From f634cc55d8550a7a80724e2887d53d12c8c3b52d Mon Sep 17 00:00:00 2001 From: Robert Helewka Date: Tue, 5 May 2026 20:57:06 -0400 Subject: [PATCH] forward: use httpx.Auth so per-turn bearer survives persistent MCP connections MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous static-header approach only ran at handshake time, and persistent MCP connections reuse the open socket for every subsequent tools/call. The first startup probe had no bearer, so every later tool call inherited an empty Authorization header — Mnemosyne saw no credentials and returned 'Authentication required'. Fix: swap the static header for a _DynamicBearerAuth(httpx.Auth) that httpx consults per-request via async_auth_flow. We look up the current _pending_bearers entry for this server_config and stamp Authorization on each outgoing request individually — no stale caching, no handshake/tool-call skew. Verified chain now runs: bearer.captured (inbound) forward.published (registry key) forward.bound (auth object installed at connect time) forward.applied (stamped per request via async_auth_flow) --- pallas/_fastagent_patch.py | 102 ++++++++++++++++++++++++++++++++----- 1 file changed, 88 insertions(+), 14 deletions(-) diff --git a/pallas/_fastagent_patch.py b/pallas/_fastagent_patch.py index f7415b9..c6001ec 100644 --- a/pallas/_fastagent_patch.py +++ b/pallas/_fastagent_patch.py @@ -43,11 +43,68 @@ import threading from pathlib import Path from typing import Any +import httpx + from fast_agent.mcp import mcp_connection_manager as _mcm from fast_agent.mcp.auth.context import request_bearer_token logger = logging.getLogger("pallas.forward") + +class _DynamicBearerAuth(httpx.Auth): + """Per-request ``Authorization`` injection for persistent MCP connections. + + fast-agent's ``create_mcp_http_client(headers=..., auth=...)`` snapshots + the ``headers`` dict at client construction time — every subsequent + ``tools/call`` reuses the *same* open connection, so a static + ``Authorization`` header set at handshake is the only one the downstream + server ever sees. For workspace-scoped forwarding that's fatal: the + first request (often a startup probe) has no bearer, and every later + request that *does* carry a bearer inherits the probe's empty header. + + httpx's ``auth`` parameter, however, is consulted on **every** outgoing + request via ``Auth.sync_auth_flow`` / ``async_auth_flow``. We use that + to look up the current ``_pending_bearers`` entry for ``server_config`` + and stamp ``Authorization`` onto each request individually — no stale + caching, no handshake/tool-call skew. + """ + + # Per-connection-reuse so ``httpx.AsyncClient`` can share us across + # streams; the lookup is keyed by ``id(server_config)`` so different + # servers (even same-named clones) stay isolated. + requires_request_body = False + requires_response_body = False + + def __init__(self, server_config: Any) -> None: + self._server_config = server_config + self._server_name = getattr(server_config, "name", "?") + + def _current_token(self) -> str | None: + return _lookup_bearer(self._server_config) + + def sync_auth_flow(self, request: httpx.Request): + token = self._current_token() + if token: + request.headers["Authorization"] = f"Bearer {token}" + logger.debug( + "forward.applied server=%s token_len=%d prefix=%s via=auth_flow", + self._server_name, + len(token), + token[:8], + ) + else: + logger.debug( + "forward.skipped server=%s reason=no_inbound_bearer via=auth_flow", + self._server_name, + ) + yield request + + async def async_auth_flow(self, request: httpx.Request): + # httpx calls this generator per-request; all our work is + # pre-response so we yield once and we're done. + for item in self.sync_auth_flow(request): + yield item + # ── Opt-in server names discovered from raw YAML ────────────────────────────── # Fast-agent's ``Settings(**merged)`` pipeline silently discards unknown keys # on nested ``MCPServerSettings`` instances — even with ``extra="allow"`` set @@ -158,24 +215,41 @@ def _prepare_headers_and_auth_with_forward(server_config, **kwargs): ) return headers, oauth_auth, user_auth_keys - inbound = _lookup_bearer(server_config) - if not inbound: - logger.debug( - "forward.skipped server=%s reason=no_inbound_bearer", - server_name, - ) - return headers, oauth_auth, user_auth_keys - - headers = dict(headers) - headers["Authorization"] = f"Bearer {inbound}" + # Install a dynamic ``httpx.Auth`` instead of baking a static header into + # the returned ``headers`` dict. Fast-agent passes the auth object to + # ``create_mcp_http_client(auth=...)`` which forwards to + # ``httpx.AsyncClient(auth=...)``; httpx then consults it on every + # outgoing request via ``async_auth_flow``, reading the *current* + # ``_pending_bearers`` entry. + # + # This dodges the fatal "first handshake wins forever" problem: + # persistent MCP connections reuse the open socket across hundreds of + # tool-call requests, but the auth flow re-runs per request, so we can + # stamp the correct per-turn bearer onto each ``tools/call`` even though + # the initial ``initialize`` ran with no bearer at startup. + # + # We also report it through ``user_auth_keys`` so OAuth scrubbing (see + # ``_prepare_headers_and_auth`` upstream) treats Authorization as + # caller-owned and doesn't try to kick off an OAuth flow. + auth = _DynamicBearerAuth(server_config) user_auth_keys = set(user_auth_keys) | {"Authorization"} logger.debug( - "forward.applied server=%s token_len=%d prefix=%s", + "forward.bound server=%s auth=%s", server_name, - len(inbound), - inbound[:8], + type(auth).__name__, ) - return headers, oauth_auth, user_auth_keys + # Current token may or may not be set — we don't require one at bind + # time because the auth flow will resolve it per-request; logging a + # preview when available helps trace the startup probe path. + inbound = _lookup_bearer(server_config) + if inbound: + logger.debug( + "forward.applied server=%s token_len=%d prefix=%s via=bind", + server_name, + len(inbound), + inbound[:8], + ) + return headers, auth, user_auth_keys def _candidate_config_paths() -> list[Path]: