In FastMCP's on_call_tool hook the middleware context is already MiddlewareContext[CallToolRequestParams] (per fastmcp's own middleware.py:158), so tool name lives at context.message.name, not at context.message.params.name — the latter always returned None, silently breaking the PUBLIC_TOOLS bypass for get_health and making the per-tool ACL short-circuit. Also wrap call_next in a traced helper that logs any exception with a full traceback and logs the success-path result type. During the Pallas↔Mnemosyne shakedown the tool results were coming back to fast-agent as the literal string "object NoneType can't be used in 'await' expression" with no trace in either process — that's Python's TypeError for 'await X' where X is None. If that TypeError is raised inside FastMCP dispatch we want the frame in Mnemosyne's own log rather than having Pallas's aggregator turn it into a terse CallToolResult(isError=True) with no stack.
429 lines
16 KiB
Python
429 lines
16 KiB
Python
"""MCP token resolution and FastMCP middleware for bearer-token auth.
|
|
|
|
Two token shapes are supported:
|
|
|
|
* **Opaque** — the original `MCPToken` row. Long-lived, hashed at rest,
|
|
used by the dashboard / Claude Desktop / admin tooling. Plaintext
|
|
hashes to a row in `mcp_token`.
|
|
* **Signed JWT** — per-turn token minted by Daedalus. Carries
|
|
`{ws, libs}` claims. Validated entirely off the signature + claims;
|
|
no database lookup of the token itself, only of the signing key
|
|
(`MCPSigningKey`) referenced by the JWT header's `kid`.
|
|
|
|
Detection: a bearer with three base64url segments separated by dots and
|
|
a parseable `{"alg":"HS256","kid":...}` header is treated as JWT; anything
|
|
else falls through to the opaque path.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import time
|
|
from collections import OrderedDict
|
|
|
|
import jwt as pyjwt
|
|
from asgiref.sync import sync_to_async
|
|
from django.conf import settings
|
|
from django.utils import timezone
|
|
from fastmcp.server.dependencies import get_http_request
|
|
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
|
|
|
from .metrics import mcp_auth_failures_total
|
|
from .models import MCPSigningKey, MCPToken, hash_token
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
STATE_KEY_USER = "mcp_user"
|
|
STATE_KEY_TOKEN = "mcp_token"
|
|
STATE_KEY_CLAIMS = "mcp_claims"
|
|
|
|
# Permitted clock skew when validating JWT exp/iat. PyJWT applies this
|
|
# symmetrically as ``leeway``.
|
|
_JWT_LEEWAY_SECONDS = 30
|
|
|
|
# Mnemosyne is the audience; Daedalus is the only accepted issuer.
|
|
_JWT_ISS = "daedalus"
|
|
|
|
# Bounded LRU of recently-seen jti values to discourage replay within
|
|
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
|
|
# is best-effort and per-process.
|
|
#
|
|
# Daedalus mints one JWT per *chat turn*, and a single turn routinely
|
|
# drives several tool calls (list_libraries → search → get_document …)
|
|
# that all re-use that bearer. Treating any repeat ``jti`` as replay
|
|
# breaks the legitimate use case, so we store ``jti -> exp`` and only
|
|
# flag a request as replay if the ``jti`` shows up *after* its own
|
|
# ``exp`` has passed — that's the scenario PyJWT's own ``exp`` check
|
|
# would have already rejected, this is belt-and-braces for clock skew
|
|
# or a resurrected captured token.
|
|
_JTI_CACHE_MAX = 4096
|
|
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
|
|
|
|
|
|
class MCPAuthError(Exception):
|
|
"""Raised when a bearer token cannot be resolved to a valid user."""
|
|
|
|
|
|
# --- Opaque token path -----------------------------------------------------
|
|
|
|
|
|
def resolve_mcp_user(token_string: str):
|
|
"""Resolve an opaque bearer token to (user, MCPToken).
|
|
|
|
Hashes the incoming bearer and looks up by the hash — plaintext is never
|
|
stored or compared directly.
|
|
"""
|
|
try:
|
|
token = (
|
|
MCPToken.objects
|
|
.select_related("user")
|
|
.get(token_hash=hash_token(token_string))
|
|
)
|
|
except MCPToken.DoesNotExist:
|
|
raise MCPAuthError("Invalid MCP token.")
|
|
|
|
if not token.is_active:
|
|
raise MCPAuthError("Token has been deactivated.")
|
|
if token.expires_at and token.expires_at < timezone.now():
|
|
raise MCPAuthError("Token has expired.")
|
|
if not token.user.is_active:
|
|
raise MCPAuthError("User account is disabled.")
|
|
|
|
token.record_usage()
|
|
return token.user, token
|
|
|
|
|
|
# --- JWT path --------------------------------------------------------------
|
|
|
|
|
|
def looks_like_jwt(token_string: str) -> bool:
|
|
"""Cheap pre-check: 3 dot-separated segments and a valid base64url header
|
|
that decodes to JSON with an alg field."""
|
|
parts = token_string.split(".")
|
|
if len(parts) != 3:
|
|
return False
|
|
try:
|
|
header_bytes = _b64url_decode(parts[0])
|
|
header = json.loads(header_bytes)
|
|
except (ValueError, json.JSONDecodeError):
|
|
return False
|
|
return isinstance(header, dict) and isinstance(header.get("alg"), str)
|
|
|
|
|
|
def _b64url_decode(segment: str) -> bytes:
|
|
pad = "=" * (-len(segment) % 4)
|
|
return base64.urlsafe_b64decode(segment + pad)
|
|
|
|
|
|
def _remember_jti(jti: str, exp: float) -> bool:
|
|
"""Record or re-check a jti; return True iff this is a genuine replay.
|
|
|
|
A "genuine replay" is a ``jti`` we've already recorded whose ``exp``
|
|
has passed (+ leeway) — that token *should not* be presentable
|
|
anymore and PyJWT's ``exp`` check would normally have rejected it;
|
|
this is the belt-and-braces path for clock skew or a captured
|
|
token being resurrected slightly after its natural lifetime.
|
|
|
|
Re-use of the same ``jti`` *within* its validity window is the
|
|
intended case: Daedalus mints one token per chat turn and a turn
|
|
commonly fires several Mnemosyne tool calls against that same
|
|
bearer. That is **not** replay and we return False.
|
|
"""
|
|
now = time.time()
|
|
# GC anything whose exp passed more than an hour ago — generous
|
|
# given exp ≤ 600s, bounds the cache even for pathological jti
|
|
# floods.
|
|
gc_cutoff = now - 3600
|
|
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < gc_cutoff:
|
|
_JTI_CACHE.popitem(last=False)
|
|
|
|
cached_exp = _JTI_CACHE.get(jti)
|
|
if cached_exp is not None:
|
|
# Refresh LRU position.
|
|
_JTI_CACHE.move_to_end(jti)
|
|
# Replay iff the token is already past its own expiry (with
|
|
# the same symmetric leeway PyJWT applied on signature check).
|
|
if now > cached_exp + _JWT_LEEWAY_SECONDS:
|
|
return True
|
|
return False
|
|
|
|
if len(_JTI_CACHE) >= _JTI_CACHE_MAX:
|
|
_JTI_CACHE.popitem(last=False)
|
|
_JTI_CACHE[jti] = float(exp)
|
|
return False
|
|
|
|
|
|
def resolve_mcp_jwt(token_string: str) -> dict:
|
|
"""Validate a signed JWT and return its claims dict.
|
|
|
|
Raises ``MCPAuthError`` on any failure. Does not touch ``MCPToken`` —
|
|
JWTs are stateless and stored only as their signing key (``MCPSigningKey``).
|
|
"""
|
|
try:
|
|
unverified_header = pyjwt.get_unverified_header(token_string)
|
|
except pyjwt.PyJWTError as exc:
|
|
raise MCPAuthError(f"Malformed JWT header: {exc}")
|
|
|
|
kid = unverified_header.get("kid")
|
|
if not kid:
|
|
raise MCPAuthError("JWT header missing kid.")
|
|
if unverified_header.get("alg") != "HS256":
|
|
raise MCPAuthError("Unsupported JWT alg.")
|
|
|
|
key = MCPSigningKey.objects.by_kid(kid)
|
|
if key is None:
|
|
raise MCPAuthError(f"Unknown signing key kid={kid!r}.")
|
|
if not key.is_active:
|
|
raise MCPAuthError(f"Signing key {kid!r} has been retired.")
|
|
|
|
try:
|
|
secret = bytes.fromhex(key.secret_hex)
|
|
except ValueError:
|
|
raise MCPAuthError(f"Stored secret for kid={kid!r} is not valid hex.")
|
|
|
|
try:
|
|
claims = pyjwt.decode(
|
|
token_string,
|
|
secret,
|
|
algorithms=["HS256"],
|
|
leeway=_JWT_LEEWAY_SECONDS,
|
|
options={"require": ["exp", "iat", "iss", "sub", "jti"]},
|
|
issuer=_JWT_ISS,
|
|
)
|
|
except pyjwt.ExpiredSignatureError:
|
|
raise MCPAuthError("Token has expired.")
|
|
except pyjwt.InvalidIssuerError:
|
|
raise MCPAuthError("Invalid token issuer.")
|
|
except pyjwt.InvalidSignatureError:
|
|
raise MCPAuthError("Invalid MCP token.")
|
|
except pyjwt.MissingRequiredClaimError as exc:
|
|
raise MCPAuthError(f"JWT missing required claim: {exc.claim}")
|
|
except pyjwt.PyJWTError as exc:
|
|
raise MCPAuthError(f"Invalid JWT: {exc}")
|
|
|
|
jti = claims.get("jti")
|
|
if not isinstance(jti, str) or not jti:
|
|
raise MCPAuthError("JWT jti must be a non-empty string.")
|
|
exp = claims.get("exp")
|
|
if not isinstance(exp, (int, float)):
|
|
# ``require=["exp", ...]`` above guarantees presence + numeric; this
|
|
# is defence in depth against future PyJWT changes.
|
|
raise MCPAuthError("JWT exp must be numeric.")
|
|
if _remember_jti(jti, float(exp)):
|
|
raise MCPAuthError("Token replay detected.")
|
|
|
|
# Normalize claim shapes: ws may be null/absent, libs default to [].
|
|
claims["ws"] = claims.get("ws") or None
|
|
libs = claims.get("libs") or []
|
|
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
|
|
raise MCPAuthError("JWT libs must be a list of strings.")
|
|
claims["libs"] = libs
|
|
claims["kid"] = kid
|
|
return claims
|
|
|
|
|
|
# --- Middleware ------------------------------------------------------------
|
|
|
|
|
|
class MCPAuthMiddleware(Middleware):
|
|
"""
|
|
FastMCP middleware that authenticates tool calls via Bearer tokens.
|
|
|
|
Listing tools/resources is permitted unauthenticated so clients can
|
|
discover the surface; calling a tool requires a valid token unless
|
|
MCP_REQUIRE_AUTH=False.
|
|
"""
|
|
|
|
# Tools that don't touch user data and must be callable without a token
|
|
# (e.g. Pallas health pollers, agent startup probes).
|
|
_PUBLIC_TOOLS = {"get_health"}
|
|
|
|
async def on_call_tool(self, context: MiddlewareContext, call_next):
|
|
require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True)
|
|
|
|
tool_name = self._extract_tool_name(context)
|
|
logger.info(
|
|
"mcp_auth.on_call_tool tool=%s require_auth=%s",
|
|
tool_name,
|
|
require_auth,
|
|
)
|
|
|
|
if require_auth and tool_name in self._PUBLIC_TOOLS:
|
|
return await self._call_next_with_trace(
|
|
tool_name, call_next, context
|
|
)
|
|
|
|
token_string = self._extract_token()
|
|
|
|
user = None
|
|
token = None
|
|
claims: dict | None = None
|
|
|
|
if token_string:
|
|
try:
|
|
if looks_like_jwt(token_string):
|
|
claims = await sync_to_async(
|
|
resolve_mcp_jwt, thread_sensitive=True
|
|
)(token_string)
|
|
user = await sync_to_async(
|
|
_resolve_jwt_actor, thread_sensitive=True
|
|
)(claims)
|
|
else:
|
|
user, token = await sync_to_async(
|
|
resolve_mcp_user, thread_sensitive=True
|
|
)(token_string)
|
|
except MCPAuthError as exc:
|
|
mcp_auth_failures_total.labels(reason=str(exc)).inc()
|
|
if require_auth:
|
|
raise PermissionError(str(exc))
|
|
elif require_auth:
|
|
mcp_auth_failures_total.labels(reason="missing_token").inc()
|
|
raise PermissionError("Authentication required. Provide a Bearer token.")
|
|
|
|
tool_name = self._extract_tool_name(context)
|
|
if token and tool_name and not token.can_use_tool(tool_name):
|
|
mcp_auth_failures_total.labels(reason="tool_not_allowed").inc()
|
|
raise PermissionError(
|
|
f"Token does not have permission to call '{tool_name}'."
|
|
)
|
|
|
|
fastmcp_ctx = getattr(context, "fastmcp_context", None)
|
|
if fastmcp_ctx is not None:
|
|
if user is not None:
|
|
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
|
|
if token is not None:
|
|
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
|
|
if claims is not None:
|
|
await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims)
|
|
|
|
return await self._call_next_with_trace(tool_name, call_next, context)
|
|
|
|
@staticmethod
|
|
async def _call_next_with_trace(tool_name, call_next, context):
|
|
"""Run ``call_next`` and log any exception with a full traceback.
|
|
|
|
During the Pallas↔Mnemosyne shakedown we were seeing tool results
|
|
come back to fast-agent as the string ``"object NoneType can't be
|
|
used in 'await' expression"`` with no trace anywhere in either
|
|
process. That string is Python's ``TypeError`` for ``await X``
|
|
where ``X`` is ``None`` (i.e. someone awaited a non-coroutine).
|
|
If that TypeError is raised inside the FastMCP dispatch we want
|
|
the full traceback in Mnemosyne's own log rather than silently
|
|
letting it propagate back to Pallas where the aggregator turns
|
|
it into a ``CallToolResult(isError=True)`` and we lose the frame.
|
|
|
|
We also log the *type* of the successful return so we can verify
|
|
later that FastMCP is returning ``ToolResult`` / ``CallToolResult``
|
|
the way we expect. Keep at INFO until green, demote to DEBUG
|
|
afterwards.
|
|
"""
|
|
try:
|
|
result = await call_next(context)
|
|
except Exception:
|
|
logger.exception(
|
|
"mcp_auth.call_next_failed tool=%s", tool_name
|
|
)
|
|
raise
|
|
logger.info(
|
|
"mcp_auth.call_next_ok tool=%s result_type=%s",
|
|
tool_name,
|
|
type(result).__name__,
|
|
)
|
|
return result
|
|
|
|
@staticmethod
|
|
def _extract_token() -> str | None:
|
|
"""Pull the Bearer token off the current HTTP request, if any.
|
|
|
|
The MCP SDK stores the Starlette ``Request`` on ``request_ctx`` for
|
|
every tool-call dispatch (including follow-up calls on a stateful
|
|
session), so ``get_http_request()`` should succeed here. If it
|
|
*doesn't* — e.g. because we're in a background task or pre-session
|
|
initialize hook — we return ``None`` and let the caller decide.
|
|
|
|
INFO-level logging is intentional until bearer-forwarding from
|
|
Daedalus/Pallas is fully shaken out; demote to DEBUG once green.
|
|
"""
|
|
try:
|
|
request = get_http_request()
|
|
except RuntimeError as exc:
|
|
logger.warning(
|
|
"mcp_auth.extract outcome=no_http_request reason=%r",
|
|
str(exc),
|
|
)
|
|
return None
|
|
|
|
# Header lookup is case-insensitive on Starlette ``Headers`` but
|
|
# different proxies normalize case differently — belt and braces.
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if not auth_header:
|
|
auth_header = request.headers.get("authorization", "")
|
|
|
|
header_names = sorted(request.headers.keys())
|
|
logger.info(
|
|
"mcp_auth.extract outcome=%s len=%d prefix=%r path=%s header_names=%s",
|
|
"present" if auth_header else "missing",
|
|
len(auth_header),
|
|
auth_header[:16] if auth_header else "",
|
|
str(getattr(request, "url", "")),
|
|
header_names,
|
|
)
|
|
|
|
if auth_header.startswith("Bearer "):
|
|
return auth_header[7:].strip() or None
|
|
if auth_header.startswith("bearer "): # some clients lowercase
|
|
return auth_header[7:].strip() or None
|
|
return None
|
|
|
|
@staticmethod
|
|
def _extract_tool_name(context: MiddlewareContext) -> str | None:
|
|
"""Pull the tool name off a FastMCP ``on_call_tool`` context.
|
|
|
|
In FastMCP middleware, ``context.message`` inside ``on_call_tool``
|
|
is already a ``CallToolRequestParams`` (see
|
|
``fastmcp.server.middleware.middleware.Middleware.on_call_tool``
|
|
signature: ``MiddlewareContext[mt.CallToolRequestParams]``), so
|
|
``name`` lives directly on ``context.message`` — there is no
|
|
nested ``.params``. The older ``message.params.name`` access we
|
|
had here always returned ``None``, which caused the public-tools
|
|
bypass to silently miss ``get_health`` and made the per-tool ACL
|
|
short-circuit. Fall back to ``.params.name`` only as a legacy
|
|
safety net in case the shape ever diverges.
|
|
"""
|
|
msg = getattr(context, "message", None)
|
|
if msg is None:
|
|
return None
|
|
name = getattr(msg, "name", None)
|
|
if name:
|
|
return name
|
|
params = getattr(msg, "params", None)
|
|
return getattr(params, "name", None) if params is not None else None
|
|
|
|
|
|
# --- Helpers ---------------------------------------------------------------
|
|
|
|
|
|
def _resolve_jwt_actor(claims: dict):
|
|
"""Resolve the synthetic actor for a JWT-authenticated turn.
|
|
|
|
Returns the system service user (``MCP_JWT_SERVICE_USERNAME``, default
|
|
``daedalus-service``). The user must exist and be active. JWT tokens
|
|
are not tied to per-user accounts — claims encode all authorization.
|
|
"""
|
|
from django.contrib.auth import get_user_model
|
|
|
|
User = get_user_model()
|
|
username = getattr(settings, "MCP_JWT_SERVICE_USERNAME", "daedalus-service")
|
|
try:
|
|
user = User.objects.get(username=username)
|
|
except User.DoesNotExist:
|
|
raise MCPAuthError(
|
|
f"JWT service user {username!r} does not exist; provision via management command."
|
|
)
|
|
if not user.is_active:
|
|
raise MCPAuthError(f"JWT service user {username!r} is disabled.")
|
|
return user
|