Fix bearer forwarding across anyio TaskGroup boundary

The Mnemosyne Authorization: Bearer token was being dropped on outbound MCP
calls because fast-agent runs downstream transports inside a long-lived
anyio TaskGroup whose context is snapshotted at manager startup —
request_bearer_token.get() inside _prepare_headers_and_auth therefore
always resolved to None even when the request handler had just set it.

Fix:
* pallas/_fastagent_patch.py
    - add _pending_bearers registry keyed by id(server_config) with a
      threading.Lock; publish_bearer / revoke_bearer helpers.
    - patched _prepare_headers_and_auth reads the registry first, falls
      back to the ContextVar for non-persistent probe paths.
    - emit INFO log on install() so the journal shows the patch ran;
      verbose flow logs at DEBUG on pallas.forward.

* pallas/multimodal_server.py
    - send_message resolves the agent's opted-in downstreams, publishes
      the inbound bearer for each, and revokes them all in the finally.
    - bearer/header diagnostics go to pallas.auth (DEBUG) instead of
      /tmp/pallas-bearer.log which is invisible under systemd PrivateTmp.

* pallas/log.py
    - honour PALLAS_LOG_LEVEL env var (default INFO) so operators can
      flip the forward/auth diagnostics on without a code change.

* docs/pallas.md, docs/mnemosyne_integration.md
    - document the registry-based forwarding and the task-group
      ContextVar constraint that forced it.
This commit is contained in:
2026-05-05 12:09:51 -04:00
parent 24c7374f3d
commit 679a809f66
5 changed files with 220 additions and 48 deletions

View File

@@ -74,11 +74,14 @@ async def _shawn():
2. Daedalus calls Pallas's `send_message` tool with `Authorization: Bearer <token>` in the HTTP request headers. 2. Daedalus calls Pallas's `send_message` tool with `Authorization: Bearer <token>` in the HTTP request headers.
3. Pallas's `MultimodalAgentMCPServer` captures the token via FastMCP's `get_access_token()` into the `request_bearer_token` context variable (see `pallas/multimodal_server.py`). 3. Pallas's `MultimodalAgentMCPServer` captures the token by reading the request's `Authorization` header directly through `fastmcp.server.dependencies.get_http_request()` — `get_access_token()` returns `None` because Pallas runs without the FastMCP auth middleware. The token is pushed into the `request_bearer_token` ContextVar (for LLM-provider passthrough) and **also** registered in a per-request bearer registry keyed by each opted-in downstream's `MCPServerSettings` object.
4. The fast-agent patch in `pallas/_fastagent_patch.py` (installed at import time in `pallas/__init__.py`) wraps `_prepare_headers_and_auth`. When a server config has `forward_inbound_auth: true`, the patch reads `request_bearer_token.get()` and injects `Authorization: Bearer <token>` into the outgoing HTTP headers for that MCP call. 4. The fast-agent patch in `pallas/_fastagent_patch.py` (installed at import time in `pallas/__init__.py`) wraps `_prepare_headers_and_auth`. When a server config has `forward_inbound_auth: true`, the patch reads the bearer out of the per-request registry (with the ContextVar as a fallback) and injects `Authorization: Bearer <token>` into the outgoing HTTP headers for that MCP call. The registry is required because fast-agent's `MCPConnectionManager` runs the transport in its own anyio `TaskGroup`, which does not inherit the request handler's `contextvars.Context`.
5. The request handler's `finally` clause revokes every bearer it published, so per-request tokens never outlive the call and no stale credentials can be reused.
6. Mnemosyne receives the same token, validates the HMAC signature against its `MCPSigningKey` table, and scopes all search Cypher queries to `ws` from the claims.
5. Mnemosyne receives the same token, validates the HMAC signature against its `MCPSigningKey` table, and scopes all search Cypher queries to `ws` from the claims.
The `forward_inbound_auth` flag is **per-server** — other servers in the same agent (`argos`, `neo4j_cypher`, `time`, etc.) never receive the bearer. The `forward_inbound_auth` flag is **per-server** — other servers in the same agent (`argos`, `neo4j_cypher`, `time`, etc.) never receive the bearer.

View File

@@ -417,10 +417,19 @@ For agents with `instance_scope != "request"`, a `{agent}_history` prompt is reg
### Bearer Token Propagation ### Bearer Token Propagation
The server captures the authenticated bearer token from the incoming MCP request into the `request_bearer_token` context variable. Two consumers read it: The server captures the authenticated bearer token from the incoming MCP request's `Authorization: Bearer …` header via `fastmcp.server.dependencies.get_http_request()` (FastMCP's `get_access_token()` returns `None` because Pallas runs without the auth middleware). Two consumers read it:
- **LLM-provider passthrough** — the token is also pushed into the `request_bearer_token` ContextVar for the agent's LLM provider key manager to pick up automatically (used by HuggingFace and any other token-passthrough providers). The ContextVar works here because the LLM call runs in a child task of the request handler.
- **Downstream MCP servers (opt-in)** — outgoing MCP calls inherit the same bearer when the downstream server is marked `forward_inbound_auth: true` in `fastagent.config.yaml`. Without that flag, the inbound bearer is **not** forwarded to MCP transport calls — `server_config.headers` is the only header source.
The forwarding is per-server so a FastAgent attached to both a credentialed downstream (e.g. Mnemosyne) and an unrelated public server doesn't leak the bearer to the latter.
#### Why a simple ContextVar forward isn't enough
fast-agent's `MCPConnectionManager` runs each downstream transport inside a long-lived `anyio.TaskGroup` created at manager startup. `TaskGroup.start_soon` snapshots the owner's `contextvars.Context` at spawn time — the request-handler's context is invisible to the transport task. A straight `request_bearer_token.get()` inside `_prepare_headers_and_auth` therefore always resolves to `None` even when the inbound handler has `set` the token a few frames up. The persistent connection is additionally reused across requests, so the first-call context (often empty) would be cached forever.
Pallas works around this in `pallas._fastagent_patch` by maintaining a process-wide `_pending_bearers` registry keyed by `id(server_config)`. `multimodal_server.send_message` calls `publish_bearer(cfg, token)` for every opted-in downstream the agent is allowed to reach; the patched `_prepare_headers_and_auth` looks it up there (with the ContextVar as a fallback for non-persistent probe paths); and the request handler's `finally` block calls `revoke_bearer(cfg)` to clear the entry. Per-request bearers therefore survive the task-group boundary without any mutation of shared config.
- **LLM-provider passthrough** — the agent's LLM provider key manager picks it up automatically (used by HuggingFace and any other token-passthrough providers).
- **Downstream MCP servers (opt-in)** — outgoing MCP calls inherit the same bearer when the downstream server is marked `forward_inbound_auth: true` in `fastagent.config.yaml`. Without that flag, `request_bearer_token` is **not** forwarded to MCP transport calls — `server_config.headers` is the only header source. This is implemented as a fast-agent monkey-patch in `pallas._fastagent_patch` and is per-server so a FastAgent attached to both a credentialed downstream (e.g. Mnemosyne) and an unrelated public server doesn't leak the bearer to the latter.
Example: Example:

View File

@@ -1,16 +1,36 @@
"""Forward the inbound bearer token to opted-in downstream MCP servers. """Forward the inbound bearer token to opted-in downstream MCP servers.
fast-agent (≤0.6.19) captures the inbound `Authorization: Bearer <X>` into fast-agent (≤0.6.19) captures the inbound ``Authorization: Bearer <X>`` into
the `request_bearer_token` ContextVar but does NOT propagate it to the ``request_bearer_token`` ContextVar, but does NOT propagate that value to
outgoing MCP transport calls — `_prepare_headers_and_auth` only reads outgoing MCP transport calls — ``_prepare_headers_and_auth`` only reads
`server_config.headers`. This patch wraps that function so a ``server_config.headers``. This module patches ``_prepare_headers_and_auth``
downstream server marked `forward_inbound_auth: true` in so a downstream server marked ``forward_inbound_auth: true`` in
fastagent.config.yaml receives the same bearer the FastAgent itself ``fastagent.config.yaml`` receives the same bearer the FastAgent itself was
was called with. called with.
Opt-in is per-server because a FastAgent with multiple downstream MCP Opt-in is per-server because a FastAgent with multiple downstream MCP
attachments (e.g. Mnemosyne + a public weather server) must not leak attachments (e.g. Mnemosyne + a public weather server) must not leak its
its credentials to every endpoint. credentials to every endpoint.
Why the simple "read from ``request_bearer_token``" approach does NOT work
----------------------------------------------------------------------------
``MCPConnectionManager.launch_server`` spawns the server's transport task in
``self._tg`` — a long-lived ``anyio.TaskGroup`` created at manager startup.
``TaskGroup.start_soon`` copies the owning task's ``contextvars.Context`` at
spawn time, which is the *startup* context, not the per-request context.
The transport-preparation code therefore sees ``request_bearer_token.get()``
as ``None`` even when the MCP request handler has just ``set`` it a few
frames up. Worse, ``launch_server`` runs once per downstream and the
persistent connection is reused, so the very first request's (often
empty) context is cached forever.
The fix is to hand the bearer through the only object that *is* shared
between the two tasks: the ``MCPServerSettings`` instance that both paths
pass into ``_prepare_headers_and_auth``. ``pallas.multimodal_server``
registers the inbound bearer against ``id(server_config)`` in a
process-wide registry for the duration of each MCP request; this patch
reads it there and forges an ``Authorization`` header onto the outgoing
transport. Cleanup is guaranteed in the request handler's ``finally``.
TODO: drop after the equivalent change lands in fast-agent upstream. TODO: drop after the equivalent change lands in fast-agent upstream.
""" """
@@ -18,6 +38,8 @@ TODO: drop after the equivalent change lands in fast-agent upstream.
from __future__ import annotations from __future__ import annotations
import logging import logging
import threading
from typing import Any
from fast_agent.config import MCPServerSettings as _MCPServerSettings from fast_agent.config import MCPServerSettings as _MCPServerSettings
from fast_agent.mcp import mcp_connection_manager as _mcm from fast_agent.mcp import mcp_connection_manager as _mcm
@@ -25,49 +47,119 @@ from fast_agent.mcp.auth.context import request_bearer_token
logger = logging.getLogger("pallas.forward") logger = logging.getLogger("pallas.forward")
_AUTH_HEADER_KEYS = {"authorization", "x-hf-authorization"}
_original_prepare = _mcm._prepare_headers_and_auth _original_prepare = _mcm._prepare_headers_and_auth
# ── Per-request bearer registry ──────────────────────────────────────────────
# Keyed by ``id(server_config)`` so a request handler can publish the bearer
# that applies to each opted-in downstream server. The registry survives the
# context-var-loss hop across anyio task groups because ``id()`` is stable and
# the config object itself is held by fast-agent's ServerRegistry.
#
# A threading.Lock (not asyncio) is used because both the publishing side
# (request handler) and the reading side (``_prepare_headers_and_auth``, run
# inside the connection manager's task group) may execute on different anyio
# worker threads under uvicorn's default thread-portal setup. Access is
# microsecond-scoped — no contention concerns.
_pending_bearers: dict[int, str] = {}
_pending_lock = threading.Lock()
def _diag_write(line: str) -> None:
"""Append a diagnostic line to /tmp/pallas-bearer.log, never raises.""" def publish_bearer(server_config: Any, token: str) -> None:
"""Register ``token`` as the inbound bearer to forward to this server.
Called by ``pallas.multimodal_server.send_message`` for every downstream
whose config carries ``forward_inbound_auth: true``. Must be paired with
``revoke_bearer`` in the same ``try/finally``.
"""
if not token:
return
with _pending_lock:
_pending_bearers[id(server_config)] = token
logger.debug(
"forward.published server=%s token_len=%d prefix=%s",
getattr(server_config, "name", "?"),
len(token),
token[:8],
)
def revoke_bearer(server_config: Any) -> None:
"""Clear any bearer previously published for ``server_config``.
Always safe to call — a missing key is silently ignored, so request
handlers can ``finally: revoke_bearer(cfg)`` without pre-checks.
"""
with _pending_lock:
_pending_bearers.pop(id(server_config), None)
logger.debug(
"forward.revoked server=%s",
getattr(server_config, "name", "?"),
)
def _lookup_bearer(server_config: Any) -> str | None:
"""Resolve the bearer to forward for ``server_config``.
Tries the per-request registry first (works across task groups) and
falls back to the ContextVar for cases where the caller lives in the
same task (e.g. fast-agent's own non-persistent probe path).
"""
with _pending_lock:
token = _pending_bearers.get(id(server_config))
if token:
return token
try: try:
from datetime import datetime return request_bearer_token.get()
with open("/tmp/pallas-bearer.log", "a") as f: except LookupError:
f.write(f"{datetime.now().isoformat()} {line}\n") return None
except Exception:
pass
def _prepare_headers_and_auth_with_forward(server_config, **kwargs): def _prepare_headers_and_auth_with_forward(server_config, **kwargs):
headers, oauth_auth, user_auth_keys = _original_prepare(server_config, **kwargs) headers, oauth_auth, user_auth_keys = _original_prepare(server_config, **kwargs)
server_name = getattr(server_config, "name", None) server_name = getattr(server_config, "name", None) or "?"
forward_flag = getattr(server_config, "forward_inbound_auth", False) forward_flag = getattr(server_config, "forward_inbound_auth", False)
_diag_write(f"FORWARD check server={server_name} flag={forward_flag}") logger.debug(
"forward.check server=%s forward_inbound_auth=%s",
server_name,
forward_flag,
)
if not forward_flag: if not forward_flag:
return headers, oauth_auth, user_auth_keys return headers, oauth_auth, user_auth_keys
if user_auth_keys: if user_auth_keys:
_diag_write(f"FORWARD skipped_user_auth server={server_name}") logger.debug(
"forward.skipped server=%s reason=user_auth_present keys=%s",
server_name,
sorted(user_auth_keys),
)
return headers, oauth_auth, user_auth_keys return headers, oauth_auth, user_auth_keys
if oauth_auth is not None: if oauth_auth is not None:
_diag_write(f"FORWARD skipped_oauth server={server_name}") logger.debug(
"forward.skipped server=%s reason=oauth_active",
server_name,
)
return headers, oauth_auth, user_auth_keys return headers, oauth_auth, user_auth_keys
inbound = request_bearer_token.get() inbound = _lookup_bearer(server_config)
if not inbound: if not inbound:
_diag_write(f"FORWARD no_inbound server={server_name}") logger.debug(
"forward.skipped server=%s reason=no_inbound_bearer",
server_name,
)
return headers, oauth_auth, user_auth_keys return headers, oauth_auth, user_auth_keys
headers = dict(headers) headers = dict(headers)
headers["Authorization"] = f"Bearer {inbound}" headers["Authorization"] = f"Bearer {inbound}"
user_auth_keys = set(user_auth_keys) | {"Authorization"} user_auth_keys = set(user_auth_keys) | {"Authorization"}
_diag_write( logger.debug(
f"FORWARD applied server={server_name} token_len={len(inbound)} prefix={inbound[:8]}" "forward.applied server=%s token_len=%d prefix=%s",
server_name,
len(inbound),
inbound[:8],
) )
return headers, oauth_auth, user_auth_keys return headers, oauth_auth, user_auth_keys
@@ -90,3 +182,9 @@ def install() -> None:
_allow_extras_on_server_settings() _allow_extras_on_server_settings()
_prepare_headers_and_auth_with_forward._pallas_forward_patched = True # type: ignore[attr-defined] _prepare_headers_and_auth_with_forward._pallas_forward_patched = True # type: ignore[attr-defined]
_mcm._prepare_headers_and_auth = _prepare_headers_and_auth_with_forward _mcm._prepare_headers_and_auth = _prepare_headers_and_auth_with_forward
# INFO so it always appears in the journal at boot — greppable proof
# that the patch ran before any agent started.
logger.info(
"bearer-forwarding patch installed "
"(forward_inbound_auth-aware _prepare_headers_and_auth)"
)

View File

@@ -17,6 +17,7 @@ Level conventions (Ouranos Lab — Python services use UPPERCASE):
import json import json
import logging import logging
import os
import re import re
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -70,15 +71,22 @@ class _HealthAccessFilter(logging.Filter):
def setup_logging() -> None: def setup_logging() -> None:
"""Configure Pallas logging. """Configure Pallas logging.
- ``pallas.*`` logger: INFO, JSON to stdout - ``pallas.*`` logger: INFO by default, JSON to stdout. Override via the
``PALLAS_LOG_LEVEL`` environment variable (``DEBUG``, ``INFO``,
``WARNING``, ``ERROR``). DEBUG unlocks bearer-forwarding diagnostics
on the ``pallas.forward`` and ``pallas.auth`` loggers — essential when
troubleshooting Mnemosyne / workspace-scoped agent calls.
- ``httpx`` / ``httpcore``: WARNING (prevent request-level debug flooding) - ``httpx`` / ``httpcore``: WARNING (prevent request-level debug flooding)
- ``uvicorn.access``: health path filter applied - ``uvicorn.access``: health path filter applied
""" """
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(_JSONFormatter()) handler.setFormatter(_JSONFormatter())
level_name = os.environ.get("PALLAS_LOG_LEVEL", "INFO").upper()
level = getattr(logging, level_name, logging.INFO)
pallas_logger = logging.getLogger("pallas") pallas_logger = logging.getLogger("pallas")
pallas_logger.setLevel(logging.INFO) pallas_logger.setLevel(level)
if not pallas_logger.handlers: if not pallas_logger.handlers:
pallas_logger.addHandler(handler) pallas_logger.addHandler(handler)
pallas_logger.propagate = False pallas_logger.propagate = False

View File

@@ -19,6 +19,7 @@ fast-agent instance whose ``message_history`` is seeded from the caller's
memory, no restart amnesia. memory, no restart amnesia.
""" """
import logging
import time import time
from typing import Any from typing import Any
@@ -28,6 +29,7 @@ from fast_agent.mcp.auth.context import request_bearer_token
from fast_agent.mcp.server import AgentMCPServer from fast_agent.mcp.server import AgentMCPServer
from fast_agent.types import PromptMessageExtended, RequestParams from fast_agent.types import PromptMessageExtended, RequestParams
from pallas._fastagent_patch import publish_bearer, revoke_bearer
from pallas.progress import EnrichedMCPToolProgressManager from pallas.progress import EnrichedMCPToolProgressManager
from fastmcp import Context as MCPContext from fastmcp import Context as MCPContext
from fastmcp.prompts import Message from fastmcp.prompts import Message
@@ -37,15 +39,11 @@ from starlette.responses import JSONResponse, Response
logger = get_logger(__name__) logger = get_logger(__name__)
# Separate stdlib logger for bearer-token diagnostics — routed through
def _diag_write(line: str) -> None: # ``pallas.log`` JSON handler to stdout / systemd journal. Gated at DEBUG so
"""Append a diagnostic line to /tmp/pallas-bearer.log, never raises.""" # it is off by default in production but trivially flipped on via
try: # ``PALLAS_LOG_LEVEL=DEBUG`` for troubleshooting agent auth issues.
from datetime import datetime _auth_log = logging.getLogger("pallas.auth")
with open("/tmp/pallas-bearer.log", "a") as f:
f.write(f"{datetime.now().isoformat()} {line}\n")
except Exception:
pass
def _get_request_bearer_token() -> str | None: def _get_request_bearer_token() -> str | None:
@@ -54,7 +52,8 @@ def _get_request_bearer_token() -> str | None:
Reads the header directly rather than going through get_access_token() because Reads the header directly rather than going through get_access_token() because
Pallas runs without FastMCP auth middleware — there is no AuthenticatedUser in Pallas runs without FastMCP auth middleware — there is no AuthenticatedUser in
the request scope, so get_access_token() always returns None here. The token the request scope, so get_access_token() always returns None here. The token
is an opaque string forwarded to opted-in downstream servers by _fastagent_patch. is an opaque string forwarded to opted-in downstream servers by
``pallas._fastagent_patch``.
""" """
try: try:
from fastmcp.server.dependencies import get_http_request from fastmcp.server.dependencies import get_http_request
@@ -63,16 +62,48 @@ def _get_request_bearer_token() -> str | None:
auth = request.headers.get("authorization", "") auth = request.headers.get("authorization", "")
if auth.lower().startswith("bearer "): if auth.lower().startswith("bearer "):
token = auth[7:] token = auth[7:]
_diag_write( _auth_log.debug(
f"BEARER captured len={len(token)} prefix={token[:8]}" "bearer.captured len=%d prefix=%s", len(token), token[:8]
) )
return token return token
_diag_write(f"BEARER absent has_auth={bool(auth)}") _auth_log.debug("bearer.absent has_auth=%s", bool(auth))
except Exception as exc: except Exception as exc:
_diag_write(f"BEARER error={exc}") _auth_log.debug("bearer.error %s", exc)
return None return None
def _forwardable_server_configs(agent) -> list[Any]:
"""Return the ``MCPServerSettings`` objects the agent is entitled to and
which are marked ``forward_inbound_auth: true``.
Restricting to the agent's own ``servers`` list ensures a request never
publishes its bearer against a server the calling agent does not use —
e.g. a Harper→Mnemosyne call must not flag Scotty→Mnemosyne's config.
Safe to call before the agent is constructed: returns an empty list if
any attribute lookup fails.
"""
try:
agent_servers = set(getattr(agent.config, "servers", []) or [])
if not agent_servers:
return []
registry = agent.context.server_registry
if registry is None:
return []
configs: list[Any] = []
for name in agent_servers:
cfg = registry.registry.get(name)
if cfg is None:
continue
if getattr(cfg, "forward_inbound_auth", False):
configs.append(cfg)
return configs
except Exception as exc:
_auth_log.debug("bearer.registry_lookup_failed %s", exc)
return []
def _history_to_fastmcp_messages( def _history_to_fastmcp_messages(
message_history: list[PromptMessageExtended], message_history: list[PromptMessageExtended],
) -> list[Message]: ) -> list[Message]:
@@ -237,17 +268,34 @@ class MultimodalAgentMCPServer(AgentMCPServer):
Optional opaque identifier, logged for trace correlation. Optional opaque identifier, logged for trace correlation.
Pallas does not interpret it. Pallas does not interpret it.
""" """
saved_token = request_bearer_token.set(_get_request_bearer_token()) inbound_bearer = _get_request_bearer_token()
saved_token = request_bearer_token.set(inbound_bearer)
report_progress = self._build_progress_reporter(ctx) report_progress = self._build_progress_reporter(ctx)
request_params = RequestParams( request_params = RequestParams(
tool_execution_handler=EnrichedMCPToolProgressManager(report_progress), tool_execution_handler=EnrichedMCPToolProgressManager(report_progress),
emit_loop_progress=True, emit_loop_progress=True,
) )
# Track which downstream server configs we publish the bearer
# against so the ``finally`` block below can revoke every one of
# them even if the agent send raises halfway through.
published_configs: list[Any] = []
try: try:
instance = await self._acquire_instance(ctx) instance = await self._acquire_instance(ctx)
agent = instance.app[agent_name] agent = instance.app[agent_name]
agent_context = getattr(agent, "context", None) agent_context = getattr(agent, "context", None)
# Register the inbound bearer against each downstream server
# config the agent is allowed to reach and which opts-in via
# ``forward_inbound_auth: true``. This is how the bearer
# crosses the anyio task-group boundary that ContextVars
# cannot hop — see ``pallas._fastagent_patch`` for the
# full explanation.
if inbound_bearer:
for srv_cfg in _forwardable_server_configs(agent):
publish_bearer(srv_cfg, inbound_bearer)
published_configs.append(srv_cfg)
# Seed the freshly-created instance's message_history from the # Seed the freshly-created instance's message_history from the
# caller-supplied history so the agent sees the full # caller-supplied history so the agent sees the full
# conversation the caller is tracking. Safe no-op when the # conversation the caller is tracking. Safe no-op when the
@@ -311,8 +359,14 @@ class MultimodalAgentMCPServer(AgentMCPServer):
finally: finally:
await self._release_instance(ctx, instance) await self._release_instance(ctx, instance)
finally: finally:
# Always revoke every bearer we published, then restore the
# ContextVar — order matters only for tidiness; a revoke that
# finds nothing is a no-op.
for srv_cfg in published_configs:
revoke_bearer(srv_cfg)
request_bearer_token.reset(saved_token) request_bearer_token.reset(saved_token)
if self._instance_scope == "request": if self._instance_scope == "request":
# With request-scoped instances there is no persistent server-side # With request-scoped instances there is no persistent server-side
# history to expose — the caller owns it. We still register the # history to expose — the caller owns it. We still register the