Files
mnemosyne/mnemosyne/mcp_server/auth.py
Robert Helewka a2c885cf34
All checks were successful
CVE Scan & Docker Build / security-scan (push) Successful in 52s
CVE Scan & Docker Build / build-and-push (push) Successful in 2m32s
feat(library): add workspace-scoped search and JWT auth for Daedalus
- Extend library list endpoint with `include_workspace` and
  `with_item_count` query params to support Daedalus registry mirroring
- Expand search scope clause to three modes: workspace-only, workspace
  plus allowed user libraries, and global
- Add `allowed_libraries` field to SearchRequest for Phase-2 JWT claims
- Introduce JWT-based actor resolution using a synthetic service user
  (`MCP_JWT_SERVICE_USERNAME`) for Daedalus-originated requests
2026-05-03 17:36:06 -04:00

291 lines
10 KiB
Python

"""MCP token resolution and FastMCP middleware for bearer-token auth.
Two token shapes are supported:
* **Opaque** — the original `MCPToken` row. Long-lived, hashed at rest,
used by the dashboard / Claude Desktop / admin tooling. Plaintext
hashes to a row in `mcp_token`.
* **Signed JWT** — per-turn token minted by Daedalus. Carries
`{ws, libs}` claims. Validated entirely off the signature + claims;
no database lookup of the token itself, only of the signing key
(`MCPSigningKey`) referenced by the JWT header's `kid`.
Detection: a bearer with three base64url segments separated by dots and
a parseable `{"alg":"HS256","kid":...}` header is treated as JWT; anything
else falls through to the opaque path.
"""
from __future__ import annotations
import base64
import hashlib
import json
import logging
import time
from collections import OrderedDict
import jwt as pyjwt
from asgiref.sync import sync_to_async
from django.conf import settings
from django.utils import timezone
from fastmcp.server.dependencies import get_http_request
from fastmcp.server.middleware import Middleware, MiddlewareContext
from .metrics import mcp_auth_failures_total
from .models import MCPSigningKey, MCPToken, hash_token
logger = logging.getLogger(__name__)
STATE_KEY_USER = "mcp_user"
STATE_KEY_TOKEN = "mcp_token"
STATE_KEY_CLAIMS = "mcp_claims"
# Permitted clock skew when validating JWT exp/iat. PyJWT applies this
# symmetrically as ``leeway``.
_JWT_LEEWAY_SECONDS = 30
# Mnemosyne is the audience; Daedalus is the only accepted issuer.
_JWT_ISS = "daedalus"
# Bounded LRU of recently-seen jti values to discourage replay within
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
# is best-effort and per-process.
_JTI_CACHE_MAX = 4096
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
class MCPAuthError(Exception):
"""Raised when a bearer token cannot be resolved to a valid user."""
# --- Opaque token path -----------------------------------------------------
def resolve_mcp_user(token_string: str):
"""Resolve an opaque bearer token to (user, MCPToken).
Hashes the incoming bearer and looks up by the hash — plaintext is never
stored or compared directly.
"""
try:
token = (
MCPToken.objects
.select_related("user")
.get(token_hash=hash_token(token_string))
)
except MCPToken.DoesNotExist:
raise MCPAuthError("Invalid MCP token.")
if not token.is_active:
raise MCPAuthError("Token has been deactivated.")
if token.expires_at and token.expires_at < timezone.now():
raise MCPAuthError("Token has expired.")
if not token.user.is_active:
raise MCPAuthError("User account is disabled.")
token.record_usage()
return token.user, token
# --- JWT path --------------------------------------------------------------
def looks_like_jwt(token_string: str) -> bool:
"""Cheap pre-check: 3 dot-separated segments and a valid base64url header
that decodes to JSON with an alg field."""
parts = token_string.split(".")
if len(parts) != 3:
return False
try:
header_bytes = _b64url_decode(parts[0])
header = json.loads(header_bytes)
except (ValueError, json.JSONDecodeError):
return False
return isinstance(header, dict) and isinstance(header.get("alg"), str)
def _b64url_decode(segment: str) -> bytes:
pad = "=" * (-len(segment) % 4)
return base64.urlsafe_b64decode(segment + pad)
def _remember_jti(jti: str) -> bool:
"""Return True if this jti has been seen before in this process."""
now = time.time()
# Drop any cached jti whose entry is older than 1h — generous given exp ≤ 600s.
cutoff = now - 3600
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < cutoff:
_JTI_CACHE.popitem(last=False)
if jti in _JTI_CACHE:
return True
if len(_JTI_CACHE) >= _JTI_CACHE_MAX:
_JTI_CACHE.popitem(last=False)
_JTI_CACHE[jti] = now
return False
def resolve_mcp_jwt(token_string: str) -> dict:
"""Validate a signed JWT and return its claims dict.
Raises ``MCPAuthError`` on any failure. Does not touch ``MCPToken`` —
JWTs are stateless and stored only as their signing key (``MCPSigningKey``).
"""
try:
unverified_header = pyjwt.get_unverified_header(token_string)
except pyjwt.PyJWTError as exc:
raise MCPAuthError(f"Malformed JWT header: {exc}")
kid = unverified_header.get("kid")
if not kid:
raise MCPAuthError("JWT header missing kid.")
if unverified_header.get("alg") != "HS256":
raise MCPAuthError("Unsupported JWT alg.")
key = MCPSigningKey.objects.by_kid(kid)
if key is None:
raise MCPAuthError(f"Unknown signing key kid={kid!r}.")
if not key.is_active:
raise MCPAuthError(f"Signing key {kid!r} has been retired.")
try:
secret = bytes.fromhex(key.secret_hex)
except ValueError:
raise MCPAuthError(f"Stored secret for kid={kid!r} is not valid hex.")
try:
claims = pyjwt.decode(
token_string,
secret,
algorithms=["HS256"],
leeway=_JWT_LEEWAY_SECONDS,
options={"require": ["exp", "iat", "iss", "sub", "jti"]},
issuer=_JWT_ISS,
)
except pyjwt.ExpiredSignatureError:
raise MCPAuthError("Token has expired.")
except pyjwt.InvalidIssuerError:
raise MCPAuthError("Invalid token issuer.")
except pyjwt.InvalidSignatureError:
raise MCPAuthError("Invalid MCP token.")
except pyjwt.MissingRequiredClaimError as exc:
raise MCPAuthError(f"JWT missing required claim: {exc.claim}")
except pyjwt.PyJWTError as exc:
raise MCPAuthError(f"Invalid JWT: {exc}")
jti = claims.get("jti")
if not isinstance(jti, str) or not jti:
raise MCPAuthError("JWT jti must be a non-empty string.")
if _remember_jti(jti):
raise MCPAuthError("Token replay detected.")
# Normalize claim shapes: ws may be null/absent, libs default to [].
claims["ws"] = claims.get("ws") or None
libs = claims.get("libs") or []
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
raise MCPAuthError("JWT libs must be a list of strings.")
claims["libs"] = libs
claims["kid"] = kid
return claims
# --- Middleware ------------------------------------------------------------
class MCPAuthMiddleware(Middleware):
"""
FastMCP middleware that authenticates tool calls via Bearer tokens.
Listing tools/resources is permitted unauthenticated so clients can
discover the surface; calling a tool requires a valid token unless
MCP_REQUIRE_AUTH=False.
"""
async def on_call_tool(self, context: MiddlewareContext, call_next):
require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True)
token_string = self._extract_token()
user = None
token = None
claims: dict | None = None
if token_string:
try:
if looks_like_jwt(token_string):
claims = await sync_to_async(
resolve_mcp_jwt, thread_sensitive=True
)(token_string)
user = await sync_to_async(
_resolve_jwt_actor, thread_sensitive=True
)(claims)
else:
user, token = await sync_to_async(
resolve_mcp_user, thread_sensitive=True
)(token_string)
except MCPAuthError as exc:
mcp_auth_failures_total.labels(reason=str(exc)).inc()
if require_auth:
raise PermissionError(str(exc))
elif require_auth:
mcp_auth_failures_total.labels(reason="missing_token").inc()
raise PermissionError("Authentication required. Provide a Bearer token.")
tool_name = self._extract_tool_name(context)
if token and tool_name and not token.can_use_tool(tool_name):
mcp_auth_failures_total.labels(reason="tool_not_allowed").inc()
raise PermissionError(
f"Token does not have permission to call '{tool_name}'."
)
fastmcp_ctx = getattr(context, "fastmcp_context", None)
if fastmcp_ctx is not None:
if user is not None:
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
if token is not None:
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
if claims is not None:
await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims)
return await call_next(context)
@staticmethod
def _extract_token() -> str | None:
try:
request = get_http_request()
except RuntimeError:
return None
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
return auth_header[7:].strip() or None
return None
@staticmethod
def _extract_tool_name(context: MiddlewareContext) -> str | None:
msg = getattr(context, "message", None)
params = getattr(msg, "params", None) if msg else None
return getattr(params, "name", None)
# --- Helpers ---------------------------------------------------------------
def _resolve_jwt_actor(claims: dict):
"""Resolve the synthetic actor for a JWT-authenticated turn.
Returns the system service user (``MCP_JWT_SERVICE_USERNAME``, default
``daedalus-service``). The user must exist and be active. JWT tokens
are not tied to per-user accounts — claims encode all authorization.
"""
from django.contrib.auth import get_user_model
User = get_user_model()
username = getattr(settings, "MCP_JWT_SERVICE_USERNAME", "daedalus-service")
try:
user = User.objects.get(username=username)
except User.DoesNotExist:
raise MCPAuthError(
f"JWT service user {username!r} does not exist; provision via management command."
)
if not user.is_active:
raise MCPAuthError(f"JWT service user {username!r} is disabled.")
return user