forward: use httpx.Auth so per-turn bearer survives persistent MCP connections

The previous static-header approach only ran at handshake time, and
persistent MCP connections reuse the open socket for every subsequent
tools/call.  The first startup probe had no bearer, so every later
tool call inherited an empty Authorization header — Mnemosyne saw
no credentials and returned 'Authentication required'.

Fix: swap the static header for a _DynamicBearerAuth(httpx.Auth) that
httpx consults per-request via async_auth_flow.  We look up the current
_pending_bearers entry for this server_config and stamp Authorization
on each outgoing request individually — no stale caching, no
handshake/tool-call skew.

Verified chain now runs:
  bearer.captured  (inbound)
  forward.published (registry key)
  forward.bound     (auth object installed at connect time)
  forward.applied   (stamped per request via async_auth_flow)
This commit is contained in:
2026-05-05 20:57:06 -04:00
parent 711f54395d
commit f634cc55d8

View File

@@ -43,11 +43,68 @@ import threading
from pathlib import Path
from typing import Any
import httpx
from fast_agent.mcp import mcp_connection_manager as _mcm
from fast_agent.mcp.auth.context import request_bearer_token
logger = logging.getLogger("pallas.forward")
class _DynamicBearerAuth(httpx.Auth):
"""Per-request ``Authorization`` injection for persistent MCP connections.
fast-agent's ``create_mcp_http_client(headers=..., auth=...)`` snapshots
the ``headers`` dict at client construction time — every subsequent
``tools/call`` reuses the *same* open connection, so a static
``Authorization`` header set at handshake is the only one the downstream
server ever sees. For workspace-scoped forwarding that's fatal: the
first request (often a startup probe) has no bearer, and every later
request that *does* carry a bearer inherits the probe's empty header.
httpx's ``auth`` parameter, however, is consulted on **every** outgoing
request via ``Auth.sync_auth_flow`` / ``async_auth_flow``. We use that
to look up the current ``_pending_bearers`` entry for ``server_config``
and stamp ``Authorization`` onto each request individually — no stale
caching, no handshake/tool-call skew.
"""
# Per-connection-reuse so ``httpx.AsyncClient`` can share us across
# streams; the lookup is keyed by ``id(server_config)`` so different
# servers (even same-named clones) stay isolated.
requires_request_body = False
requires_response_body = False
def __init__(self, server_config: Any) -> None:
self._server_config = server_config
self._server_name = getattr(server_config, "name", "?")
def _current_token(self) -> str | None:
return _lookup_bearer(self._server_config)
def sync_auth_flow(self, request: httpx.Request):
token = self._current_token()
if token:
request.headers["Authorization"] = f"Bearer {token}"
logger.debug(
"forward.applied server=%s token_len=%d prefix=%s via=auth_flow",
self._server_name,
len(token),
token[:8],
)
else:
logger.debug(
"forward.skipped server=%s reason=no_inbound_bearer via=auth_flow",
self._server_name,
)
yield request
async def async_auth_flow(self, request: httpx.Request):
# httpx calls this generator per-request; all our work is
# pre-response so we yield once and we're done.
for item in self.sync_auth_flow(request):
yield item
# ── Opt-in server names discovered from raw YAML ──────────────────────────────
# Fast-agent's ``Settings(**merged)`` pipeline silently discards unknown keys
# on nested ``MCPServerSettings`` instances — even with ``extra="allow"`` set
@@ -158,24 +215,41 @@ def _prepare_headers_and_auth_with_forward(server_config, **kwargs):
)
return headers, oauth_auth, user_auth_keys
inbound = _lookup_bearer(server_config)
if not inbound:
logger.debug(
"forward.skipped server=%s reason=no_inbound_bearer",
server_name,
)
return headers, oauth_auth, user_auth_keys
headers = dict(headers)
headers["Authorization"] = f"Bearer {inbound}"
# Install a dynamic ``httpx.Auth`` instead of baking a static header into
# the returned ``headers`` dict. Fast-agent passes the auth object to
# ``create_mcp_http_client(auth=...)`` which forwards to
# ``httpx.AsyncClient(auth=...)``; httpx then consults it on every
# outgoing request via ``async_auth_flow``, reading the *current*
# ``_pending_bearers`` entry.
#
# This dodges the fatal "first handshake wins forever" problem:
# persistent MCP connections reuse the open socket across hundreds of
# tool-call requests, but the auth flow re-runs per request, so we can
# stamp the correct per-turn bearer onto each ``tools/call`` even though
# the initial ``initialize`` ran with no bearer at startup.
#
# We also report it through ``user_auth_keys`` so OAuth scrubbing (see
# ``_prepare_headers_and_auth`` upstream) treats Authorization as
# caller-owned and doesn't try to kick off an OAuth flow.
auth = _DynamicBearerAuth(server_config)
user_auth_keys = set(user_auth_keys) | {"Authorization"}
logger.debug(
"forward.applied server=%s token_len=%d prefix=%s",
"forward.bound server=%s auth=%s",
server_name,
len(inbound),
inbound[:8],
type(auth).__name__,
)
return headers, oauth_auth, user_auth_keys
# Current token may or may not be set — we don't require one at bind
# time because the auth flow will resolve it per-request; logging a
# preview when available helps trace the startup probe path.
inbound = _lookup_bearer(server_config)
if inbound:
logger.debug(
"forward.applied server=%s token_len=%d prefix=%s via=bind",
server_name,
len(inbound),
inbound[:8],
)
return headers, auth, user_auth_keys
def _candidate_config_paths() -> list[Path]: