docs: clarify Daedalus-Pallas integration auth model
All checks were successful
CVE Scan & Docker Build / security-scan (push) Successful in 51s
CVE Scan & Docker Build / build-and-push (push) Successful in 2m27s

Refine the phase-2 integration spec to reflect implementation details:

- Change `resolved_libraries` from `set[str]` to ordered `list[str]`
- Document `MCPToken.allowed_libraries` as JSONField (not M2M) since
  Library lives in Neo4j, not Django's ORM
- Clarify that `Library.workspace_id` is a content-routing attribute,
  not an authorization axis
- Describe retirement of the three-branch `_WORKSPACE_SCOPE_CLAUSE` in
  favor of a single `lib.uid IN $resolved_libraries` check
- Specify team JWT resolution via `TeamWorkspaceAssignment` DB join
- Note admin UI materializes full Library UID list explicitly
This commit is contained in:
2026-05-10 11:59:44 -04:00
parent e9f6eeb1a3
commit 16fb7ff4dc
35 changed files with 1839 additions and 2035 deletions

View File

@@ -1,27 +1,35 @@
"""MCP token resolution and FastMCP middleware for bearer-token auth.
Two token shapes are supported:
Three credential types are accepted — see
``docs/DAEDALUS_PALLAS_INTEGRATION_v1.md`` §3.2 for the full model:
* **Opaque** — the original `MCPToken` row. Long-lived, hashed at rest,
used by the dashboard / Claude Desktop / admin tooling. Plaintext
hashes to a row in `mcp_token`.
* **Signed JWT** — per-turn token minted by Daedalus. Carries
`{ws, libs}` claims. Validated entirely off the signature + claims;
no database lookup of the token itself, only of the signing key
(`MCPSigningKey`) referenced by the JWT header's `kid`.
1. **Opaque ``MCPToken``** (long-lived, hashed at rest). Authorization
scope is its ``allowed_libraries`` JSON list.
2. **Per-turn signed JWT** (``iss=daedalus``, ≤10 min, legacy — retires
in Phase 4 when Daedalus chat itself becomes a Pallas Team). Scope
is the ``libs`` claim.
3. **Team JWT** (``iss=mnemosyne``, ``typ=team``, 10-year lifetime).
Scope is resolved live by joining ``TeamWorkspaceAssignment`` rows
to Neo4j ``Library.workspace_id``.
Every branch populates a single :data:`STATE_KEY_RESOLVED_LIBRARIES`
value on the FastMCP context — a ``list[str]`` of Library UIDs the
downstream tools are permitted to read. Tools never consult claim
shapes; they read this list via
``mcp_server.context.get_mcp_resolved_libraries``.
Detection: a bearer with three base64url segments separated by dots and
a parseable `{"alg":"HS256","kid":...}` header is treated as JWT; anything
a parseable ``{"alg":"HS256","kid":...}`` header is treated as JWT; anything
else falls through to the opaque path.
"""
from __future__ import annotations
import base64
import hashlib
import json
import logging
import time
import uuid
from collections import OrderedDict
import jwt as pyjwt
@@ -32,20 +40,26 @@ from fastmcp.server.dependencies import get_http_request
from fastmcp.server.middleware import Middleware, MiddlewareContext
from .metrics import mcp_auth_failures_total
from .models import MCPSigningKey, MCPToken, hash_token
from .models import MCPSigningKey, MCPToken, Team, hash_token
logger = logging.getLogger(__name__)
STATE_KEY_USER = "mcp_user"
STATE_KEY_TOKEN = "mcp_token"
STATE_KEY_CLAIMS = "mcp_claims"
STATE_KEY_RESOLVED_LIBRARIES = "mcp_resolved_libraries"
# Permitted clock skew when validating JWT exp/iat. PyJWT applies this
# symmetrically as ``leeway``.
_JWT_LEEWAY_SECONDS = 30
# Mnemosyne is the audience; Daedalus is the only accepted issuer.
_JWT_ISS = "daedalus"
# Accepted JWT issuers.
#
# ``daedalus`` — per-turn tokens minted by Daedalus chat (legacy path,
# retires with Phase 4).
# ``mnemosyne`` — team tokens minted by this service. ``typ=team``
# distinguishes them from any future self-issued credential.
_JWT_ISS_VALUES = {"daedalus", "mnemosyne"}
# Bounded LRU of recently-seen jti values to discourage replay within
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
@@ -59,6 +73,10 @@ _JWT_ISS = "daedalus"
# ``exp`` has passed — that's the scenario PyJWT's own ``exp`` check
# would have already rejected, this is belt-and-braces for clock skew
# or a resurrected captured token.
#
# Team tokens (``typ=team``) bypass this cache entirely — they are
# reused on every request by design. Revocation for those tokens runs
# against the live ``Team`` row (``active`` + ``active_jti``).
_JTI_CACHE_MAX = 4096
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
@@ -159,8 +177,22 @@ def _remember_jti(jti: str, exp: float) -> bool:
def resolve_mcp_jwt(token_string: str) -> dict:
"""Validate a signed JWT and return its claims dict.
Raises ``MCPAuthError`` on any failure. Does not touch ``MCPToken`` —
JWTs are stateless and stored only as their signing key (``MCPSigningKey``).
Accepts both the legacy per-turn issuer (``iss=daedalus``) and the
new team issuer (``iss=mnemosyne``, ``typ=team``). The returned
claims dict is normalized so the middleware doesn't have to guess:
* ``claims["iss"]`` — as presented (``daedalus`` or ``mnemosyne``).
* ``claims["typ"]`` — ``"team"`` for team tokens, otherwise absent.
* ``claims["libs"]`` — per-turn only; normalized to ``list[str]``.
* ``claims["ws"]`` — per-turn only; may be ``None``. Not consulted
for authorization (kept for diagnostics).
* ``claims["team_id"]`` — team only; ``UUID`` parsed from
``sub == "team:<uuid>"``.
* ``claims["kid"]`` — copy of the JWT header's ``kid``.
Raises :class:`MCPAuthError` on any failure. The per-turn path runs
the ``_remember_jti`` replay check; the team path skips it (team
JWTs are intentionally reused across the token's lifetime).
"""
try:
unverified_header = pyjwt.get_unverified_header(token_string)
@@ -191,7 +223,9 @@ def resolve_mcp_jwt(token_string: str) -> dict:
algorithms=["HS256"],
leeway=_JWT_LEEWAY_SECONDS,
options={"require": ["exp", "iat", "iss", "sub", "jti"]},
issuer=_JWT_ISS,
# ``issuer=`` accepts ``str | Iterable[str]`` and raises
# ``InvalidIssuerError`` if the claim is outside the set.
issuer=list(_JWT_ISS_VALUES),
)
except pyjwt.ExpiredSignatureError:
raise MCPAuthError("Token has expired.")
@@ -212,19 +246,79 @@ def resolve_mcp_jwt(token_string: str) -> dict:
# ``require=["exp", ...]`` above guarantees presence + numeric; this
# is defence in depth against future PyJWT changes.
raise MCPAuthError("JWT exp must be numeric.")
if _remember_jti(jti, float(exp)):
raise MCPAuthError("Token replay detected.")
# Normalize claim shapes: ws may be null/absent, libs default to [].
claims["ws"] = claims.get("ws") or None
libs = claims.get("libs") or []
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
raise MCPAuthError("JWT libs must be a list of strings.")
claims["libs"] = libs
typ = claims.get("typ")
if typ == "team":
# Team tokens: no replay cache, no ``libs`` or ``ws`` claims.
# Verify the ``sub`` shape and parse the embedded team UUID so
# the middleware doesn't have to re-parse it later.
sub = claims.get("sub")
if not isinstance(sub, str) or not sub.startswith("team:"):
raise MCPAuthError("Invalid MCP token.")
try:
claims["team_id"] = uuid.UUID(sub[len("team:"):])
except ValueError:
raise MCPAuthError("Invalid MCP token.")
else:
# Per-turn (legacy) path: replay-cache gate + normalize claims.
if _remember_jti(jti, float(exp)):
raise MCPAuthError("Token replay detected.")
claims["ws"] = claims.get("ws") or None
libs = claims.get("libs") or []
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
raise MCPAuthError("JWT libs must be a list of strings.")
claims["libs"] = libs
claims["kid"] = kid
return claims
# --- Team-JWT library resolution ------------------------------------------
def _libraries_for_team(team_id: uuid.UUID, jti: str) -> list[str]:
"""Resolve a team token to the Library UIDs it may read.
Runs two cheap queries in sequence:
1. Fetch the ``Team`` row by UUID. Reject if it doesn't exist, is
inactive, or its ``active_jti`` doesn't match the incoming
``jti`` — this is how rotation / soft-delete revocation becomes
effective on the *next* request.
2. If active: read every ``TeamWorkspaceAssignment.workspace_id`` for
the team and translate them into Library UIDs via a single
Cypher query against Neo4j.
Returns an empty list when the team has no workspace assignments
(fail-closed — a team pointing at no workspaces sees no libraries).
"""
try:
team = Team.objects.get(pk=team_id)
except Team.DoesNotExist:
raise MCPAuthError("Invalid MCP token.")
if not team.active:
raise MCPAuthError("Token has been deactivated.")
if team.active_jti is None or str(team.active_jti) != jti:
raise MCPAuthError("Invalid MCP token.")
workspace_ids = list(
team.workspace_assignments.values_list("workspace_id", flat=True)
)
if not workspace_ids:
return []
from neomodel import db
rows, _ = db.cypher_query(
"MATCH (l:Library) WHERE l.workspace_id IN $ws RETURN l.uid",
{"ws": workspace_ids},
)
return [row[0] for row in rows if row and row[0]]
# --- Middleware ------------------------------------------------------------
@@ -234,7 +328,18 @@ class MCPAuthMiddleware(Middleware):
Listing tools/resources is permitted unauthenticated so clients can
discover the surface; calling a tool requires a valid token unless
MCP_REQUIRE_AUTH=False.
``MCP_REQUIRE_AUTH=False``.
On every authenticated call the middleware attaches four values to
the FastMCP ``Context`` state for downstream tools to consume via
:mod:`mcp_server.context`:
* ``STATE_KEY_USER`` — Django user.
* ``STATE_KEY_TOKEN`` — MCPToken row (opaque callers only).
* ``STATE_KEY_CLAIMS`` — JWT claims dict (JWT callers only).
* ``STATE_KEY_RESOLVED_LIBRARIES`` — authorization-resolved Library
UID list. Tools read this; they never read ``STATE_KEY_CLAIMS``
for authorization.
"""
# Tools that don't touch user data and must be callable without a token
@@ -261,6 +366,7 @@ class MCPAuthMiddleware(Middleware):
user = None
token = None
claims: dict | None = None
resolved_libraries: list[str] | None = None
if token_string:
try:
@@ -271,10 +377,16 @@ class MCPAuthMiddleware(Middleware):
user = await sync_to_async(
_resolve_jwt_actor, thread_sensitive=True
)(claims)
resolved_libraries = await sync_to_async(
_resolved_libraries_for_jwt, thread_sensitive=True
)(claims)
else:
user, token = await sync_to_async(
resolve_mcp_user, thread_sensitive=True
)(token_string)
# Opaque tokens store the Library UID list directly.
# Empty list = fail-closed; not "everything".
resolved_libraries = list(token.allowed_libraries or [])
except MCPAuthError as exc:
mcp_auth_failures_total.labels(reason=str(exc)).inc()
if require_auth:
@@ -283,7 +395,6 @@ class MCPAuthMiddleware(Middleware):
mcp_auth_failures_total.labels(reason="missing_token").inc()
raise PermissionError("Authentication required. Provide a Bearer token.")
tool_name = self._extract_tool_name(context)
if token and tool_name and not token.can_use_tool(tool_name):
mcp_auth_failures_total.labels(reason="tool_not_allowed").inc()
raise PermissionError(
@@ -298,6 +409,18 @@ class MCPAuthMiddleware(Middleware):
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
if claims is not None:
await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims)
# Always publish resolved_libraries — None means "no auth
# information" and the tools treat that as fail-closed.
await fastmcp_ctx.set_state(
STATE_KEY_RESOLVED_LIBRARIES, resolved_libraries
)
logger.info(
"mcp_auth.resolved tool=%s principal=%s lib_count=%s",
tool_name,
self._describe_principal(user, token, claims),
"none" if resolved_libraries is None else len(resolved_libraries),
)
return await self._call_next_with_trace(tool_name, call_next, context)
@@ -334,6 +457,20 @@ class MCPAuthMiddleware(Middleware):
)
return result
@staticmethod
def _describe_principal(user, token, claims) -> str:
"""Compact, log-friendly principal summary. No PII beyond usernames."""
if claims is not None:
typ = claims.get("typ")
if typ == "team":
return f"team:{claims.get('team_id')}"
return f"jwt:{claims.get('sub')}"
if token is not None:
return f"mcptoken:{token.get_masked_token()}"
if user is not None:
return f"user:{user.username}"
return "anonymous"
@staticmethod
def _extract_token() -> str | None:
"""Pull the Bearer token off the current HTTP request, if any.
@@ -412,6 +549,10 @@ def _resolve_jwt_actor(claims: dict):
Returns the system service user (``MCP_JWT_SERVICE_USERNAME``, default
``daedalus-service``). The user must exist and be active. JWT tokens
are not tied to per-user accounts — claims encode all authorization.
Used for both per-turn and team JWTs. The service user is a hook for
usage accounting (LLMUsage / search metrics) and for the audit trail;
authorization does not depend on it.
"""
from django.contrib.auth import get_user_model
@@ -426,3 +567,14 @@ def _resolve_jwt_actor(claims: dict):
if not user.is_active:
raise MCPAuthError(f"JWT service user {username!r} is disabled.")
return user
def _resolved_libraries_for_jwt(claims: dict) -> list[str]:
"""Pick the right resolver branch for a validated JWT claims dict.
* ``typ == "team"`` → live lookup via :func:`_libraries_for_team`.
* otherwise → legacy ``claims["libs"]`` (per-turn JWT).
"""
if claims.get("typ") == "team":
return _libraries_for_team(claims["team_id"], claims["jti"])
return list(claims.get("libs") or [])