docs: clarify Daedalus-Pallas integration auth model
Refine the phase-2 integration spec to reflect implementation details: - Change `resolved_libraries` from `set[str]` to ordered `list[str]` - Document `MCPToken.allowed_libraries` as JSONField (not M2M) since Library lives in Neo4j, not Django's ORM - Clarify that `Library.workspace_id` is a content-routing attribute, not an authorization axis - Describe retirement of the three-branch `_WORKSPACE_SCOPE_CLAUSE` in favor of a single `lib.uid IN $resolved_libraries` check - Specify team JWT resolution via `TeamWorkspaceAssignment` DB join - Note admin UI materializes full Library UID list explicitly
This commit is contained in:
@@ -1,27 +1,35 @@
|
||||
"""MCP token resolution and FastMCP middleware for bearer-token auth.
|
||||
|
||||
Two token shapes are supported:
|
||||
Three credential types are accepted — see
|
||||
``docs/DAEDALUS_PALLAS_INTEGRATION_v1.md`` §3.2 for the full model:
|
||||
|
||||
* **Opaque** — the original `MCPToken` row. Long-lived, hashed at rest,
|
||||
used by the dashboard / Claude Desktop / admin tooling. Plaintext
|
||||
hashes to a row in `mcp_token`.
|
||||
* **Signed JWT** — per-turn token minted by Daedalus. Carries
|
||||
`{ws, libs}` claims. Validated entirely off the signature + claims;
|
||||
no database lookup of the token itself, only of the signing key
|
||||
(`MCPSigningKey`) referenced by the JWT header's `kid`.
|
||||
1. **Opaque ``MCPToken``** (long-lived, hashed at rest). Authorization
|
||||
scope is its ``allowed_libraries`` JSON list.
|
||||
2. **Per-turn signed JWT** (``iss=daedalus``, ≤10 min, legacy — retires
|
||||
in Phase 4 when Daedalus chat itself becomes a Pallas Team). Scope
|
||||
is the ``libs`` claim.
|
||||
3. **Team JWT** (``iss=mnemosyne``, ``typ=team``, 10-year lifetime).
|
||||
Scope is resolved live by joining ``TeamWorkspaceAssignment`` rows
|
||||
to Neo4j ``Library.workspace_id``.
|
||||
|
||||
Every branch populates a single :data:`STATE_KEY_RESOLVED_LIBRARIES`
|
||||
value on the FastMCP context — a ``list[str]`` of Library UIDs the
|
||||
downstream tools are permitted to read. Tools never consult claim
|
||||
shapes; they read this list via
|
||||
``mcp_server.context.get_mcp_resolved_libraries``.
|
||||
|
||||
Detection: a bearer with three base64url segments separated by dots and
|
||||
a parseable `{"alg":"HS256","kid":...}` header is treated as JWT; anything
|
||||
a parseable ``{"alg":"HS256","kid":...}`` header is treated as JWT; anything
|
||||
else falls through to the opaque path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
|
||||
import jwt as pyjwt
|
||||
@@ -32,20 +40,26 @@ from fastmcp.server.dependencies import get_http_request
|
||||
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
||||
|
||||
from .metrics import mcp_auth_failures_total
|
||||
from .models import MCPSigningKey, MCPToken, hash_token
|
||||
from .models import MCPSigningKey, MCPToken, Team, hash_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STATE_KEY_USER = "mcp_user"
|
||||
STATE_KEY_TOKEN = "mcp_token"
|
||||
STATE_KEY_CLAIMS = "mcp_claims"
|
||||
STATE_KEY_RESOLVED_LIBRARIES = "mcp_resolved_libraries"
|
||||
|
||||
# Permitted clock skew when validating JWT exp/iat. PyJWT applies this
|
||||
# symmetrically as ``leeway``.
|
||||
_JWT_LEEWAY_SECONDS = 30
|
||||
|
||||
# Mnemosyne is the audience; Daedalus is the only accepted issuer.
|
||||
_JWT_ISS = "daedalus"
|
||||
# Accepted JWT issuers.
|
||||
#
|
||||
# ``daedalus`` — per-turn tokens minted by Daedalus chat (legacy path,
|
||||
# retires with Phase 4).
|
||||
# ``mnemosyne`` — team tokens minted by this service. ``typ=team``
|
||||
# distinguishes them from any future self-issued credential.
|
||||
_JWT_ISS_VALUES = {"daedalus", "mnemosyne"}
|
||||
|
||||
# Bounded LRU of recently-seen jti values to discourage replay within
|
||||
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
|
||||
@@ -59,6 +73,10 @@ _JWT_ISS = "daedalus"
|
||||
# ``exp`` has passed — that's the scenario PyJWT's own ``exp`` check
|
||||
# would have already rejected, this is belt-and-braces for clock skew
|
||||
# or a resurrected captured token.
|
||||
#
|
||||
# Team tokens (``typ=team``) bypass this cache entirely — they are
|
||||
# reused on every request by design. Revocation for those tokens runs
|
||||
# against the live ``Team`` row (``active`` + ``active_jti``).
|
||||
_JTI_CACHE_MAX = 4096
|
||||
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
|
||||
|
||||
@@ -159,8 +177,22 @@ def _remember_jti(jti: str, exp: float) -> bool:
|
||||
def resolve_mcp_jwt(token_string: str) -> dict:
|
||||
"""Validate a signed JWT and return its claims dict.
|
||||
|
||||
Raises ``MCPAuthError`` on any failure. Does not touch ``MCPToken`` —
|
||||
JWTs are stateless and stored only as their signing key (``MCPSigningKey``).
|
||||
Accepts both the legacy per-turn issuer (``iss=daedalus``) and the
|
||||
new team issuer (``iss=mnemosyne``, ``typ=team``). The returned
|
||||
claims dict is normalized so the middleware doesn't have to guess:
|
||||
|
||||
* ``claims["iss"]`` — as presented (``daedalus`` or ``mnemosyne``).
|
||||
* ``claims["typ"]`` — ``"team"`` for team tokens, otherwise absent.
|
||||
* ``claims["libs"]`` — per-turn only; normalized to ``list[str]``.
|
||||
* ``claims["ws"]`` — per-turn only; may be ``None``. Not consulted
|
||||
for authorization (kept for diagnostics).
|
||||
* ``claims["team_id"]`` — team only; ``UUID`` parsed from
|
||||
``sub == "team:<uuid>"``.
|
||||
* ``claims["kid"]`` — copy of the JWT header's ``kid``.
|
||||
|
||||
Raises :class:`MCPAuthError` on any failure. The per-turn path runs
|
||||
the ``_remember_jti`` replay check; the team path skips it (team
|
||||
JWTs are intentionally reused across the token's lifetime).
|
||||
"""
|
||||
try:
|
||||
unverified_header = pyjwt.get_unverified_header(token_string)
|
||||
@@ -191,7 +223,9 @@ def resolve_mcp_jwt(token_string: str) -> dict:
|
||||
algorithms=["HS256"],
|
||||
leeway=_JWT_LEEWAY_SECONDS,
|
||||
options={"require": ["exp", "iat", "iss", "sub", "jti"]},
|
||||
issuer=_JWT_ISS,
|
||||
# ``issuer=`` accepts ``str | Iterable[str]`` and raises
|
||||
# ``InvalidIssuerError`` if the claim is outside the set.
|
||||
issuer=list(_JWT_ISS_VALUES),
|
||||
)
|
||||
except pyjwt.ExpiredSignatureError:
|
||||
raise MCPAuthError("Token has expired.")
|
||||
@@ -212,19 +246,79 @@ def resolve_mcp_jwt(token_string: str) -> dict:
|
||||
# ``require=["exp", ...]`` above guarantees presence + numeric; this
|
||||
# is defence in depth against future PyJWT changes.
|
||||
raise MCPAuthError("JWT exp must be numeric.")
|
||||
if _remember_jti(jti, float(exp)):
|
||||
raise MCPAuthError("Token replay detected.")
|
||||
|
||||
# Normalize claim shapes: ws may be null/absent, libs default to [].
|
||||
claims["ws"] = claims.get("ws") or None
|
||||
libs = claims.get("libs") or []
|
||||
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
|
||||
raise MCPAuthError("JWT libs must be a list of strings.")
|
||||
claims["libs"] = libs
|
||||
typ = claims.get("typ")
|
||||
|
||||
if typ == "team":
|
||||
# Team tokens: no replay cache, no ``libs`` or ``ws`` claims.
|
||||
# Verify the ``sub`` shape and parse the embedded team UUID so
|
||||
# the middleware doesn't have to re-parse it later.
|
||||
sub = claims.get("sub")
|
||||
if not isinstance(sub, str) or not sub.startswith("team:"):
|
||||
raise MCPAuthError("Invalid MCP token.")
|
||||
try:
|
||||
claims["team_id"] = uuid.UUID(sub[len("team:"):])
|
||||
except ValueError:
|
||||
raise MCPAuthError("Invalid MCP token.")
|
||||
else:
|
||||
# Per-turn (legacy) path: replay-cache gate + normalize claims.
|
||||
if _remember_jti(jti, float(exp)):
|
||||
raise MCPAuthError("Token replay detected.")
|
||||
|
||||
claims["ws"] = claims.get("ws") or None
|
||||
libs = claims.get("libs") or []
|
||||
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
|
||||
raise MCPAuthError("JWT libs must be a list of strings.")
|
||||
claims["libs"] = libs
|
||||
|
||||
claims["kid"] = kid
|
||||
return claims
|
||||
|
||||
|
||||
# --- Team-JWT library resolution ------------------------------------------
|
||||
|
||||
|
||||
def _libraries_for_team(team_id: uuid.UUID, jti: str) -> list[str]:
|
||||
"""Resolve a team token to the Library UIDs it may read.
|
||||
|
||||
Runs two cheap queries in sequence:
|
||||
|
||||
1. Fetch the ``Team`` row by UUID. Reject if it doesn't exist, is
|
||||
inactive, or its ``active_jti`` doesn't match the incoming
|
||||
``jti`` — this is how rotation / soft-delete revocation becomes
|
||||
effective on the *next* request.
|
||||
2. If active: read every ``TeamWorkspaceAssignment.workspace_id`` for
|
||||
the team and translate them into Library UIDs via a single
|
||||
Cypher query against Neo4j.
|
||||
|
||||
Returns an empty list when the team has no workspace assignments
|
||||
(fail-closed — a team pointing at no workspaces sees no libraries).
|
||||
"""
|
||||
try:
|
||||
team = Team.objects.get(pk=team_id)
|
||||
except Team.DoesNotExist:
|
||||
raise MCPAuthError("Invalid MCP token.")
|
||||
|
||||
if not team.active:
|
||||
raise MCPAuthError("Token has been deactivated.")
|
||||
if team.active_jti is None or str(team.active_jti) != jti:
|
||||
raise MCPAuthError("Invalid MCP token.")
|
||||
|
||||
workspace_ids = list(
|
||||
team.workspace_assignments.values_list("workspace_id", flat=True)
|
||||
)
|
||||
if not workspace_ids:
|
||||
return []
|
||||
|
||||
from neomodel import db
|
||||
|
||||
rows, _ = db.cypher_query(
|
||||
"MATCH (l:Library) WHERE l.workspace_id IN $ws RETURN l.uid",
|
||||
{"ws": workspace_ids},
|
||||
)
|
||||
return [row[0] for row in rows if row and row[0]]
|
||||
|
||||
|
||||
# --- Middleware ------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -234,7 +328,18 @@ class MCPAuthMiddleware(Middleware):
|
||||
|
||||
Listing tools/resources is permitted unauthenticated so clients can
|
||||
discover the surface; calling a tool requires a valid token unless
|
||||
MCP_REQUIRE_AUTH=False.
|
||||
``MCP_REQUIRE_AUTH=False``.
|
||||
|
||||
On every authenticated call the middleware attaches four values to
|
||||
the FastMCP ``Context`` state for downstream tools to consume via
|
||||
:mod:`mcp_server.context`:
|
||||
|
||||
* ``STATE_KEY_USER`` — Django user.
|
||||
* ``STATE_KEY_TOKEN`` — MCPToken row (opaque callers only).
|
||||
* ``STATE_KEY_CLAIMS`` — JWT claims dict (JWT callers only).
|
||||
* ``STATE_KEY_RESOLVED_LIBRARIES`` — authorization-resolved Library
|
||||
UID list. Tools read this; they never read ``STATE_KEY_CLAIMS``
|
||||
for authorization.
|
||||
"""
|
||||
|
||||
# Tools that don't touch user data and must be callable without a token
|
||||
@@ -261,6 +366,7 @@ class MCPAuthMiddleware(Middleware):
|
||||
user = None
|
||||
token = None
|
||||
claims: dict | None = None
|
||||
resolved_libraries: list[str] | None = None
|
||||
|
||||
if token_string:
|
||||
try:
|
||||
@@ -271,10 +377,16 @@ class MCPAuthMiddleware(Middleware):
|
||||
user = await sync_to_async(
|
||||
_resolve_jwt_actor, thread_sensitive=True
|
||||
)(claims)
|
||||
resolved_libraries = await sync_to_async(
|
||||
_resolved_libraries_for_jwt, thread_sensitive=True
|
||||
)(claims)
|
||||
else:
|
||||
user, token = await sync_to_async(
|
||||
resolve_mcp_user, thread_sensitive=True
|
||||
)(token_string)
|
||||
# Opaque tokens store the Library UID list directly.
|
||||
# Empty list = fail-closed; not "everything".
|
||||
resolved_libraries = list(token.allowed_libraries or [])
|
||||
except MCPAuthError as exc:
|
||||
mcp_auth_failures_total.labels(reason=str(exc)).inc()
|
||||
if require_auth:
|
||||
@@ -283,7 +395,6 @@ class MCPAuthMiddleware(Middleware):
|
||||
mcp_auth_failures_total.labels(reason="missing_token").inc()
|
||||
raise PermissionError("Authentication required. Provide a Bearer token.")
|
||||
|
||||
tool_name = self._extract_tool_name(context)
|
||||
if token and tool_name and not token.can_use_tool(tool_name):
|
||||
mcp_auth_failures_total.labels(reason="tool_not_allowed").inc()
|
||||
raise PermissionError(
|
||||
@@ -298,6 +409,18 @@ class MCPAuthMiddleware(Middleware):
|
||||
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
|
||||
if claims is not None:
|
||||
await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims)
|
||||
# Always publish resolved_libraries — None means "no auth
|
||||
# information" and the tools treat that as fail-closed.
|
||||
await fastmcp_ctx.set_state(
|
||||
STATE_KEY_RESOLVED_LIBRARIES, resolved_libraries
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"mcp_auth.resolved tool=%s principal=%s lib_count=%s",
|
||||
tool_name,
|
||||
self._describe_principal(user, token, claims),
|
||||
"none" if resolved_libraries is None else len(resolved_libraries),
|
||||
)
|
||||
|
||||
return await self._call_next_with_trace(tool_name, call_next, context)
|
||||
|
||||
@@ -334,6 +457,20 @@ class MCPAuthMiddleware(Middleware):
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _describe_principal(user, token, claims) -> str:
|
||||
"""Compact, log-friendly principal summary. No PII beyond usernames."""
|
||||
if claims is not None:
|
||||
typ = claims.get("typ")
|
||||
if typ == "team":
|
||||
return f"team:{claims.get('team_id')}"
|
||||
return f"jwt:{claims.get('sub')}"
|
||||
if token is not None:
|
||||
return f"mcptoken:{token.get_masked_token()}"
|
||||
if user is not None:
|
||||
return f"user:{user.username}"
|
||||
return "anonymous"
|
||||
|
||||
@staticmethod
|
||||
def _extract_token() -> str | None:
|
||||
"""Pull the Bearer token off the current HTTP request, if any.
|
||||
@@ -412,6 +549,10 @@ def _resolve_jwt_actor(claims: dict):
|
||||
Returns the system service user (``MCP_JWT_SERVICE_USERNAME``, default
|
||||
``daedalus-service``). The user must exist and be active. JWT tokens
|
||||
are not tied to per-user accounts — claims encode all authorization.
|
||||
|
||||
Used for both per-turn and team JWTs. The service user is a hook for
|
||||
usage accounting (LLMUsage / search metrics) and for the audit trail;
|
||||
authorization does not depend on it.
|
||||
"""
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
@@ -426,3 +567,14 @@ def _resolve_jwt_actor(claims: dict):
|
||||
if not user.is_active:
|
||||
raise MCPAuthError(f"JWT service user {username!r} is disabled.")
|
||||
return user
|
||||
|
||||
|
||||
def _resolved_libraries_for_jwt(claims: dict) -> list[str]:
|
||||
"""Pick the right resolver branch for a validated JWT claims dict.
|
||||
|
||||
* ``typ == "team"`` → live lookup via :func:`_libraries_for_team`.
|
||||
* otherwise → legacy ``claims["libs"]`` (per-turn JWT).
|
||||
"""
|
||||
if claims.get("typ") == "team":
|
||||
return _libraries_for_team(claims["team_id"], claims["jti"])
|
||||
return list(claims.get("libs") or [])
|
||||
|
||||
Reference in New Issue
Block a user