Files
mnemosyne/mnemosyne/mcp_server/auth.py
Robert Helewka 93639188d3
Some checks failed
CVE Scan & Docker Build / build-and-push (push) Has been cancelled
CVE Scan & Docker Build / security-scan (push) Has been cancelled
Build & Deploy Docs / build-and-deploy (push) Successful in 1m10s
feat: rework auth model with UserToken and Daedalus/Pallas integration
- Rename MCPToken to UserToken across models, views, and tests
- Update URL names from mcp-token-* to token-*
- Add Daedalus/Pallas integration design doc (v2)
- Switch docker-compose to build local mnemosyne:local image via shared
  build config instead of pulling from git.helu.ca
2026-05-23 19:50:29 -04:00

605 lines
24 KiB
Python

"""MCP token resolution and FastMCP middleware for bearer-token auth.
Three credential types are accepted — see
``docs/DAEDALUS_PALLAS_INTEGRATION_v1.md`` §3.2 for the full model:
1. **Opaque ``UserToken``** (long-lived, hashed at rest). Authorization
scope is its ``allowed_libraries`` JSON list.
2. **Per-turn signed JWT** (``iss=daedalus``, ≤10 min, legacy — retires
in Phase 4 when Daedalus chat itself becomes a Pallas Team). Scope
is the ``libs`` claim.
3. **Team JWT** (``iss=mnemosyne``, ``typ=team``, 10-year lifetime).
Scope is resolved live by joining ``TeamWorkspaceAssignment`` rows
to Neo4j ``Library.workspace_id``.
Every branch populates a single :data:`STATE_KEY_RESOLVED_LIBRARIES`
value on the FastMCP context — a ``list[str]`` of Library UIDs the
downstream tools are permitted to read. Tools never consult claim
shapes; they read this list via
``mcp_server.context.get_mcp_resolved_libraries``.
Detection: a bearer with three base64url segments separated by dots and
a parseable ``{"alg":"HS256","kid":...}`` header is treated as JWT; anything
else falls through to the opaque path.
"""
from __future__ import annotations
import base64
import json
import logging
import time
import uuid
from collections import OrderedDict
import jwt as pyjwt
from asgiref.sync import sync_to_async
from django.conf import settings
from django.utils import timezone
from fastmcp.server.dependencies import get_http_request
from fastmcp.server.middleware import Middleware, MiddlewareContext
from .metrics import mcp_auth_failures_total
from .models import MCPSigningKey, UserToken, Team, hash_token
logger = logging.getLogger(__name__)
STATE_KEY_USER = "mcp_user"
STATE_KEY_TOKEN = "mcp_token"
STATE_KEY_CLAIMS = "mcp_claims"
STATE_KEY_RESOLVED_LIBRARIES = "mcp_resolved_libraries"
# Permitted clock skew when validating JWT exp/iat. PyJWT applies this
# symmetrically as ``leeway``.
_JWT_LEEWAY_SECONDS = 30
# Accepted JWT issuers.
#
# ``daedalus`` — per-turn tokens minted by Daedalus chat (legacy path,
# retires with Phase 4).
# ``mnemosyne`` — team tokens minted by this service. ``typ=team``
# distinguishes them from any future self-issued credential.
_JWT_ISS_VALUES = {"daedalus", "mnemosyne"}
# Bounded LRU of recently-seen jti values to discourage replay within
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
# is best-effort and per-process.
#
# Daedalus mints one JWT per *chat turn*, and a single turn routinely
# drives several tool calls (list_libraries → search → get_document …)
# that all re-use that bearer. Treating any repeat ``jti`` as replay
# breaks the legitimate use case, so we store ``jti -> exp`` and only
# flag a request as replay if the ``jti`` shows up *after* its own
# ``exp`` has passed — that's the scenario PyJWT's own ``exp`` check
# would have already rejected, this is belt-and-braces for clock skew
# or a resurrected captured token.
#
# Team tokens (``typ=team``) bypass this cache entirely — they are
# reused on every request by design. Revocation for those tokens runs
# against the live ``Team`` row (``active`` + ``active_jti``).
_JTI_CACHE_MAX = 4096
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
class MCPAuthError(Exception):
"""Raised when a bearer token cannot be resolved to a valid user."""
# --- Opaque token path -----------------------------------------------------
def resolve_mcp_user(token_string: str):
"""Resolve an opaque bearer token to (user, UserToken).
Hashes the incoming bearer and looks up by the hash — plaintext is never
stored or compared directly.
"""
try:
token = (
UserToken.objects
.select_related("user")
.get(token_hash=hash_token(token_string))
)
except UserToken.DoesNotExist:
raise MCPAuthError("Invalid MCP token.")
if not token.is_active:
raise MCPAuthError("Token has been deactivated.")
if token.expires_at and token.expires_at < timezone.now():
raise MCPAuthError("Token has expired.")
if not token.user.is_active:
raise MCPAuthError("User account is disabled.")
token.record_usage()
return token.user, token
# --- JWT path --------------------------------------------------------------
def looks_like_jwt(token_string: str) -> bool:
"""Cheap pre-check: 3 dot-separated segments and a valid base64url header
that decodes to JSON with an alg field."""
parts = token_string.split(".")
if len(parts) != 3:
return False
try:
header_bytes = _b64url_decode(parts[0])
header = json.loads(header_bytes)
except (ValueError, json.JSONDecodeError):
return False
return isinstance(header, dict) and isinstance(header.get("alg"), str)
def _b64url_decode(segment: str) -> bytes:
pad = "=" * (-len(segment) % 4)
return base64.urlsafe_b64decode(segment + pad)
def _remember_jti(jti: str, exp: float) -> bool:
"""Record or re-check a jti; return True iff this is a genuine replay.
A "genuine replay" is a ``jti`` we've already recorded whose ``exp``
has passed (+ leeway) — that token *should not* be presentable
anymore and PyJWT's ``exp`` check would normally have rejected it;
this is the belt-and-braces path for clock skew or a captured
token being resurrected slightly after its natural lifetime.
Re-use of the same ``jti`` *within* its validity window is the
intended case: Daedalus mints one token per chat turn and a turn
commonly fires several Mnemosyne tool calls against that same
bearer. That is **not** replay and we return False.
"""
now = time.time()
# GC anything whose exp passed more than an hour ago — generous
# given exp ≤ 600s, bounds the cache even for pathological jti
# floods.
gc_cutoff = now - 3600
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < gc_cutoff:
_JTI_CACHE.popitem(last=False)
cached_exp = _JTI_CACHE.get(jti)
if cached_exp is not None:
# Refresh LRU position.
_JTI_CACHE.move_to_end(jti)
# Replay iff the token is already past its own expiry (with
# the same symmetric leeway PyJWT applied on signature check).
if now > cached_exp + _JWT_LEEWAY_SECONDS:
return True
return False
if len(_JTI_CACHE) >= _JTI_CACHE_MAX:
_JTI_CACHE.popitem(last=False)
_JTI_CACHE[jti] = float(exp)
return False
def resolve_mcp_jwt(token_string: str) -> dict:
"""Validate a signed JWT and return its claims dict.
Accepts both the legacy per-turn issuer (``iss=daedalus``) and the
new team issuer (``iss=mnemosyne``, ``typ=team``). The returned
claims dict is normalized so the middleware doesn't have to guess:
* ``claims["iss"]`` — as presented (``daedalus`` or ``mnemosyne``).
* ``claims["typ"]`` — ``"team"`` for team tokens, otherwise absent.
* ``claims["libs"]`` — per-turn only; normalized to ``list[str]``.
* ``claims["ws"]`` — per-turn only; may be ``None``. Not consulted
for authorization (kept for diagnostics).
* ``claims["team_id"]`` — team only; ``UUID`` parsed from
``sub == "team:<uuid>"``.
* ``claims["kid"]`` — copy of the JWT header's ``kid``.
Raises :class:`MCPAuthError` on any failure. The per-turn path runs
the ``_remember_jti`` replay check; the team path skips it (team
JWTs are intentionally reused across the token's lifetime).
"""
try:
unverified_header = pyjwt.get_unverified_header(token_string)
except pyjwt.PyJWTError as exc:
raise MCPAuthError(f"Malformed JWT header: {exc}")
kid = unverified_header.get("kid")
if not kid:
raise MCPAuthError("JWT header missing kid.")
if unverified_header.get("alg") != "HS256":
raise MCPAuthError("Unsupported JWT alg.")
key = MCPSigningKey.objects.by_kid(kid)
if key is None:
raise MCPAuthError(f"Unknown signing key kid={kid!r}.")
if not key.is_active:
raise MCPAuthError(f"Signing key {kid!r} has been retired.")
try:
secret = bytes.fromhex(key.secret_hex)
except ValueError:
raise MCPAuthError(f"Stored secret for kid={kid!r} is not valid hex.")
try:
claims = pyjwt.decode(
token_string,
secret,
algorithms=["HS256"],
leeway=_JWT_LEEWAY_SECONDS,
options={
"require": ["exp", "iat", "iss", "sub", "jti"],
# Team JWTs carry ``aud=mnemosyne`` for informational
# purposes; per-turn JWTs omit ``aud`` entirely. We
# don't enforce either shape because ``iss`` + ``typ``
# already partition the two token populations.
"verify_aud": False,
},
# ``issuer=`` accepts ``str | Iterable[str]`` and raises
# ``InvalidIssuerError`` if the claim is outside the set.
issuer=list(_JWT_ISS_VALUES),
)
except pyjwt.ExpiredSignatureError:
raise MCPAuthError("Token has expired.")
except pyjwt.InvalidIssuerError:
raise MCPAuthError("Invalid token issuer.")
except pyjwt.InvalidSignatureError:
raise MCPAuthError("Invalid MCP token.")
except pyjwt.MissingRequiredClaimError as exc:
raise MCPAuthError(f"JWT missing required claim: {exc.claim}")
except pyjwt.PyJWTError as exc:
raise MCPAuthError(f"Invalid JWT: {exc}")
jti = claims.get("jti")
if not isinstance(jti, str) or not jti:
raise MCPAuthError("JWT jti must be a non-empty string.")
exp = claims.get("exp")
if not isinstance(exp, (int, float)):
# ``require=["exp", ...]`` above guarantees presence + numeric; this
# is defence in depth against future PyJWT changes.
raise MCPAuthError("JWT exp must be numeric.")
typ = claims.get("typ")
if typ == "team":
# Team tokens: no replay cache, no ``libs`` or ``ws`` claims.
# Verify the ``sub`` shape and parse the embedded team UUID so
# the middleware doesn't have to re-parse it later.
sub = claims.get("sub")
if not isinstance(sub, str) or not sub.startswith("team:"):
raise MCPAuthError("Invalid MCP token.")
try:
claims["team_id"] = uuid.UUID(sub[len("team:"):])
except ValueError:
raise MCPAuthError("Invalid MCP token.")
else:
# Per-turn (legacy) path: replay-cache gate + normalize claims.
if _remember_jti(jti, float(exp)):
raise MCPAuthError("Token replay detected.")
claims["ws"] = claims.get("ws") or None
libs = claims.get("libs") or []
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
raise MCPAuthError("JWT libs must be a list of strings.")
claims["libs"] = libs
claims["kid"] = kid
return claims
# --- Team-JWT library resolution ------------------------------------------
def _libraries_for_team(team_id: uuid.UUID, jti: str) -> list[str]:
"""Resolve a team token to the Library UIDs it may read.
Runs two cheap queries in sequence:
1. Fetch the ``Team`` row by UUID. Reject if it doesn't exist, is
inactive, or its ``active_jti`` doesn't match the incoming
``jti`` — this is how rotation / soft-delete revocation becomes
effective on the *next* request.
2. If active: read every ``TeamWorkspaceAssignment.workspace_id`` for
the team and translate them into Library UIDs via a single
Cypher query against Neo4j.
Returns an empty list when the team has no workspace assignments
(fail-closed — a team pointing at no workspaces sees no libraries).
"""
try:
team = Team.objects.get(pk=team_id)
except Team.DoesNotExist:
raise MCPAuthError("Invalid MCP token.")
if not team.active:
raise MCPAuthError("Token has been deactivated.")
if team.active_jti is None or str(team.active_jti) != jti:
raise MCPAuthError("Invalid MCP token.")
workspace_ids = list(
team.workspace_assignments.values_list("workspace_id", flat=True)
)
if not workspace_ids:
return []
from neomodel import db
rows, _ = db.cypher_query(
"MATCH (l:Library) WHERE l.workspace_id IN $ws RETURN l.uid",
{"ws": workspace_ids},
)
return [row[0] for row in rows if row and row[0]]
# --- Middleware ------------------------------------------------------------
class MCPAuthMiddleware(Middleware):
"""
FastMCP middleware that authenticates tool calls via Bearer tokens.
Listing tools/resources is permitted unauthenticated so clients can
discover the surface; calling a tool requires a valid token unless
``MCP_REQUIRE_AUTH=False``.
On every authenticated call the middleware attaches four values to
the FastMCP ``Context`` state for downstream tools to consume via
:mod:`mcp_server.context`:
* ``STATE_KEY_USER`` — Django user.
* ``STATE_KEY_TOKEN`` — UserToken row (opaque callers only).
* ``STATE_KEY_CLAIMS`` — JWT claims dict (JWT callers only).
* ``STATE_KEY_RESOLVED_LIBRARIES`` — authorization-resolved Library
UID list. Tools read this; they never read ``STATE_KEY_CLAIMS``
for authorization.
"""
# Tools that don't touch user data and must be callable without a token
# (e.g. Pallas health pollers, agent startup probes).
_PUBLIC_TOOLS = {"get_health"}
async def on_call_tool(self, context: MiddlewareContext, call_next):
require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True)
tool_name = self._extract_tool_name(context)
logger.info(
"mcp_auth.on_call_tool tool=%s require_auth=%s",
tool_name,
require_auth,
)
if require_auth and tool_name in self._PUBLIC_TOOLS:
return await self._call_next_with_trace(
tool_name, call_next, context
)
token_string = self._extract_token()
user = None
token = None
claims: dict | None = None
resolved_libraries: list[str] | None = None
if token_string:
try:
if looks_like_jwt(token_string):
claims = await sync_to_async(
resolve_mcp_jwt, thread_sensitive=True
)(token_string)
user = await sync_to_async(
_resolve_jwt_actor, thread_sensitive=True
)(claims)
resolved_libraries = await sync_to_async(
_resolved_libraries_for_jwt, thread_sensitive=True
)(claims)
else:
user, token = await sync_to_async(
resolve_mcp_user, thread_sensitive=True
)(token_string)
# Opaque tokens store the Library UID list directly.
# Empty list = fail-closed; not "everything".
resolved_libraries = list(token.allowed_libraries or [])
except MCPAuthError as exc:
mcp_auth_failures_total.labels(reason=str(exc)).inc()
if require_auth:
raise PermissionError(str(exc))
elif require_auth:
mcp_auth_failures_total.labels(reason="missing_token").inc()
raise PermissionError("Authentication required. Provide a Bearer token.")
if token and tool_name and not token.can_use_tool(tool_name):
mcp_auth_failures_total.labels(reason="tool_not_allowed").inc()
raise PermissionError(
f"Token does not have permission to call '{tool_name}'."
)
fastmcp_ctx = getattr(context, "fastmcp_context", None)
if fastmcp_ctx is not None:
# ``Context.set_state`` is a synchronous method in FastMCP; it
# stores into ``self._state`` and returns ``None``. Awaiting its
# return value raises ``TypeError: object NoneType can't be used
# in 'await' expression`` which propagates through FastMCP's
# dispatch as an opaque string-valued ``CallToolResult`` —
# exactly the symptom documented in ``pallas._fastagent_patch``.
# Call synchronously.
if user is not None:
fastmcp_ctx.set_state(STATE_KEY_USER, user)
if token is not None:
fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
if claims is not None:
fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims)
# Always publish resolved_libraries — None means "no auth
# information" and the tools treat that as fail-closed.
fastmcp_ctx.set_state(
STATE_KEY_RESOLVED_LIBRARIES, resolved_libraries
)
logger.info(
"mcp_auth.resolved tool=%s principal=%s lib_count=%s",
tool_name,
self._describe_principal(user, token, claims),
"none" if resolved_libraries is None else len(resolved_libraries),
)
return await self._call_next_with_trace(tool_name, call_next, context)
@staticmethod
async def _call_next_with_trace(tool_name, call_next, context):
"""Run ``call_next`` and log any exception with a full traceback.
During the Pallas↔Mnemosyne shakedown we were seeing tool results
come back to fast-agent as the string ``"object NoneType can't be
used in 'await' expression"`` with no trace anywhere in either
process. That string is Python's ``TypeError`` for ``await X``
where ``X`` is ``None`` (i.e. someone awaited a non-coroutine).
If that TypeError is raised inside the FastMCP dispatch we want
the full traceback in Mnemosyne's own log rather than silently
letting it propagate back to Pallas where the aggregator turns
it into a ``CallToolResult(isError=True)`` and we lose the frame.
We also log the *type* of the successful return so we can verify
later that FastMCP is returning ``ToolResult`` / ``CallToolResult``
the way we expect. Keep at INFO until green, demote to DEBUG
afterwards.
"""
try:
result = await call_next(context)
except Exception:
logger.exception(
"mcp_auth.call_next_failed tool=%s", tool_name
)
raise
logger.info(
"mcp_auth.call_next_ok tool=%s result_type=%s",
tool_name,
type(result).__name__,
)
return result
@staticmethod
def _describe_principal(user, token, claims) -> str:
"""Compact, log-friendly principal summary. No PII beyond usernames."""
if claims is not None:
typ = claims.get("typ")
if typ == "team":
return f"team:{claims.get('team_id')}"
return f"jwt:{claims.get('sub')}"
if token is not None:
return f"mcptoken:{token.get_masked_token()}"
if user is not None:
return f"user:{user.username}"
return "anonymous"
@staticmethod
def _extract_token() -> str | None:
"""Pull the Bearer token off the current HTTP request, if any.
The MCP SDK stores the Starlette ``Request`` on ``request_ctx`` for
every tool-call dispatch (including follow-up calls on a stateful
session), so ``get_http_request()`` should succeed here. If it
*doesn't* — e.g. because we're in a background task or pre-session
initialize hook — we return ``None`` and let the caller decide.
INFO-level logging is intentional until bearer-forwarding from
Daedalus/Pallas is fully shaken out; demote to DEBUG once green.
"""
try:
request = get_http_request()
except RuntimeError as exc:
logger.warning(
"mcp_auth.extract outcome=no_http_request reason=%r",
str(exc),
)
return None
# Header lookup is case-insensitive on Starlette ``Headers`` but
# different proxies normalize case differently — belt and braces.
auth_header = request.headers.get("Authorization", "")
if not auth_header:
auth_header = request.headers.get("authorization", "")
header_names = sorted(request.headers.keys())
logger.info(
"mcp_auth.extract outcome=%s len=%d prefix=%r path=%s header_names=%s",
"present" if auth_header else "missing",
len(auth_header),
auth_header[:16] if auth_header else "",
str(getattr(request, "url", "")),
header_names,
)
if auth_header.startswith("Bearer "):
return auth_header[7:].strip() or None
if auth_header.startswith("bearer "): # some clients lowercase
return auth_header[7:].strip() or None
return None
@staticmethod
def _extract_tool_name(context: MiddlewareContext) -> str | None:
"""Pull the tool name off a FastMCP ``on_call_tool`` context.
In FastMCP middleware, ``context.message`` inside ``on_call_tool``
is already a ``CallToolRequestParams`` (see
``fastmcp.server.middleware.middleware.Middleware.on_call_tool``
signature: ``MiddlewareContext[mt.CallToolRequestParams]``), so
``name`` lives directly on ``context.message`` — there is no
nested ``.params``. The older ``message.params.name`` access we
had here always returned ``None``, which caused the public-tools
bypass to silently miss ``get_health`` and made the per-tool ACL
short-circuit. Fall back to ``.params.name`` only as a legacy
safety net in case the shape ever diverges.
"""
msg = getattr(context, "message", None)
if msg is None:
return None
name = getattr(msg, "name", None)
if name:
return name
params = getattr(msg, "params", None)
return getattr(params, "name", None) if params is not None else None
# --- Helpers ---------------------------------------------------------------
def _resolve_jwt_actor(claims: dict):
"""Resolve the acting user for a JWT-authenticated turn.
For ``typ=team`` JWTs (the only kind we mint), the actor is the
``Team.owner`` — the Mnemosyne user that created the team. Usage
accounting and the audit trail attribute the turn to that user.
For legacy per-turn JWTs (``iss=daedalus``, retiring in Phase 4), no
user binding exists in the claims; we cannot attribute the turn to a
Mnemosyne user and the path is rejected. If a deployment still
accepts per-turn JWTs, that work needs to land first.
"""
if claims.get("typ") != "team":
raise MCPAuthError(
"Per-turn JWTs are no longer accepted; mint a team JWT."
)
# resolve_mcp_jwt has already parsed the team UUID out of ``sub`` and
# stashed it as ``team_id`` for the team branch.
team_id = claims.get("team_id")
if team_id is None:
raise MCPAuthError("Team JWT missing team_id claim.")
try:
team = Team.objects.select_related("owner").get(pk=team_id)
except Team.DoesNotExist:
raise MCPAuthError("Team JWT references a team that no longer exists.")
if not team.active:
raise MCPAuthError("Team JWT references an inactive team.")
if not team.owner.is_active:
raise MCPAuthError(
f"Team owner {team.owner.username!r} is disabled."
)
return team.owner
def _resolved_libraries_for_jwt(claims: dict) -> list[str]:
"""Pick the right resolver branch for a validated JWT claims dict.
* ``typ == "team"`` → live lookup via :func:`_libraries_for_team`.
* otherwise → legacy ``claims["libs"]`` (per-turn JWT).
"""
if claims.get("typ") == "team":
return _libraries_for_team(claims["team_id"], claims["jti"])
return list(claims.get("libs") or [])