feat(library): add workspace-scoped search and JWT auth for Daedalus
- Extend library list endpoint with `include_workspace` and `with_item_count` query params to support Daedalus registry mirroring - Expand search scope clause to three modes: workspace-only, workspace plus allowed user libraries, and global - Add `allowed_libraries` field to SearchRequest for Phase-2 JWT claims - Introduce JWT-based actor resolution using a synthetic service user (`MCP_JWT_SERVICE_USERNAME`) for Daedalus-originated requests
This commit is contained in:
@@ -1,31 +1,68 @@
|
||||
"""MCP token resolution and FastMCP middleware for bearer-token auth."""
|
||||
"""MCP token resolution and FastMCP middleware for bearer-token auth.
|
||||
|
||||
Two token shapes are supported:
|
||||
|
||||
* **Opaque** — the original `MCPToken` row. Long-lived, hashed at rest,
|
||||
used by the dashboard / Claude Desktop / admin tooling. Plaintext
|
||||
hashes to a row in `mcp_token`.
|
||||
* **Signed JWT** — per-turn token minted by Daedalus. Carries
|
||||
`{ws, libs}` claims. Validated entirely off the signature + claims;
|
||||
no database lookup of the token itself, only of the signing key
|
||||
(`MCPSigningKey`) referenced by the JWT header's `kid`.
|
||||
|
||||
Detection: a bearer with three base64url segments separated by dots and
|
||||
a parseable `{"alg":"HS256","kid":...}` header is treated as JWT; anything
|
||||
else falls through to the opaque path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
import jwt as pyjwt
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.utils import timezone
|
||||
from fastmcp.server.dependencies import get_http_request
|
||||
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
||||
|
||||
from .metrics import mcp_auth_failures_total
|
||||
from .models import MCPToken, hash_token
|
||||
from .models import MCPSigningKey, MCPToken, hash_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STATE_KEY_USER = "mcp_user"
|
||||
STATE_KEY_TOKEN = "mcp_token"
|
||||
STATE_KEY_CLAIMS = "mcp_claims"
|
||||
|
||||
# Permitted clock skew when validating JWT exp/iat. PyJWT applies this
|
||||
# symmetrically as ``leeway``.
|
||||
_JWT_LEEWAY_SECONDS = 30
|
||||
|
||||
# Mnemosyne is the audience; Daedalus is the only accepted issuer.
|
||||
_JWT_ISS = "daedalus"
|
||||
|
||||
# Bounded LRU of recently-seen jti values to discourage replay within
|
||||
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
|
||||
# is best-effort and per-process.
|
||||
_JTI_CACHE_MAX = 4096
|
||||
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
|
||||
|
||||
|
||||
class MCPAuthError(Exception):
|
||||
"""Raised when a bearer token cannot be resolved to a valid user."""
|
||||
|
||||
|
||||
# --- Opaque token path -----------------------------------------------------
|
||||
|
||||
|
||||
def resolve_mcp_user(token_string: str):
|
||||
"""Resolve a bearer token to (user, MCPToken). Raises MCPAuthError on any failure.
|
||||
"""Resolve an opaque bearer token to (user, MCPToken).
|
||||
|
||||
Hashes the incoming bearer and looks up by the hash — plaintext is never
|
||||
stored or compared directly.
|
||||
@@ -50,6 +87,110 @@ def resolve_mcp_user(token_string: str):
|
||||
return token.user, token
|
||||
|
||||
|
||||
# --- JWT path --------------------------------------------------------------
|
||||
|
||||
|
||||
def looks_like_jwt(token_string: str) -> bool:
|
||||
"""Cheap pre-check: 3 dot-separated segments and a valid base64url header
|
||||
that decodes to JSON with an alg field."""
|
||||
parts = token_string.split(".")
|
||||
if len(parts) != 3:
|
||||
return False
|
||||
try:
|
||||
header_bytes = _b64url_decode(parts[0])
|
||||
header = json.loads(header_bytes)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
return False
|
||||
return isinstance(header, dict) and isinstance(header.get("alg"), str)
|
||||
|
||||
|
||||
def _b64url_decode(segment: str) -> bytes:
|
||||
pad = "=" * (-len(segment) % 4)
|
||||
return base64.urlsafe_b64decode(segment + pad)
|
||||
|
||||
|
||||
def _remember_jti(jti: str) -> bool:
|
||||
"""Return True if this jti has been seen before in this process."""
|
||||
now = time.time()
|
||||
# Drop any cached jti whose entry is older than 1h — generous given exp ≤ 600s.
|
||||
cutoff = now - 3600
|
||||
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < cutoff:
|
||||
_JTI_CACHE.popitem(last=False)
|
||||
if jti in _JTI_CACHE:
|
||||
return True
|
||||
if len(_JTI_CACHE) >= _JTI_CACHE_MAX:
|
||||
_JTI_CACHE.popitem(last=False)
|
||||
_JTI_CACHE[jti] = now
|
||||
return False
|
||||
|
||||
|
||||
def resolve_mcp_jwt(token_string: str) -> dict:
|
||||
"""Validate a signed JWT and return its claims dict.
|
||||
|
||||
Raises ``MCPAuthError`` on any failure. Does not touch ``MCPToken`` —
|
||||
JWTs are stateless and stored only as their signing key (``MCPSigningKey``).
|
||||
"""
|
||||
try:
|
||||
unverified_header = pyjwt.get_unverified_header(token_string)
|
||||
except pyjwt.PyJWTError as exc:
|
||||
raise MCPAuthError(f"Malformed JWT header: {exc}")
|
||||
|
||||
kid = unverified_header.get("kid")
|
||||
if not kid:
|
||||
raise MCPAuthError("JWT header missing kid.")
|
||||
if unverified_header.get("alg") != "HS256":
|
||||
raise MCPAuthError("Unsupported JWT alg.")
|
||||
|
||||
key = MCPSigningKey.objects.by_kid(kid)
|
||||
if key is None:
|
||||
raise MCPAuthError(f"Unknown signing key kid={kid!r}.")
|
||||
if not key.is_active:
|
||||
raise MCPAuthError(f"Signing key {kid!r} has been retired.")
|
||||
|
||||
try:
|
||||
secret = bytes.fromhex(key.secret_hex)
|
||||
except ValueError:
|
||||
raise MCPAuthError(f"Stored secret for kid={kid!r} is not valid hex.")
|
||||
|
||||
try:
|
||||
claims = pyjwt.decode(
|
||||
token_string,
|
||||
secret,
|
||||
algorithms=["HS256"],
|
||||
leeway=_JWT_LEEWAY_SECONDS,
|
||||
options={"require": ["exp", "iat", "iss", "sub", "jti"]},
|
||||
issuer=_JWT_ISS,
|
||||
)
|
||||
except pyjwt.ExpiredSignatureError:
|
||||
raise MCPAuthError("Token has expired.")
|
||||
except pyjwt.InvalidIssuerError:
|
||||
raise MCPAuthError("Invalid token issuer.")
|
||||
except pyjwt.InvalidSignatureError:
|
||||
raise MCPAuthError("Invalid MCP token.")
|
||||
except pyjwt.MissingRequiredClaimError as exc:
|
||||
raise MCPAuthError(f"JWT missing required claim: {exc.claim}")
|
||||
except pyjwt.PyJWTError as exc:
|
||||
raise MCPAuthError(f"Invalid JWT: {exc}")
|
||||
|
||||
jti = claims.get("jti")
|
||||
if not isinstance(jti, str) or not jti:
|
||||
raise MCPAuthError("JWT jti must be a non-empty string.")
|
||||
if _remember_jti(jti):
|
||||
raise MCPAuthError("Token replay detected.")
|
||||
|
||||
# Normalize claim shapes: ws may be null/absent, libs default to [].
|
||||
claims["ws"] = claims.get("ws") or None
|
||||
libs = claims.get("libs") or []
|
||||
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
|
||||
raise MCPAuthError("JWT libs must be a list of strings.")
|
||||
claims["libs"] = libs
|
||||
claims["kid"] = kid
|
||||
return claims
|
||||
|
||||
|
||||
# --- Middleware ------------------------------------------------------------
|
||||
|
||||
|
||||
class MCPAuthMiddleware(Middleware):
|
||||
"""
|
||||
FastMCP middleware that authenticates tool calls via Bearer tokens.
|
||||
@@ -59,20 +200,27 @@ class MCPAuthMiddleware(Middleware):
|
||||
MCP_REQUIRE_AUTH=False.
|
||||
"""
|
||||
|
||||
async def on_call_tool(
|
||||
self, context: MiddlewareContext, call_next
|
||||
):
|
||||
async def on_call_tool(self, context: MiddlewareContext, call_next):
|
||||
require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True)
|
||||
token_string = self._extract_token()
|
||||
|
||||
user = None
|
||||
token = None
|
||||
claims: dict | None = None
|
||||
|
||||
if token_string:
|
||||
try:
|
||||
user, token = await sync_to_async(
|
||||
resolve_mcp_user, thread_sensitive=True
|
||||
)(token_string)
|
||||
if looks_like_jwt(token_string):
|
||||
claims = await sync_to_async(
|
||||
resolve_mcp_jwt, thread_sensitive=True
|
||||
)(token_string)
|
||||
user = await sync_to_async(
|
||||
_resolve_jwt_actor, thread_sensitive=True
|
||||
)(claims)
|
||||
else:
|
||||
user, token = await sync_to_async(
|
||||
resolve_mcp_user, thread_sensitive=True
|
||||
)(token_string)
|
||||
except MCPAuthError as exc:
|
||||
mcp_auth_failures_total.labels(reason=str(exc)).inc()
|
||||
if require_auth:
|
||||
@@ -89,9 +237,13 @@ class MCPAuthMiddleware(Middleware):
|
||||
)
|
||||
|
||||
fastmcp_ctx = getattr(context, "fastmcp_context", None)
|
||||
if fastmcp_ctx and user is not None:
|
||||
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
|
||||
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
|
||||
if fastmcp_ctx is not None:
|
||||
if user is not None:
|
||||
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
|
||||
if token is not None:
|
||||
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
|
||||
if claims is not None:
|
||||
await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims)
|
||||
|
||||
return await call_next(context)
|
||||
|
||||
@@ -111,3 +263,28 @@ class MCPAuthMiddleware(Middleware):
|
||||
msg = getattr(context, "message", None)
|
||||
params = getattr(msg, "params", None) if msg else None
|
||||
return getattr(params, "name", None)
|
||||
|
||||
|
||||
# --- Helpers ---------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_jwt_actor(claims: dict):
|
||||
"""Resolve the synthetic actor for a JWT-authenticated turn.
|
||||
|
||||
Returns the system service user (``MCP_JWT_SERVICE_USERNAME``, default
|
||||
``daedalus-service``). The user must exist and be active. JWT tokens
|
||||
are not tied to per-user accounts — claims encode all authorization.
|
||||
"""
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
User = get_user_model()
|
||||
username = getattr(settings, "MCP_JWT_SERVICE_USERNAME", "daedalus-service")
|
||||
try:
|
||||
user = User.objects.get(username=username)
|
||||
except User.DoesNotExist:
|
||||
raise MCPAuthError(
|
||||
f"JWT service user {username!r} does not exist; provision via management command."
|
||||
)
|
||||
if not user.is_active:
|
||||
raise MCPAuthError(f"JWT service user {username!r} is disabled.")
|
||||
return user
|
||||
|
||||
Reference in New Issue
Block a user