forward: use httpx.Auth so per-turn bearer survives persistent MCP connections
The previous static-header approach only ran at handshake time, and persistent MCP connections reuse the open socket for every subsequent tools/call. The first startup probe had no bearer, so every later tool call inherited an empty Authorization header — Mnemosyne saw no credentials and returned 'Authentication required'. Fix: swap the static header for a _DynamicBearerAuth(httpx.Auth) that httpx consults per-request via async_auth_flow. We look up the current _pending_bearers entry for this server_config and stamp Authorization on each outgoing request individually — no stale caching, no handshake/tool-call skew. Verified chain now runs: bearer.captured (inbound) forward.published (registry key) forward.bound (auth object installed at connect time) forward.applied (stamped per request via async_auth_flow)
This commit is contained in:
@@ -43,11 +43,68 @@ import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from fast_agent.mcp import mcp_connection_manager as _mcm
|
||||
from fast_agent.mcp.auth.context import request_bearer_token
|
||||
|
||||
logger = logging.getLogger("pallas.forward")
|
||||
|
||||
|
||||
class _DynamicBearerAuth(httpx.Auth):
|
||||
"""Per-request ``Authorization`` injection for persistent MCP connections.
|
||||
|
||||
fast-agent's ``create_mcp_http_client(headers=..., auth=...)`` snapshots
|
||||
the ``headers`` dict at client construction time — every subsequent
|
||||
``tools/call`` reuses the *same* open connection, so a static
|
||||
``Authorization`` header set at handshake is the only one the downstream
|
||||
server ever sees. For workspace-scoped forwarding that's fatal: the
|
||||
first request (often a startup probe) has no bearer, and every later
|
||||
request that *does* carry a bearer inherits the probe's empty header.
|
||||
|
||||
httpx's ``auth`` parameter, however, is consulted on **every** outgoing
|
||||
request via ``Auth.sync_auth_flow`` / ``async_auth_flow``. We use that
|
||||
to look up the current ``_pending_bearers`` entry for ``server_config``
|
||||
and stamp ``Authorization`` onto each request individually — no stale
|
||||
caching, no handshake/tool-call skew.
|
||||
"""
|
||||
|
||||
# Per-connection-reuse so ``httpx.AsyncClient`` can share us across
|
||||
# streams; the lookup is keyed by ``id(server_config)`` so different
|
||||
# servers (even same-named clones) stay isolated.
|
||||
requires_request_body = False
|
||||
requires_response_body = False
|
||||
|
||||
def __init__(self, server_config: Any) -> None:
|
||||
self._server_config = server_config
|
||||
self._server_name = getattr(server_config, "name", "?")
|
||||
|
||||
def _current_token(self) -> str | None:
|
||||
return _lookup_bearer(self._server_config)
|
||||
|
||||
def sync_auth_flow(self, request: httpx.Request):
|
||||
token = self._current_token()
|
||||
if token:
|
||||
request.headers["Authorization"] = f"Bearer {token}"
|
||||
logger.debug(
|
||||
"forward.applied server=%s token_len=%d prefix=%s via=auth_flow",
|
||||
self._server_name,
|
||||
len(token),
|
||||
token[:8],
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"forward.skipped server=%s reason=no_inbound_bearer via=auth_flow",
|
||||
self._server_name,
|
||||
)
|
||||
yield request
|
||||
|
||||
async def async_auth_flow(self, request: httpx.Request):
|
||||
# httpx calls this generator per-request; all our work is
|
||||
# pre-response so we yield once and we're done.
|
||||
for item in self.sync_auth_flow(request):
|
||||
yield item
|
||||
|
||||
# ── Opt-in server names discovered from raw YAML ──────────────────────────────
|
||||
# Fast-agent's ``Settings(**merged)`` pipeline silently discards unknown keys
|
||||
# on nested ``MCPServerSettings`` instances — even with ``extra="allow"`` set
|
||||
@@ -158,24 +215,41 @@ def _prepare_headers_and_auth_with_forward(server_config, **kwargs):
|
||||
)
|
||||
return headers, oauth_auth, user_auth_keys
|
||||
|
||||
inbound = _lookup_bearer(server_config)
|
||||
if not inbound:
|
||||
logger.debug(
|
||||
"forward.skipped server=%s reason=no_inbound_bearer",
|
||||
server_name,
|
||||
)
|
||||
return headers, oauth_auth, user_auth_keys
|
||||
|
||||
headers = dict(headers)
|
||||
headers["Authorization"] = f"Bearer {inbound}"
|
||||
# Install a dynamic ``httpx.Auth`` instead of baking a static header into
|
||||
# the returned ``headers`` dict. Fast-agent passes the auth object to
|
||||
# ``create_mcp_http_client(auth=...)`` which forwards to
|
||||
# ``httpx.AsyncClient(auth=...)``; httpx then consults it on every
|
||||
# outgoing request via ``async_auth_flow``, reading the *current*
|
||||
# ``_pending_bearers`` entry.
|
||||
#
|
||||
# This dodges the fatal "first handshake wins forever" problem:
|
||||
# persistent MCP connections reuse the open socket across hundreds of
|
||||
# tool-call requests, but the auth flow re-runs per request, so we can
|
||||
# stamp the correct per-turn bearer onto each ``tools/call`` even though
|
||||
# the initial ``initialize`` ran with no bearer at startup.
|
||||
#
|
||||
# We also report it through ``user_auth_keys`` so OAuth scrubbing (see
|
||||
# ``_prepare_headers_and_auth`` upstream) treats Authorization as
|
||||
# caller-owned and doesn't try to kick off an OAuth flow.
|
||||
auth = _DynamicBearerAuth(server_config)
|
||||
user_auth_keys = set(user_auth_keys) | {"Authorization"}
|
||||
logger.debug(
|
||||
"forward.applied server=%s token_len=%d prefix=%s",
|
||||
"forward.bound server=%s auth=%s",
|
||||
server_name,
|
||||
type(auth).__name__,
|
||||
)
|
||||
# Current token may or may not be set — we don't require one at bind
|
||||
# time because the auth flow will resolve it per-request; logging a
|
||||
# preview when available helps trace the startup probe path.
|
||||
inbound = _lookup_bearer(server_config)
|
||||
if inbound:
|
||||
logger.debug(
|
||||
"forward.applied server=%s token_len=%d prefix=%s via=bind",
|
||||
server_name,
|
||||
len(inbound),
|
||||
inbound[:8],
|
||||
)
|
||||
return headers, oauth_auth, user_auth_keys
|
||||
return headers, auth, user_auth_keys
|
||||
|
||||
|
||||
def _candidate_config_paths() -> list[Path]:
|
||||
|
||||
Reference in New Issue
Block a user