- Extend library list endpoint with `include_workspace` and `with_item_count` query params to support Daedalus registry mirroring - Expand search scope clause to three modes: workspace-only, workspace plus allowed user libraries, and global - Add `allowed_libraries` field to SearchRequest for Phase-2 JWT claims - Introduce JWT-based actor resolution using a synthetic service user (`MCP_JWT_SERVICE_USERNAME`) for Daedalus-originated requests
291 lines
10 KiB
Python
291 lines
10 KiB
Python
"""MCP token resolution and FastMCP middleware for bearer-token auth.
|
|
|
|
Two token shapes are supported:
|
|
|
|
* **Opaque** — the original `MCPToken` row. Long-lived, hashed at rest,
|
|
used by the dashboard / Claude Desktop / admin tooling. Plaintext
|
|
hashes to a row in `mcp_token`.
|
|
* **Signed JWT** — per-turn token minted by Daedalus. Carries
|
|
`{ws, libs}` claims. Validated entirely off the signature + claims;
|
|
no database lookup of the token itself, only of the signing key
|
|
(`MCPSigningKey`) referenced by the JWT header's `kid`.
|
|
|
|
Detection: a bearer with three base64url segments separated by dots and
|
|
a parseable `{"alg":"HS256","kid":...}` header is treated as JWT; anything
|
|
else falls through to the opaque path.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import time
|
|
from collections import OrderedDict
|
|
|
|
import jwt as pyjwt
|
|
from asgiref.sync import sync_to_async
|
|
from django.conf import settings
|
|
from django.utils import timezone
|
|
from fastmcp.server.dependencies import get_http_request
|
|
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
|
|
|
from .metrics import mcp_auth_failures_total
|
|
from .models import MCPSigningKey, MCPToken, hash_token
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
STATE_KEY_USER = "mcp_user"
|
|
STATE_KEY_TOKEN = "mcp_token"
|
|
STATE_KEY_CLAIMS = "mcp_claims"
|
|
|
|
# Permitted clock skew when validating JWT exp/iat. PyJWT applies this
|
|
# symmetrically as ``leeway``.
|
|
_JWT_LEEWAY_SECONDS = 30
|
|
|
|
# Mnemosyne is the audience; Daedalus is the only accepted issuer.
|
|
_JWT_ISS = "daedalus"
|
|
|
|
# Bounded LRU of recently-seen jti values to discourage replay within
|
|
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
|
|
# is best-effort and per-process.
|
|
_JTI_CACHE_MAX = 4096
|
|
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
|
|
|
|
|
|
class MCPAuthError(Exception):
|
|
"""Raised when a bearer token cannot be resolved to a valid user."""
|
|
|
|
|
|
# --- Opaque token path -----------------------------------------------------
|
|
|
|
|
|
def resolve_mcp_user(token_string: str):
|
|
"""Resolve an opaque bearer token to (user, MCPToken).
|
|
|
|
Hashes the incoming bearer and looks up by the hash — plaintext is never
|
|
stored or compared directly.
|
|
"""
|
|
try:
|
|
token = (
|
|
MCPToken.objects
|
|
.select_related("user")
|
|
.get(token_hash=hash_token(token_string))
|
|
)
|
|
except MCPToken.DoesNotExist:
|
|
raise MCPAuthError("Invalid MCP token.")
|
|
|
|
if not token.is_active:
|
|
raise MCPAuthError("Token has been deactivated.")
|
|
if token.expires_at and token.expires_at < timezone.now():
|
|
raise MCPAuthError("Token has expired.")
|
|
if not token.user.is_active:
|
|
raise MCPAuthError("User account is disabled.")
|
|
|
|
token.record_usage()
|
|
return token.user, token
|
|
|
|
|
|
# --- JWT path --------------------------------------------------------------
|
|
|
|
|
|
def looks_like_jwt(token_string: str) -> bool:
|
|
"""Cheap pre-check: 3 dot-separated segments and a valid base64url header
|
|
that decodes to JSON with an alg field."""
|
|
parts = token_string.split(".")
|
|
if len(parts) != 3:
|
|
return False
|
|
try:
|
|
header_bytes = _b64url_decode(parts[0])
|
|
header = json.loads(header_bytes)
|
|
except (ValueError, json.JSONDecodeError):
|
|
return False
|
|
return isinstance(header, dict) and isinstance(header.get("alg"), str)
|
|
|
|
|
|
def _b64url_decode(segment: str) -> bytes:
|
|
pad = "=" * (-len(segment) % 4)
|
|
return base64.urlsafe_b64decode(segment + pad)
|
|
|
|
|
|
def _remember_jti(jti: str) -> bool:
|
|
"""Return True if this jti has been seen before in this process."""
|
|
now = time.time()
|
|
# Drop any cached jti whose entry is older than 1h — generous given exp ≤ 600s.
|
|
cutoff = now - 3600
|
|
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < cutoff:
|
|
_JTI_CACHE.popitem(last=False)
|
|
if jti in _JTI_CACHE:
|
|
return True
|
|
if len(_JTI_CACHE) >= _JTI_CACHE_MAX:
|
|
_JTI_CACHE.popitem(last=False)
|
|
_JTI_CACHE[jti] = now
|
|
return False
|
|
|
|
|
|
def resolve_mcp_jwt(token_string: str) -> dict:
|
|
"""Validate a signed JWT and return its claims dict.
|
|
|
|
Raises ``MCPAuthError`` on any failure. Does not touch ``MCPToken`` —
|
|
JWTs are stateless and stored only as their signing key (``MCPSigningKey``).
|
|
"""
|
|
try:
|
|
unverified_header = pyjwt.get_unverified_header(token_string)
|
|
except pyjwt.PyJWTError as exc:
|
|
raise MCPAuthError(f"Malformed JWT header: {exc}")
|
|
|
|
kid = unverified_header.get("kid")
|
|
if not kid:
|
|
raise MCPAuthError("JWT header missing kid.")
|
|
if unverified_header.get("alg") != "HS256":
|
|
raise MCPAuthError("Unsupported JWT alg.")
|
|
|
|
key = MCPSigningKey.objects.by_kid(kid)
|
|
if key is None:
|
|
raise MCPAuthError(f"Unknown signing key kid={kid!r}.")
|
|
if not key.is_active:
|
|
raise MCPAuthError(f"Signing key {kid!r} has been retired.")
|
|
|
|
try:
|
|
secret = bytes.fromhex(key.secret_hex)
|
|
except ValueError:
|
|
raise MCPAuthError(f"Stored secret for kid={kid!r} is not valid hex.")
|
|
|
|
try:
|
|
claims = pyjwt.decode(
|
|
token_string,
|
|
secret,
|
|
algorithms=["HS256"],
|
|
leeway=_JWT_LEEWAY_SECONDS,
|
|
options={"require": ["exp", "iat", "iss", "sub", "jti"]},
|
|
issuer=_JWT_ISS,
|
|
)
|
|
except pyjwt.ExpiredSignatureError:
|
|
raise MCPAuthError("Token has expired.")
|
|
except pyjwt.InvalidIssuerError:
|
|
raise MCPAuthError("Invalid token issuer.")
|
|
except pyjwt.InvalidSignatureError:
|
|
raise MCPAuthError("Invalid MCP token.")
|
|
except pyjwt.MissingRequiredClaimError as exc:
|
|
raise MCPAuthError(f"JWT missing required claim: {exc.claim}")
|
|
except pyjwt.PyJWTError as exc:
|
|
raise MCPAuthError(f"Invalid JWT: {exc}")
|
|
|
|
jti = claims.get("jti")
|
|
if not isinstance(jti, str) or not jti:
|
|
raise MCPAuthError("JWT jti must be a non-empty string.")
|
|
if _remember_jti(jti):
|
|
raise MCPAuthError("Token replay detected.")
|
|
|
|
# Normalize claim shapes: ws may be null/absent, libs default to [].
|
|
claims["ws"] = claims.get("ws") or None
|
|
libs = claims.get("libs") or []
|
|
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
|
|
raise MCPAuthError("JWT libs must be a list of strings.")
|
|
claims["libs"] = libs
|
|
claims["kid"] = kid
|
|
return claims
|
|
|
|
|
|
# --- Middleware ------------------------------------------------------------
|
|
|
|
|
|
class MCPAuthMiddleware(Middleware):
|
|
"""
|
|
FastMCP middleware that authenticates tool calls via Bearer tokens.
|
|
|
|
Listing tools/resources is permitted unauthenticated so clients can
|
|
discover the surface; calling a tool requires a valid token unless
|
|
MCP_REQUIRE_AUTH=False.
|
|
"""
|
|
|
|
async def on_call_tool(self, context: MiddlewareContext, call_next):
|
|
require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True)
|
|
token_string = self._extract_token()
|
|
|
|
user = None
|
|
token = None
|
|
claims: dict | None = None
|
|
|
|
if token_string:
|
|
try:
|
|
if looks_like_jwt(token_string):
|
|
claims = await sync_to_async(
|
|
resolve_mcp_jwt, thread_sensitive=True
|
|
)(token_string)
|
|
user = await sync_to_async(
|
|
_resolve_jwt_actor, thread_sensitive=True
|
|
)(claims)
|
|
else:
|
|
user, token = await sync_to_async(
|
|
resolve_mcp_user, thread_sensitive=True
|
|
)(token_string)
|
|
except MCPAuthError as exc:
|
|
mcp_auth_failures_total.labels(reason=str(exc)).inc()
|
|
if require_auth:
|
|
raise PermissionError(str(exc))
|
|
elif require_auth:
|
|
mcp_auth_failures_total.labels(reason="missing_token").inc()
|
|
raise PermissionError("Authentication required. Provide a Bearer token.")
|
|
|
|
tool_name = self._extract_tool_name(context)
|
|
if token and tool_name and not token.can_use_tool(tool_name):
|
|
mcp_auth_failures_total.labels(reason="tool_not_allowed").inc()
|
|
raise PermissionError(
|
|
f"Token does not have permission to call '{tool_name}'."
|
|
)
|
|
|
|
fastmcp_ctx = getattr(context, "fastmcp_context", None)
|
|
if fastmcp_ctx is not None:
|
|
if user is not None:
|
|
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
|
|
if token is not None:
|
|
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
|
|
if claims is not None:
|
|
await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims)
|
|
|
|
return await call_next(context)
|
|
|
|
@staticmethod
|
|
def _extract_token() -> str | None:
|
|
try:
|
|
request = get_http_request()
|
|
except RuntimeError:
|
|
return None
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if auth_header.startswith("Bearer "):
|
|
return auth_header[7:].strip() or None
|
|
return None
|
|
|
|
@staticmethod
|
|
def _extract_tool_name(context: MiddlewareContext) -> str | None:
|
|
msg = getattr(context, "message", None)
|
|
params = getattr(msg, "params", None) if msg else None
|
|
return getattr(params, "name", None)
|
|
|
|
|
|
# --- Helpers ---------------------------------------------------------------
|
|
|
|
|
|
def _resolve_jwt_actor(claims: dict):
|
|
"""Resolve the synthetic actor for a JWT-authenticated turn.
|
|
|
|
Returns the system service user (``MCP_JWT_SERVICE_USERNAME``, default
|
|
``daedalus-service``). The user must exist and be active. JWT tokens
|
|
are not tied to per-user accounts — claims encode all authorization.
|
|
"""
|
|
from django.contrib.auth import get_user_model
|
|
|
|
User = get_user_model()
|
|
username = getattr(settings, "MCP_JWT_SERVICE_USERNAME", "daedalus-service")
|
|
try:
|
|
user = User.objects.get(username=username)
|
|
except User.DoesNotExist:
|
|
raise MCPAuthError(
|
|
f"JWT service user {username!r} does not exist; provision via management command."
|
|
)
|
|
if not user.is_active:
|
|
raise MCPAuthError(f"JWT service user {username!r} is disabled.")
|
|
return user
|