Files
mnemosyne/mnemosyne/mcp_server/auth.py
Robert Helewka e0fa825189
All checks were successful
CVE Scan & Docker Build / security-scan (push) Successful in 50s
CVE Scan & Docker Build / build-and-push (push) Successful in 2m26s
auth: read tool name off context.message directly; trace call_next failures
In FastMCP's on_call_tool hook the middleware context is already
MiddlewareContext[CallToolRequestParams] (per fastmcp's own
middleware.py:158), so tool name lives at context.message.name, not
at context.message.params.name — the latter always returned None,
silently breaking the PUBLIC_TOOLS bypass for get_health and making
the per-tool ACL short-circuit.

Also wrap call_next in a traced helper that logs any exception with
a full traceback and logs the success-path result type.  During the
Pallas↔Mnemosyne shakedown the tool results were coming back to
fast-agent as the literal string "object NoneType can't be used in
'await' expression" with no trace in either process — that's Python's
TypeError for 'await X' where X is None.  If that TypeError is raised
inside FastMCP dispatch we want the frame in Mnemosyne's own log
rather than having Pallas's aggregator turn it into a terse
CallToolResult(isError=True) with no stack.
2026-05-06 19:47:52 -04:00

429 lines
16 KiB
Python

"""MCP token resolution and FastMCP middleware for bearer-token auth.
Two token shapes are supported:
* **Opaque** — the original `MCPToken` row. Long-lived, hashed at rest,
used by the dashboard / Claude Desktop / admin tooling. Plaintext
hashes to a row in `mcp_token`.
* **Signed JWT** — per-turn token minted by Daedalus. Carries
`{ws, libs}` claims. Validated entirely off the signature + claims;
no database lookup of the token itself, only of the signing key
(`MCPSigningKey`) referenced by the JWT header's `kid`.
Detection: a bearer with three base64url segments separated by dots and
a parseable `{"alg":"HS256","kid":...}` header is treated as JWT; anything
else falls through to the opaque path.
"""
from __future__ import annotations
import base64
import hashlib
import json
import logging
import time
from collections import OrderedDict
import jwt as pyjwt
from asgiref.sync import sync_to_async
from django.conf import settings
from django.utils import timezone
from fastmcp.server.dependencies import get_http_request
from fastmcp.server.middleware import Middleware, MiddlewareContext
from .metrics import mcp_auth_failures_total
from .models import MCPSigningKey, MCPToken, hash_token
logger = logging.getLogger(__name__)
STATE_KEY_USER = "mcp_user"
STATE_KEY_TOKEN = "mcp_token"
STATE_KEY_CLAIMS = "mcp_claims"
# Permitted clock skew when validating JWT exp/iat. PyJWT applies this
# symmetrically as ``leeway``.
_JWT_LEEWAY_SECONDS = 30
# Mnemosyne is the audience; Daedalus is the only accepted issuer.
_JWT_ISS = "daedalus"
# Bounded LRU of recently-seen jti values to discourage replay within
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
# is best-effort and per-process.
#
# Daedalus mints one JWT per *chat turn*, and a single turn routinely
# drives several tool calls (list_libraries → search → get_document …)
# that all re-use that bearer. Treating any repeat ``jti`` as replay
# breaks the legitimate use case, so we store ``jti -> exp`` and only
# flag a request as replay if the ``jti`` shows up *after* its own
# ``exp`` has passed — that's the scenario PyJWT's own ``exp`` check
# would have already rejected, this is belt-and-braces for clock skew
# or a resurrected captured token.
_JTI_CACHE_MAX = 4096
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
class MCPAuthError(Exception):
"""Raised when a bearer token cannot be resolved to a valid user."""
# --- Opaque token path -----------------------------------------------------
def resolve_mcp_user(token_string: str):
"""Resolve an opaque bearer token to (user, MCPToken).
Hashes the incoming bearer and looks up by the hash — plaintext is never
stored or compared directly.
"""
try:
token = (
MCPToken.objects
.select_related("user")
.get(token_hash=hash_token(token_string))
)
except MCPToken.DoesNotExist:
raise MCPAuthError("Invalid MCP token.")
if not token.is_active:
raise MCPAuthError("Token has been deactivated.")
if token.expires_at and token.expires_at < timezone.now():
raise MCPAuthError("Token has expired.")
if not token.user.is_active:
raise MCPAuthError("User account is disabled.")
token.record_usage()
return token.user, token
# --- JWT path --------------------------------------------------------------
def looks_like_jwt(token_string: str) -> bool:
"""Cheap pre-check: 3 dot-separated segments and a valid base64url header
that decodes to JSON with an alg field."""
parts = token_string.split(".")
if len(parts) != 3:
return False
try:
header_bytes = _b64url_decode(parts[0])
header = json.loads(header_bytes)
except (ValueError, json.JSONDecodeError):
return False
return isinstance(header, dict) and isinstance(header.get("alg"), str)
def _b64url_decode(segment: str) -> bytes:
pad = "=" * (-len(segment) % 4)
return base64.urlsafe_b64decode(segment + pad)
def _remember_jti(jti: str, exp: float) -> bool:
"""Record or re-check a jti; return True iff this is a genuine replay.
A "genuine replay" is a ``jti`` we've already recorded whose ``exp``
has passed (+ leeway) — that token *should not* be presentable
anymore and PyJWT's ``exp`` check would normally have rejected it;
this is the belt-and-braces path for clock skew or a captured
token being resurrected slightly after its natural lifetime.
Re-use of the same ``jti`` *within* its validity window is the
intended case: Daedalus mints one token per chat turn and a turn
commonly fires several Mnemosyne tool calls against that same
bearer. That is **not** replay and we return False.
"""
now = time.time()
# GC anything whose exp passed more than an hour ago — generous
# given exp ≤ 600s, bounds the cache even for pathological jti
# floods.
gc_cutoff = now - 3600
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < gc_cutoff:
_JTI_CACHE.popitem(last=False)
cached_exp = _JTI_CACHE.get(jti)
if cached_exp is not None:
# Refresh LRU position.
_JTI_CACHE.move_to_end(jti)
# Replay iff the token is already past its own expiry (with
# the same symmetric leeway PyJWT applied on signature check).
if now > cached_exp + _JWT_LEEWAY_SECONDS:
return True
return False
if len(_JTI_CACHE) >= _JTI_CACHE_MAX:
_JTI_CACHE.popitem(last=False)
_JTI_CACHE[jti] = float(exp)
return False
def resolve_mcp_jwt(token_string: str) -> dict:
"""Validate a signed JWT and return its claims dict.
Raises ``MCPAuthError`` on any failure. Does not touch ``MCPToken`` —
JWTs are stateless and stored only as their signing key (``MCPSigningKey``).
"""
try:
unverified_header = pyjwt.get_unverified_header(token_string)
except pyjwt.PyJWTError as exc:
raise MCPAuthError(f"Malformed JWT header: {exc}")
kid = unverified_header.get("kid")
if not kid:
raise MCPAuthError("JWT header missing kid.")
if unverified_header.get("alg") != "HS256":
raise MCPAuthError("Unsupported JWT alg.")
key = MCPSigningKey.objects.by_kid(kid)
if key is None:
raise MCPAuthError(f"Unknown signing key kid={kid!r}.")
if not key.is_active:
raise MCPAuthError(f"Signing key {kid!r} has been retired.")
try:
secret = bytes.fromhex(key.secret_hex)
except ValueError:
raise MCPAuthError(f"Stored secret for kid={kid!r} is not valid hex.")
try:
claims = pyjwt.decode(
token_string,
secret,
algorithms=["HS256"],
leeway=_JWT_LEEWAY_SECONDS,
options={"require": ["exp", "iat", "iss", "sub", "jti"]},
issuer=_JWT_ISS,
)
except pyjwt.ExpiredSignatureError:
raise MCPAuthError("Token has expired.")
except pyjwt.InvalidIssuerError:
raise MCPAuthError("Invalid token issuer.")
except pyjwt.InvalidSignatureError:
raise MCPAuthError("Invalid MCP token.")
except pyjwt.MissingRequiredClaimError as exc:
raise MCPAuthError(f"JWT missing required claim: {exc.claim}")
except pyjwt.PyJWTError as exc:
raise MCPAuthError(f"Invalid JWT: {exc}")
jti = claims.get("jti")
if not isinstance(jti, str) or not jti:
raise MCPAuthError("JWT jti must be a non-empty string.")
exp = claims.get("exp")
if not isinstance(exp, (int, float)):
# ``require=["exp", ...]`` above guarantees presence + numeric; this
# is defence in depth against future PyJWT changes.
raise MCPAuthError("JWT exp must be numeric.")
if _remember_jti(jti, float(exp)):
raise MCPAuthError("Token replay detected.")
# Normalize claim shapes: ws may be null/absent, libs default to [].
claims["ws"] = claims.get("ws") or None
libs = claims.get("libs") or []
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
raise MCPAuthError("JWT libs must be a list of strings.")
claims["libs"] = libs
claims["kid"] = kid
return claims
# --- Middleware ------------------------------------------------------------
class MCPAuthMiddleware(Middleware):
"""
FastMCP middleware that authenticates tool calls via Bearer tokens.
Listing tools/resources is permitted unauthenticated so clients can
discover the surface; calling a tool requires a valid token unless
MCP_REQUIRE_AUTH=False.
"""
# Tools that don't touch user data and must be callable without a token
# (e.g. Pallas health pollers, agent startup probes).
_PUBLIC_TOOLS = {"get_health"}
async def on_call_tool(self, context: MiddlewareContext, call_next):
require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True)
tool_name = self._extract_tool_name(context)
logger.info(
"mcp_auth.on_call_tool tool=%s require_auth=%s",
tool_name,
require_auth,
)
if require_auth and tool_name in self._PUBLIC_TOOLS:
return await self._call_next_with_trace(
tool_name, call_next, context
)
token_string = self._extract_token()
user = None
token = None
claims: dict | None = None
if token_string:
try:
if looks_like_jwt(token_string):
claims = await sync_to_async(
resolve_mcp_jwt, thread_sensitive=True
)(token_string)
user = await sync_to_async(
_resolve_jwt_actor, thread_sensitive=True
)(claims)
else:
user, token = await sync_to_async(
resolve_mcp_user, thread_sensitive=True
)(token_string)
except MCPAuthError as exc:
mcp_auth_failures_total.labels(reason=str(exc)).inc()
if require_auth:
raise PermissionError(str(exc))
elif require_auth:
mcp_auth_failures_total.labels(reason="missing_token").inc()
raise PermissionError("Authentication required. Provide a Bearer token.")
tool_name = self._extract_tool_name(context)
if token and tool_name and not token.can_use_tool(tool_name):
mcp_auth_failures_total.labels(reason="tool_not_allowed").inc()
raise PermissionError(
f"Token does not have permission to call '{tool_name}'."
)
fastmcp_ctx = getattr(context, "fastmcp_context", None)
if fastmcp_ctx is not None:
if user is not None:
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
if token is not None:
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
if claims is not None:
await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims)
return await self._call_next_with_trace(tool_name, call_next, context)
@staticmethod
async def _call_next_with_trace(tool_name, call_next, context):
"""Run ``call_next`` and log any exception with a full traceback.
During the Pallas↔Mnemosyne shakedown we were seeing tool results
come back to fast-agent as the string ``"object NoneType can't be
used in 'await' expression"`` with no trace anywhere in either
process. That string is Python's ``TypeError`` for ``await X``
where ``X`` is ``None`` (i.e. someone awaited a non-coroutine).
If that TypeError is raised inside the FastMCP dispatch we want
the full traceback in Mnemosyne's own log rather than silently
letting it propagate back to Pallas where the aggregator turns
it into a ``CallToolResult(isError=True)`` and we lose the frame.
We also log the *type* of the successful return so we can verify
later that FastMCP is returning ``ToolResult`` / ``CallToolResult``
the way we expect. Keep at INFO until green, demote to DEBUG
afterwards.
"""
try:
result = await call_next(context)
except Exception:
logger.exception(
"mcp_auth.call_next_failed tool=%s", tool_name
)
raise
logger.info(
"mcp_auth.call_next_ok tool=%s result_type=%s",
tool_name,
type(result).__name__,
)
return result
@staticmethod
def _extract_token() -> str | None:
"""Pull the Bearer token off the current HTTP request, if any.
The MCP SDK stores the Starlette ``Request`` on ``request_ctx`` for
every tool-call dispatch (including follow-up calls on a stateful
session), so ``get_http_request()`` should succeed here. If it
*doesn't* — e.g. because we're in a background task or pre-session
initialize hook — we return ``None`` and let the caller decide.
INFO-level logging is intentional until bearer-forwarding from
Daedalus/Pallas is fully shaken out; demote to DEBUG once green.
"""
try:
request = get_http_request()
except RuntimeError as exc:
logger.warning(
"mcp_auth.extract outcome=no_http_request reason=%r",
str(exc),
)
return None
# Header lookup is case-insensitive on Starlette ``Headers`` but
# different proxies normalize case differently — belt and braces.
auth_header = request.headers.get("Authorization", "")
if not auth_header:
auth_header = request.headers.get("authorization", "")
header_names = sorted(request.headers.keys())
logger.info(
"mcp_auth.extract outcome=%s len=%d prefix=%r path=%s header_names=%s",
"present" if auth_header else "missing",
len(auth_header),
auth_header[:16] if auth_header else "",
str(getattr(request, "url", "")),
header_names,
)
if auth_header.startswith("Bearer "):
return auth_header[7:].strip() or None
if auth_header.startswith("bearer "): # some clients lowercase
return auth_header[7:].strip() or None
return None
@staticmethod
def _extract_tool_name(context: MiddlewareContext) -> str | None:
"""Pull the tool name off a FastMCP ``on_call_tool`` context.
In FastMCP middleware, ``context.message`` inside ``on_call_tool``
is already a ``CallToolRequestParams`` (see
``fastmcp.server.middleware.middleware.Middleware.on_call_tool``
signature: ``MiddlewareContext[mt.CallToolRequestParams]``), so
``name`` lives directly on ``context.message`` — there is no
nested ``.params``. The older ``message.params.name`` access we
had here always returned ``None``, which caused the public-tools
bypass to silently miss ``get_health`` and made the per-tool ACL
short-circuit. Fall back to ``.params.name`` only as a legacy
safety net in case the shape ever diverges.
"""
msg = getattr(context, "message", None)
if msg is None:
return None
name = getattr(msg, "name", None)
if name:
return name
params = getattr(msg, "params", None)
return getattr(params, "name", None) if params is not None else None
# --- Helpers ---------------------------------------------------------------
def _resolve_jwt_actor(claims: dict):
"""Resolve the synthetic actor for a JWT-authenticated turn.
Returns the system service user (``MCP_JWT_SERVICE_USERNAME``, default
``daedalus-service``). The user must exist and be active. JWT tokens
are not tied to per-user accounts — claims encode all authorization.
"""
from django.contrib.auth import get_user_model
User = get_user_model()
username = getattr(settings, "MCP_JWT_SERVICE_USERNAME", "daedalus-service")
try:
user = User.objects.get(username=username)
except User.DoesNotExist:
raise MCPAuthError(
f"JWT service user {username!r} does not exist; provision via management command."
)
if not user.is_active:
raise MCPAuthError(f"JWT service user {username!r} is disabled.")
return user