"""MCP token resolution and FastMCP middleware for bearer-token auth. Three credential types are accepted — see ``docs/DAEDALUS_PALLAS_INTEGRATION_v1.md`` §3.2 for the full model: 1. **Opaque ``MCPToken``** (long-lived, hashed at rest). Authorization scope is its ``allowed_libraries`` JSON list. 2. **Per-turn signed JWT** (``iss=daedalus``, ≤10 min, legacy — retires in Phase 4 when Daedalus chat itself becomes a Pallas Team). Scope is the ``libs`` claim. 3. **Team JWT** (``iss=mnemosyne``, ``typ=team``, 10-year lifetime). Scope is resolved live by joining ``TeamWorkspaceAssignment`` rows to Neo4j ``Library.workspace_id``. Every branch populates a single :data:`STATE_KEY_RESOLVED_LIBRARIES` value on the FastMCP context — a ``list[str]`` of Library UIDs the downstream tools are permitted to read. Tools never consult claim shapes; they read this list via ``mcp_server.context.get_mcp_resolved_libraries``. Detection: a bearer with three base64url segments separated by dots and a parseable ``{"alg":"HS256","kid":...}`` header is treated as JWT; anything else falls through to the opaque path. """ from __future__ import annotations import base64 import json import logging import time import uuid from collections import OrderedDict import jwt as pyjwt from asgiref.sync import sync_to_async from django.conf import settings from django.utils import timezone from fastmcp.server.dependencies import get_http_request from fastmcp.server.middleware import Middleware, MiddlewareContext from .metrics import mcp_auth_failures_total from .models import MCPSigningKey, MCPToken, Team, hash_token logger = logging.getLogger(__name__) STATE_KEY_USER = "mcp_user" STATE_KEY_TOKEN = "mcp_token" STATE_KEY_CLAIMS = "mcp_claims" STATE_KEY_RESOLVED_LIBRARIES = "mcp_resolved_libraries" # Permitted clock skew when validating JWT exp/iat. PyJWT applies this # symmetrically as ``leeway``. _JWT_LEEWAY_SECONDS = 30 # Accepted JWT issuers. # # ``daedalus`` — per-turn tokens minted by Daedalus chat (legacy path, # retires with Phase 4). # ``mnemosyne`` — team tokens minted by this service. ``typ=team`` # distinguishes them from any future self-issued credential. _JWT_ISS_VALUES = {"daedalus", "mnemosyne"} # Bounded LRU of recently-seen jti values to discourage replay within # a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this # is best-effort and per-process. # # Daedalus mints one JWT per *chat turn*, and a single turn routinely # drives several tool calls (list_libraries → search → get_document …) # that all re-use that bearer. Treating any repeat ``jti`` as replay # breaks the legitimate use case, so we store ``jti -> exp`` and only # flag a request as replay if the ``jti`` shows up *after* its own # ``exp`` has passed — that's the scenario PyJWT's own ``exp`` check # would have already rejected, this is belt-and-braces for clock skew # or a resurrected captured token. # # Team tokens (``typ=team``) bypass this cache entirely — they are # reused on every request by design. Revocation for those tokens runs # against the live ``Team`` row (``active`` + ``active_jti``). _JTI_CACHE_MAX = 4096 _JTI_CACHE: "OrderedDict[str, float]" = OrderedDict() class MCPAuthError(Exception): """Raised when a bearer token cannot be resolved to a valid user.""" # --- Opaque token path ----------------------------------------------------- def resolve_mcp_user(token_string: str): """Resolve an opaque bearer token to (user, MCPToken). Hashes the incoming bearer and looks up by the hash — plaintext is never stored or compared directly. """ try: token = ( MCPToken.objects .select_related("user") .get(token_hash=hash_token(token_string)) ) except MCPToken.DoesNotExist: raise MCPAuthError("Invalid MCP token.") if not token.is_active: raise MCPAuthError("Token has been deactivated.") if token.expires_at and token.expires_at < timezone.now(): raise MCPAuthError("Token has expired.") if not token.user.is_active: raise MCPAuthError("User account is disabled.") token.record_usage() return token.user, token # --- JWT path -------------------------------------------------------------- def looks_like_jwt(token_string: str) -> bool: """Cheap pre-check: 3 dot-separated segments and a valid base64url header that decodes to JSON with an alg field.""" parts = token_string.split(".") if len(parts) != 3: return False try: header_bytes = _b64url_decode(parts[0]) header = json.loads(header_bytes) except (ValueError, json.JSONDecodeError): return False return isinstance(header, dict) and isinstance(header.get("alg"), str) def _b64url_decode(segment: str) -> bytes: pad = "=" * (-len(segment) % 4) return base64.urlsafe_b64decode(segment + pad) def _remember_jti(jti: str, exp: float) -> bool: """Record or re-check a jti; return True iff this is a genuine replay. A "genuine replay" is a ``jti`` we've already recorded whose ``exp`` has passed (+ leeway) — that token *should not* be presentable anymore and PyJWT's ``exp`` check would normally have rejected it; this is the belt-and-braces path for clock skew or a captured token being resurrected slightly after its natural lifetime. Re-use of the same ``jti`` *within* its validity window is the intended case: Daedalus mints one token per chat turn and a turn commonly fires several Mnemosyne tool calls against that same bearer. That is **not** replay and we return False. """ now = time.time() # GC anything whose exp passed more than an hour ago — generous # given exp ≤ 600s, bounds the cache even for pathological jti # floods. gc_cutoff = now - 3600 while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < gc_cutoff: _JTI_CACHE.popitem(last=False) cached_exp = _JTI_CACHE.get(jti) if cached_exp is not None: # Refresh LRU position. _JTI_CACHE.move_to_end(jti) # Replay iff the token is already past its own expiry (with # the same symmetric leeway PyJWT applied on signature check). if now > cached_exp + _JWT_LEEWAY_SECONDS: return True return False if len(_JTI_CACHE) >= _JTI_CACHE_MAX: _JTI_CACHE.popitem(last=False) _JTI_CACHE[jti] = float(exp) return False def resolve_mcp_jwt(token_string: str) -> dict: """Validate a signed JWT and return its claims dict. Accepts both the legacy per-turn issuer (``iss=daedalus``) and the new team issuer (``iss=mnemosyne``, ``typ=team``). The returned claims dict is normalized so the middleware doesn't have to guess: * ``claims["iss"]`` — as presented (``daedalus`` or ``mnemosyne``). * ``claims["typ"]`` — ``"team"`` for team tokens, otherwise absent. * ``claims["libs"]`` — per-turn only; normalized to ``list[str]``. * ``claims["ws"]`` — per-turn only; may be ``None``. Not consulted for authorization (kept for diagnostics). * ``claims["team_id"]`` — team only; ``UUID`` parsed from ``sub == "team:"``. * ``claims["kid"]`` — copy of the JWT header's ``kid``. Raises :class:`MCPAuthError` on any failure. The per-turn path runs the ``_remember_jti`` replay check; the team path skips it (team JWTs are intentionally reused across the token's lifetime). """ try: unverified_header = pyjwt.get_unverified_header(token_string) except pyjwt.PyJWTError as exc: raise MCPAuthError(f"Malformed JWT header: {exc}") kid = unverified_header.get("kid") if not kid: raise MCPAuthError("JWT header missing kid.") if unverified_header.get("alg") != "HS256": raise MCPAuthError("Unsupported JWT alg.") key = MCPSigningKey.objects.by_kid(kid) if key is None: raise MCPAuthError(f"Unknown signing key kid={kid!r}.") if not key.is_active: raise MCPAuthError(f"Signing key {kid!r} has been retired.") try: secret = bytes.fromhex(key.secret_hex) except ValueError: raise MCPAuthError(f"Stored secret for kid={kid!r} is not valid hex.") try: claims = pyjwt.decode( token_string, secret, algorithms=["HS256"], leeway=_JWT_LEEWAY_SECONDS, options={ "require": ["exp", "iat", "iss", "sub", "jti"], # Team JWTs carry ``aud=mnemosyne`` for informational # purposes; per-turn JWTs omit ``aud`` entirely. We # don't enforce either shape because ``iss`` + ``typ`` # already partition the two token populations. "verify_aud": False, }, # ``issuer=`` accepts ``str | Iterable[str]`` and raises # ``InvalidIssuerError`` if the claim is outside the set. issuer=list(_JWT_ISS_VALUES), ) except pyjwt.ExpiredSignatureError: raise MCPAuthError("Token has expired.") except pyjwt.InvalidIssuerError: raise MCPAuthError("Invalid token issuer.") except pyjwt.InvalidSignatureError: raise MCPAuthError("Invalid MCP token.") except pyjwt.MissingRequiredClaimError as exc: raise MCPAuthError(f"JWT missing required claim: {exc.claim}") except pyjwt.PyJWTError as exc: raise MCPAuthError(f"Invalid JWT: {exc}") jti = claims.get("jti") if not isinstance(jti, str) or not jti: raise MCPAuthError("JWT jti must be a non-empty string.") exp = claims.get("exp") if not isinstance(exp, (int, float)): # ``require=["exp", ...]`` above guarantees presence + numeric; this # is defence in depth against future PyJWT changes. raise MCPAuthError("JWT exp must be numeric.") typ = claims.get("typ") if typ == "team": # Team tokens: no replay cache, no ``libs`` or ``ws`` claims. # Verify the ``sub`` shape and parse the embedded team UUID so # the middleware doesn't have to re-parse it later. sub = claims.get("sub") if not isinstance(sub, str) or not sub.startswith("team:"): raise MCPAuthError("Invalid MCP token.") try: claims["team_id"] = uuid.UUID(sub[len("team:"):]) except ValueError: raise MCPAuthError("Invalid MCP token.") else: # Per-turn (legacy) path: replay-cache gate + normalize claims. if _remember_jti(jti, float(exp)): raise MCPAuthError("Token replay detected.") claims["ws"] = claims.get("ws") or None libs = claims.get("libs") or [] if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs): raise MCPAuthError("JWT libs must be a list of strings.") claims["libs"] = libs claims["kid"] = kid return claims # --- Team-JWT library resolution ------------------------------------------ def _libraries_for_team(team_id: uuid.UUID, jti: str) -> list[str]: """Resolve a team token to the Library UIDs it may read. Runs two cheap queries in sequence: 1. Fetch the ``Team`` row by UUID. Reject if it doesn't exist, is inactive, or its ``active_jti`` doesn't match the incoming ``jti`` — this is how rotation / soft-delete revocation becomes effective on the *next* request. 2. If active: read every ``TeamWorkspaceAssignment.workspace_id`` for the team and translate them into Library UIDs via a single Cypher query against Neo4j. Returns an empty list when the team has no workspace assignments (fail-closed — a team pointing at no workspaces sees no libraries). """ try: team = Team.objects.get(pk=team_id) except Team.DoesNotExist: raise MCPAuthError("Invalid MCP token.") if not team.active: raise MCPAuthError("Token has been deactivated.") if team.active_jti is None or str(team.active_jti) != jti: raise MCPAuthError("Invalid MCP token.") workspace_ids = list( team.workspace_assignments.values_list("workspace_id", flat=True) ) if not workspace_ids: return [] from neomodel import db rows, _ = db.cypher_query( "MATCH (l:Library) WHERE l.workspace_id IN $ws RETURN l.uid", {"ws": workspace_ids}, ) return [row[0] for row in rows if row and row[0]] # --- Middleware ------------------------------------------------------------ class MCPAuthMiddleware(Middleware): """ FastMCP middleware that authenticates tool calls via Bearer tokens. Listing tools/resources is permitted unauthenticated so clients can discover the surface; calling a tool requires a valid token unless ``MCP_REQUIRE_AUTH=False``. On every authenticated call the middleware attaches four values to the FastMCP ``Context`` state for downstream tools to consume via :mod:`mcp_server.context`: * ``STATE_KEY_USER`` — Django user. * ``STATE_KEY_TOKEN`` — MCPToken row (opaque callers only). * ``STATE_KEY_CLAIMS`` — JWT claims dict (JWT callers only). * ``STATE_KEY_RESOLVED_LIBRARIES`` — authorization-resolved Library UID list. Tools read this; they never read ``STATE_KEY_CLAIMS`` for authorization. """ # Tools that don't touch user data and must be callable without a token # (e.g. Pallas health pollers, agent startup probes). _PUBLIC_TOOLS = {"get_health"} async def on_call_tool(self, context: MiddlewareContext, call_next): require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True) tool_name = self._extract_tool_name(context) logger.info( "mcp_auth.on_call_tool tool=%s require_auth=%s", tool_name, require_auth, ) if require_auth and tool_name in self._PUBLIC_TOOLS: return await self._call_next_with_trace( tool_name, call_next, context ) token_string = self._extract_token() user = None token = None claims: dict | None = None resolved_libraries: list[str] | None = None if token_string: try: if looks_like_jwt(token_string): claims = await sync_to_async( resolve_mcp_jwt, thread_sensitive=True )(token_string) user = await sync_to_async( _resolve_jwt_actor, thread_sensitive=True )(claims) resolved_libraries = await sync_to_async( _resolved_libraries_for_jwt, thread_sensitive=True )(claims) else: user, token = await sync_to_async( resolve_mcp_user, thread_sensitive=True )(token_string) # Opaque tokens store the Library UID list directly. # Empty list = fail-closed; not "everything". resolved_libraries = list(token.allowed_libraries or []) except MCPAuthError as exc: mcp_auth_failures_total.labels(reason=str(exc)).inc() if require_auth: raise PermissionError(str(exc)) elif require_auth: mcp_auth_failures_total.labels(reason="missing_token").inc() raise PermissionError("Authentication required. Provide a Bearer token.") if token and tool_name and not token.can_use_tool(tool_name): mcp_auth_failures_total.labels(reason="tool_not_allowed").inc() raise PermissionError( f"Token does not have permission to call '{tool_name}'." ) fastmcp_ctx = getattr(context, "fastmcp_context", None) if fastmcp_ctx is not None: if user is not None: await fastmcp_ctx.set_state(STATE_KEY_USER, user) if token is not None: await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token) if claims is not None: await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims) # Always publish resolved_libraries — None means "no auth # information" and the tools treat that as fail-closed. await fastmcp_ctx.set_state( STATE_KEY_RESOLVED_LIBRARIES, resolved_libraries ) logger.info( "mcp_auth.resolved tool=%s principal=%s lib_count=%s", tool_name, self._describe_principal(user, token, claims), "none" if resolved_libraries is None else len(resolved_libraries), ) return await self._call_next_with_trace(tool_name, call_next, context) @staticmethod async def _call_next_with_trace(tool_name, call_next, context): """Run ``call_next`` and log any exception with a full traceback. During the Pallas↔Mnemosyne shakedown we were seeing tool results come back to fast-agent as the string ``"object NoneType can't be used in 'await' expression"`` with no trace anywhere in either process. That string is Python's ``TypeError`` for ``await X`` where ``X`` is ``None`` (i.e. someone awaited a non-coroutine). If that TypeError is raised inside the FastMCP dispatch we want the full traceback in Mnemosyne's own log rather than silently letting it propagate back to Pallas where the aggregator turns it into a ``CallToolResult(isError=True)`` and we lose the frame. We also log the *type* of the successful return so we can verify later that FastMCP is returning ``ToolResult`` / ``CallToolResult`` the way we expect. Keep at INFO until green, demote to DEBUG afterwards. """ try: result = await call_next(context) except Exception: logger.exception( "mcp_auth.call_next_failed tool=%s", tool_name ) raise logger.info( "mcp_auth.call_next_ok tool=%s result_type=%s", tool_name, type(result).__name__, ) return result @staticmethod def _describe_principal(user, token, claims) -> str: """Compact, log-friendly principal summary. No PII beyond usernames.""" if claims is not None: typ = claims.get("typ") if typ == "team": return f"team:{claims.get('team_id')}" return f"jwt:{claims.get('sub')}" if token is not None: return f"mcptoken:{token.get_masked_token()}" if user is not None: return f"user:{user.username}" return "anonymous" @staticmethod def _extract_token() -> str | None: """Pull the Bearer token off the current HTTP request, if any. The MCP SDK stores the Starlette ``Request`` on ``request_ctx`` for every tool-call dispatch (including follow-up calls on a stateful session), so ``get_http_request()`` should succeed here. If it *doesn't* — e.g. because we're in a background task or pre-session initialize hook — we return ``None`` and let the caller decide. INFO-level logging is intentional until bearer-forwarding from Daedalus/Pallas is fully shaken out; demote to DEBUG once green. """ try: request = get_http_request() except RuntimeError as exc: logger.warning( "mcp_auth.extract outcome=no_http_request reason=%r", str(exc), ) return None # Header lookup is case-insensitive on Starlette ``Headers`` but # different proxies normalize case differently — belt and braces. auth_header = request.headers.get("Authorization", "") if not auth_header: auth_header = request.headers.get("authorization", "") header_names = sorted(request.headers.keys()) logger.info( "mcp_auth.extract outcome=%s len=%d prefix=%r path=%s header_names=%s", "present" if auth_header else "missing", len(auth_header), auth_header[:16] if auth_header else "", str(getattr(request, "url", "")), header_names, ) if auth_header.startswith("Bearer "): return auth_header[7:].strip() or None if auth_header.startswith("bearer "): # some clients lowercase return auth_header[7:].strip() or None return None @staticmethod def _extract_tool_name(context: MiddlewareContext) -> str | None: """Pull the tool name off a FastMCP ``on_call_tool`` context. In FastMCP middleware, ``context.message`` inside ``on_call_tool`` is already a ``CallToolRequestParams`` (see ``fastmcp.server.middleware.middleware.Middleware.on_call_tool`` signature: ``MiddlewareContext[mt.CallToolRequestParams]``), so ``name`` lives directly on ``context.message`` — there is no nested ``.params``. The older ``message.params.name`` access we had here always returned ``None``, which caused the public-tools bypass to silently miss ``get_health`` and made the per-tool ACL short-circuit. Fall back to ``.params.name`` only as a legacy safety net in case the shape ever diverges. """ msg = getattr(context, "message", None) if msg is None: return None name = getattr(msg, "name", None) if name: return name params = getattr(msg, "params", None) return getattr(params, "name", None) if params is not None else None # --- Helpers --------------------------------------------------------------- def _resolve_jwt_actor(claims: dict): """Resolve the synthetic actor for a JWT-authenticated turn. Returns the system service user (``MCP_JWT_SERVICE_USERNAME``, default ``daedalus-service``). The user must exist and be active. JWT tokens are not tied to per-user accounts — claims encode all authorization. Used for both per-turn and team JWTs. The service user is a hook for usage accounting (LLMUsage / search metrics) and for the audit trail; authorization does not depend on it. """ from django.contrib.auth import get_user_model User = get_user_model() username = getattr(settings, "MCP_JWT_SERVICE_USERNAME", "daedalus-service") try: user = User.objects.get(username=username) except User.DoesNotExist: raise MCPAuthError( f"JWT service user {username!r} does not exist; provision via management command." ) if not user.is_active: raise MCPAuthError(f"JWT service user {username!r} is disabled.") return user def _resolved_libraries_for_jwt(claims: dict) -> list[str]: """Pick the right resolver branch for a validated JWT claims dict. * ``typ == "team"`` → live lookup via :func:`_libraries_for_team`. * otherwise → legacy ``claims["libs"]`` (per-turn JWT). """ if claims.get("typ") == "team": return _libraries_for_team(claims["team_id"], claims["jti"]) return list(claims.get("libs") or [])