From 679a809f66e3b1f90da1f370ca776a8f9ee727df Mon Sep 17 00:00:00 2001 From: Robert Helewka Date: Tue, 5 May 2026 12:09:51 -0400 Subject: [PATCH] Fix bearer forwarding across anyio TaskGroup boundary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Mnemosyne Authorization: Bearer token was being dropped on outbound MCP calls because fast-agent runs downstream transports inside a long-lived anyio TaskGroup whose context is snapshotted at manager startup — request_bearer_token.get() inside _prepare_headers_and_auth therefore always resolved to None even when the request handler had just set it. Fix: * pallas/_fastagent_patch.py - add _pending_bearers registry keyed by id(server_config) with a threading.Lock; publish_bearer / revoke_bearer helpers. - patched _prepare_headers_and_auth reads the registry first, falls back to the ContextVar for non-persistent probe paths. - emit INFO log on install() so the journal shows the patch ran; verbose flow logs at DEBUG on pallas.forward. * pallas/multimodal_server.py - send_message resolves the agent's opted-in downstreams, publishes the inbound bearer for each, and revokes them all in the finally. - bearer/header diagnostics go to pallas.auth (DEBUG) instead of /tmp/pallas-bearer.log which is invisible under systemd PrivateTmp. * pallas/log.py - honour PALLAS_LOG_LEVEL env var (default INFO) so operators can flip the forward/auth diagnostics on without a code change. * docs/pallas.md, docs/mnemosyne_integration.md - document the registry-based forwarding and the task-group ContextVar constraint that forced it. --- docs/mnemosyne_integration.md | 9 ++- docs/pallas.md | 15 +++- pallas/_fastagent_patch.py | 148 ++++++++++++++++++++++++++++------ pallas/log.py | 12 ++- pallas/multimodal_server.py | 84 +++++++++++++++---- 5 files changed, 220 insertions(+), 48 deletions(-) diff --git a/docs/mnemosyne_integration.md b/docs/mnemosyne_integration.md index c60a5e6..8d81764 100644 --- a/docs/mnemosyne_integration.md +++ b/docs/mnemosyne_integration.md @@ -74,11 +74,14 @@ async def _shawn(): 2. Daedalus calls Pallas's `send_message` tool with `Authorization: Bearer ` in the HTTP request headers. -3. Pallas's `MultimodalAgentMCPServer` captures the token via FastMCP's `get_access_token()` into the `request_bearer_token` context variable (see `pallas/multimodal_server.py`). +3. Pallas's `MultimodalAgentMCPServer` captures the token by reading the request's `Authorization` header directly through `fastmcp.server.dependencies.get_http_request()` — `get_access_token()` returns `None` because Pallas runs without the FastMCP auth middleware. The token is pushed into the `request_bearer_token` ContextVar (for LLM-provider passthrough) and **also** registered in a per-request bearer registry keyed by each opted-in downstream's `MCPServerSettings` object. -4. The fast-agent patch in `pallas/_fastagent_patch.py` (installed at import time in `pallas/__init__.py`) wraps `_prepare_headers_and_auth`. When a server config has `forward_inbound_auth: true`, the patch reads `request_bearer_token.get()` and injects `Authorization: Bearer ` into the outgoing HTTP headers for that MCP call. +4. The fast-agent patch in `pallas/_fastagent_patch.py` (installed at import time in `pallas/__init__.py`) wraps `_prepare_headers_and_auth`. When a server config has `forward_inbound_auth: true`, the patch reads the bearer out of the per-request registry (with the ContextVar as a fallback) and injects `Authorization: Bearer ` into the outgoing HTTP headers for that MCP call. The registry is required because fast-agent's `MCPConnectionManager` runs the transport in its own anyio `TaskGroup`, which does not inherit the request handler's `contextvars.Context`. + +5. The request handler's `finally` clause revokes every bearer it published, so per-request tokens never outlive the call and no stale credentials can be reused. + +6. Mnemosyne receives the same token, validates the HMAC signature against its `MCPSigningKey` table, and scopes all search Cypher queries to `ws` from the claims. -5. Mnemosyne receives the same token, validates the HMAC signature against its `MCPSigningKey` table, and scopes all search Cypher queries to `ws` from the claims. The `forward_inbound_auth` flag is **per-server** — other servers in the same agent (`argos`, `neo4j_cypher`, `time`, etc.) never receive the bearer. diff --git a/docs/pallas.md b/docs/pallas.md index 8848fc2..0033dde 100644 --- a/docs/pallas.md +++ b/docs/pallas.md @@ -417,10 +417,19 @@ For agents with `instance_scope != "request"`, a `{agent}_history` prompt is reg ### Bearer Token Propagation -The server captures the authenticated bearer token from the incoming MCP request into the `request_bearer_token` context variable. Two consumers read it: +The server captures the authenticated bearer token from the incoming MCP request's `Authorization: Bearer …` header via `fastmcp.server.dependencies.get_http_request()` (FastMCP's `get_access_token()` returns `None` because Pallas runs without the auth middleware). Two consumers read it: + +- **LLM-provider passthrough** — the token is also pushed into the `request_bearer_token` ContextVar for the agent's LLM provider key manager to pick up automatically (used by HuggingFace and any other token-passthrough providers). The ContextVar works here because the LLM call runs in a child task of the request handler. +- **Downstream MCP servers (opt-in)** — outgoing MCP calls inherit the same bearer when the downstream server is marked `forward_inbound_auth: true` in `fastagent.config.yaml`. Without that flag, the inbound bearer is **not** forwarded to MCP transport calls — `server_config.headers` is the only header source. + +The forwarding is per-server so a FastAgent attached to both a credentialed downstream (e.g. Mnemosyne) and an unrelated public server doesn't leak the bearer to the latter. + +#### Why a simple ContextVar forward isn't enough + +fast-agent's `MCPConnectionManager` runs each downstream transport inside a long-lived `anyio.TaskGroup` created at manager startup. `TaskGroup.start_soon` snapshots the owner's `contextvars.Context` at spawn time — the request-handler's context is invisible to the transport task. A straight `request_bearer_token.get()` inside `_prepare_headers_and_auth` therefore always resolves to `None` even when the inbound handler has `set` the token a few frames up. The persistent connection is additionally reused across requests, so the first-call context (often empty) would be cached forever. + +Pallas works around this in `pallas._fastagent_patch` by maintaining a process-wide `_pending_bearers` registry keyed by `id(server_config)`. `multimodal_server.send_message` calls `publish_bearer(cfg, token)` for every opted-in downstream the agent is allowed to reach; the patched `_prepare_headers_and_auth` looks it up there (with the ContextVar as a fallback for non-persistent probe paths); and the request handler's `finally` block calls `revoke_bearer(cfg)` to clear the entry. Per-request bearers therefore survive the task-group boundary without any mutation of shared config. -- **LLM-provider passthrough** — the agent's LLM provider key manager picks it up automatically (used by HuggingFace and any other token-passthrough providers). -- **Downstream MCP servers (opt-in)** — outgoing MCP calls inherit the same bearer when the downstream server is marked `forward_inbound_auth: true` in `fastagent.config.yaml`. Without that flag, `request_bearer_token` is **not** forwarded to MCP transport calls — `server_config.headers` is the only header source. This is implemented as a fast-agent monkey-patch in `pallas._fastagent_patch` and is per-server so a FastAgent attached to both a credentialed downstream (e.g. Mnemosyne) and an unrelated public server doesn't leak the bearer to the latter. Example: diff --git a/pallas/_fastagent_patch.py b/pallas/_fastagent_patch.py index 731ab13..3cdccbe 100644 --- a/pallas/_fastagent_patch.py +++ b/pallas/_fastagent_patch.py @@ -1,16 +1,36 @@ """Forward the inbound bearer token to opted-in downstream MCP servers. -fast-agent (≤0.6.19) captures the inbound `Authorization: Bearer ` into -the `request_bearer_token` ContextVar but does NOT propagate it to -outgoing MCP transport calls — `_prepare_headers_and_auth` only reads -`server_config.headers`. This patch wraps that function so a -downstream server marked `forward_inbound_auth: true` in -fastagent.config.yaml receives the same bearer the FastAgent itself -was called with. +fast-agent (≤0.6.19) captures the inbound ``Authorization: Bearer `` into +the ``request_bearer_token`` ContextVar, but does NOT propagate that value to +outgoing MCP transport calls — ``_prepare_headers_and_auth`` only reads +``server_config.headers``. This module patches ``_prepare_headers_and_auth`` +so a downstream server marked ``forward_inbound_auth: true`` in +``fastagent.config.yaml`` receives the same bearer the FastAgent itself was +called with. Opt-in is per-server because a FastAgent with multiple downstream MCP -attachments (e.g. Mnemosyne + a public weather server) must not leak -its credentials to every endpoint. +attachments (e.g. Mnemosyne + a public weather server) must not leak its +credentials to every endpoint. + +Why the simple "read from ``request_bearer_token``" approach does NOT work +---------------------------------------------------------------------------- +``MCPConnectionManager.launch_server`` spawns the server's transport task in +``self._tg`` — a long-lived ``anyio.TaskGroup`` created at manager startup. +``TaskGroup.start_soon`` copies the owning task's ``contextvars.Context`` at +spawn time, which is the *startup* context, not the per-request context. +The transport-preparation code therefore sees ``request_bearer_token.get()`` +as ``None`` even when the MCP request handler has just ``set`` it a few +frames up. Worse, ``launch_server`` runs once per downstream and the +persistent connection is reused, so the very first request's (often +empty) context is cached forever. + +The fix is to hand the bearer through the only object that *is* shared +between the two tasks: the ``MCPServerSettings`` instance that both paths +pass into ``_prepare_headers_and_auth``. ``pallas.multimodal_server`` +registers the inbound bearer against ``id(server_config)`` in a +process-wide registry for the duration of each MCP request; this patch +reads it there and forges an ``Authorization`` header onto the outgoing +transport. Cleanup is guaranteed in the request handler's ``finally``. TODO: drop after the equivalent change lands in fast-agent upstream. """ @@ -18,6 +38,8 @@ TODO: drop after the equivalent change lands in fast-agent upstream. from __future__ import annotations import logging +import threading +from typing import Any from fast_agent.config import MCPServerSettings as _MCPServerSettings from fast_agent.mcp import mcp_connection_manager as _mcm @@ -25,49 +47,119 @@ from fast_agent.mcp.auth.context import request_bearer_token logger = logging.getLogger("pallas.forward") -_AUTH_HEADER_KEYS = {"authorization", "x-hf-authorization"} _original_prepare = _mcm._prepare_headers_and_auth +# ── Per-request bearer registry ────────────────────────────────────────────── +# Keyed by ``id(server_config)`` so a request handler can publish the bearer +# that applies to each opted-in downstream server. The registry survives the +# context-var-loss hop across anyio task groups because ``id()`` is stable and +# the config object itself is held by fast-agent's ServerRegistry. +# +# A threading.Lock (not asyncio) is used because both the publishing side +# (request handler) and the reading side (``_prepare_headers_and_auth``, run +# inside the connection manager's task group) may execute on different anyio +# worker threads under uvicorn's default thread-portal setup. Access is +# microsecond-scoped — no contention concerns. +_pending_bearers: dict[int, str] = {} +_pending_lock = threading.Lock() -def _diag_write(line: str) -> None: - """Append a diagnostic line to /tmp/pallas-bearer.log, never raises.""" + +def publish_bearer(server_config: Any, token: str) -> None: + """Register ``token`` as the inbound bearer to forward to this server. + + Called by ``pallas.multimodal_server.send_message`` for every downstream + whose config carries ``forward_inbound_auth: true``. Must be paired with + ``revoke_bearer`` in the same ``try/finally``. + """ + if not token: + return + with _pending_lock: + _pending_bearers[id(server_config)] = token + logger.debug( + "forward.published server=%s token_len=%d prefix=%s", + getattr(server_config, "name", "?"), + len(token), + token[:8], + ) + + +def revoke_bearer(server_config: Any) -> None: + """Clear any bearer previously published for ``server_config``. + + Always safe to call — a missing key is silently ignored, so request + handlers can ``finally: revoke_bearer(cfg)`` without pre-checks. + """ + with _pending_lock: + _pending_bearers.pop(id(server_config), None) + logger.debug( + "forward.revoked server=%s", + getattr(server_config, "name", "?"), + ) + + +def _lookup_bearer(server_config: Any) -> str | None: + """Resolve the bearer to forward for ``server_config``. + + Tries the per-request registry first (works across task groups) and + falls back to the ContextVar for cases where the caller lives in the + same task (e.g. fast-agent's own non-persistent probe path). + """ + with _pending_lock: + token = _pending_bearers.get(id(server_config)) + if token: + return token try: - from datetime import datetime - with open("/tmp/pallas-bearer.log", "a") as f: - f.write(f"{datetime.now().isoformat()} {line}\n") - except Exception: - pass + return request_bearer_token.get() + except LookupError: + return None def _prepare_headers_and_auth_with_forward(server_config, **kwargs): headers, oauth_auth, user_auth_keys = _original_prepare(server_config, **kwargs) - server_name = getattr(server_config, "name", None) + server_name = getattr(server_config, "name", None) or "?" forward_flag = getattr(server_config, "forward_inbound_auth", False) - _diag_write(f"FORWARD check server={server_name} flag={forward_flag}") + logger.debug( + "forward.check server=%s forward_inbound_auth=%s", + server_name, + forward_flag, + ) if not forward_flag: return headers, oauth_auth, user_auth_keys if user_auth_keys: - _diag_write(f"FORWARD skipped_user_auth server={server_name}") + logger.debug( + "forward.skipped server=%s reason=user_auth_present keys=%s", + server_name, + sorted(user_auth_keys), + ) return headers, oauth_auth, user_auth_keys if oauth_auth is not None: - _diag_write(f"FORWARD skipped_oauth server={server_name}") + logger.debug( + "forward.skipped server=%s reason=oauth_active", + server_name, + ) return headers, oauth_auth, user_auth_keys - inbound = request_bearer_token.get() + inbound = _lookup_bearer(server_config) if not inbound: - _diag_write(f"FORWARD no_inbound server={server_name}") + logger.debug( + "forward.skipped server=%s reason=no_inbound_bearer", + server_name, + ) return headers, oauth_auth, user_auth_keys headers = dict(headers) headers["Authorization"] = f"Bearer {inbound}" user_auth_keys = set(user_auth_keys) | {"Authorization"} - _diag_write( - f"FORWARD applied server={server_name} token_len={len(inbound)} prefix={inbound[:8]}" + logger.debug( + "forward.applied server=%s token_len=%d prefix=%s", + server_name, + len(inbound), + inbound[:8], ) return headers, oauth_auth, user_auth_keys @@ -90,3 +182,9 @@ def install() -> None: _allow_extras_on_server_settings() _prepare_headers_and_auth_with_forward._pallas_forward_patched = True # type: ignore[attr-defined] _mcm._prepare_headers_and_auth = _prepare_headers_and_auth_with_forward + # INFO so it always appears in the journal at boot — greppable proof + # that the patch ran before any agent started. + logger.info( + "bearer-forwarding patch installed " + "(forward_inbound_auth-aware _prepare_headers_and_auth)" + ) diff --git a/pallas/log.py b/pallas/log.py index 15b94db..d7df4cf 100644 --- a/pallas/log.py +++ b/pallas/log.py @@ -17,6 +17,7 @@ Level conventions (Ouranos Lab — Python services use UPPERCASE): import json import logging +import os import re from datetime import datetime, timezone @@ -70,15 +71,22 @@ class _HealthAccessFilter(logging.Filter): def setup_logging() -> None: """Configure Pallas logging. - - ``pallas.*`` logger: INFO, JSON to stdout + - ``pallas.*`` logger: INFO by default, JSON to stdout. Override via the + ``PALLAS_LOG_LEVEL`` environment variable (``DEBUG``, ``INFO``, + ``WARNING``, ``ERROR``). DEBUG unlocks bearer-forwarding diagnostics + on the ``pallas.forward`` and ``pallas.auth`` loggers — essential when + troubleshooting Mnemosyne / workspace-scoped agent calls. - ``httpx`` / ``httpcore``: WARNING (prevent request-level debug flooding) - ``uvicorn.access``: health path filter applied """ handler = logging.StreamHandler() handler.setFormatter(_JSONFormatter()) + level_name = os.environ.get("PALLAS_LOG_LEVEL", "INFO").upper() + level = getattr(logging, level_name, logging.INFO) + pallas_logger = logging.getLogger("pallas") - pallas_logger.setLevel(logging.INFO) + pallas_logger.setLevel(level) if not pallas_logger.handlers: pallas_logger.addHandler(handler) pallas_logger.propagate = False diff --git a/pallas/multimodal_server.py b/pallas/multimodal_server.py index 2e7c1fe..b18c5a3 100644 --- a/pallas/multimodal_server.py +++ b/pallas/multimodal_server.py @@ -19,6 +19,7 @@ fast-agent instance whose ``message_history`` is seeded from the caller's memory, no restart amnesia. """ +import logging import time from typing import Any @@ -28,6 +29,7 @@ from fast_agent.mcp.auth.context import request_bearer_token from fast_agent.mcp.server import AgentMCPServer from fast_agent.types import PromptMessageExtended, RequestParams +from pallas._fastagent_patch import publish_bearer, revoke_bearer from pallas.progress import EnrichedMCPToolProgressManager from fastmcp import Context as MCPContext from fastmcp.prompts import Message @@ -37,15 +39,11 @@ from starlette.responses import JSONResponse, Response logger = get_logger(__name__) - -def _diag_write(line: str) -> None: - """Append a diagnostic line to /tmp/pallas-bearer.log, never raises.""" - try: - from datetime import datetime - with open("/tmp/pallas-bearer.log", "a") as f: - f.write(f"{datetime.now().isoformat()} {line}\n") - except Exception: - pass +# Separate stdlib logger for bearer-token diagnostics — routed through +# ``pallas.log`` JSON handler to stdout / systemd journal. Gated at DEBUG so +# it is off by default in production but trivially flipped on via +# ``PALLAS_LOG_LEVEL=DEBUG`` for troubleshooting agent auth issues. +_auth_log = logging.getLogger("pallas.auth") def _get_request_bearer_token() -> str | None: @@ -54,7 +52,8 @@ def _get_request_bearer_token() -> str | None: Reads the header directly rather than going through get_access_token() because Pallas runs without FastMCP auth middleware — there is no AuthenticatedUser in the request scope, so get_access_token() always returns None here. The token - is an opaque string forwarded to opted-in downstream servers by _fastagent_patch. + is an opaque string forwarded to opted-in downstream servers by + ``pallas._fastagent_patch``. """ try: from fastmcp.server.dependencies import get_http_request @@ -63,16 +62,48 @@ def _get_request_bearer_token() -> str | None: auth = request.headers.get("authorization", "") if auth.lower().startswith("bearer "): token = auth[7:] - _diag_write( - f"BEARER captured len={len(token)} prefix={token[:8]}" + _auth_log.debug( + "bearer.captured len=%d prefix=%s", len(token), token[:8] ) return token - _diag_write(f"BEARER absent has_auth={bool(auth)}") + _auth_log.debug("bearer.absent has_auth=%s", bool(auth)) except Exception as exc: - _diag_write(f"BEARER error={exc}") + _auth_log.debug("bearer.error %s", exc) return None +def _forwardable_server_configs(agent) -> list[Any]: + """Return the ``MCPServerSettings`` objects the agent is entitled to and + which are marked ``forward_inbound_auth: true``. + + Restricting to the agent's own ``servers`` list ensures a request never + publishes its bearer against a server the calling agent does not use — + e.g. a Harper→Mnemosyne call must not flag Scotty→Mnemosyne's config. + + Safe to call before the agent is constructed: returns an empty list if + any attribute lookup fails. + """ + try: + agent_servers = set(getattr(agent.config, "servers", []) or []) + if not agent_servers: + return [] + registry = agent.context.server_registry + if registry is None: + return [] + configs: list[Any] = [] + for name in agent_servers: + cfg = registry.registry.get(name) + if cfg is None: + continue + if getattr(cfg, "forward_inbound_auth", False): + configs.append(cfg) + return configs + except Exception as exc: + _auth_log.debug("bearer.registry_lookup_failed %s", exc) + return [] + + + def _history_to_fastmcp_messages( message_history: list[PromptMessageExtended], ) -> list[Message]: @@ -237,17 +268,34 @@ class MultimodalAgentMCPServer(AgentMCPServer): Optional opaque identifier, logged for trace correlation. Pallas does not interpret it. """ - saved_token = request_bearer_token.set(_get_request_bearer_token()) + inbound_bearer = _get_request_bearer_token() + saved_token = request_bearer_token.set(inbound_bearer) report_progress = self._build_progress_reporter(ctx) request_params = RequestParams( tool_execution_handler=EnrichedMCPToolProgressManager(report_progress), emit_loop_progress=True, ) + # Track which downstream server configs we publish the bearer + # against so the ``finally`` block below can revoke every one of + # them even if the agent send raises halfway through. + published_configs: list[Any] = [] try: instance = await self._acquire_instance(ctx) agent = instance.app[agent_name] agent_context = getattr(agent, "context", None) + # Register the inbound bearer against each downstream server + # config the agent is allowed to reach and which opts-in via + # ``forward_inbound_auth: true``. This is how the bearer + # crosses the anyio task-group boundary that ContextVars + # cannot hop — see ``pallas._fastagent_patch`` for the + # full explanation. + if inbound_bearer: + for srv_cfg in _forwardable_server_configs(agent): + publish_bearer(srv_cfg, inbound_bearer) + published_configs.append(srv_cfg) + + # Seed the freshly-created instance's message_history from the # caller-supplied history so the agent sees the full # conversation the caller is tracking. Safe no-op when the @@ -311,8 +359,14 @@ class MultimodalAgentMCPServer(AgentMCPServer): finally: await self._release_instance(ctx, instance) finally: + # Always revoke every bearer we published, then restore the + # ContextVar — order matters only for tidiness; a revoke that + # finds nothing is a no-op. + for srv_cfg in published_configs: + revoke_bearer(srv_cfg) request_bearer_token.reset(saved_token) + if self._instance_scope == "request": # With request-scoped instances there is no persistent server-side # history to expose — the caller owns it. We still register the