feat(library): add workspace-scoped search and JWT auth for Daedalus
- Extend library list endpoint with `include_workspace` and `with_item_count` query params to support Daedalus registry mirroring - Expand search scope clause to three modes: workspace-only, workspace plus allowed user libraries, and global - Add `allowed_libraries` field to SearchRequest for Phase-2 JWT claims - Introduce JWT-based actor resolution using a synthetic service user (`MCP_JWT_SERVICE_USERNAME`) for Daedalus-originated requests
This commit is contained in:
@@ -40,13 +40,44 @@ logger = logging.getLogger(__name__)
|
||||
@api_view(["GET", "POST"])
|
||||
@permission_classes([IsAuthenticated])
|
||||
def library_list_create(request):
|
||||
"""List all libraries or create a new one."""
|
||||
"""List all libraries or create a new one.
|
||||
|
||||
GET supports ``?include_workspace=false`` (default ``true``) to filter out
|
||||
libraries that belong to a Daedalus workspace — the Daedalus library
|
||||
registry uses this to mirror only the user-managed catalog.
|
||||
|
||||
GET supports ``?with_item_count=true`` (default ``false``) to attach
|
||||
a per-library ``item_count``. Off by default because the count is a
|
||||
Cypher aggregate; on for the Daedalus-side registry poll.
|
||||
"""
|
||||
from library.models import Library
|
||||
|
||||
if request.method == "GET":
|
||||
libraries = Library.nodes.order_by("name")
|
||||
serializer = LibrarySerializer(libraries, many=True)
|
||||
return Response(serializer.data)
|
||||
include_workspace = request.GET.get("include_workspace", "true").lower() != "false"
|
||||
with_count = request.GET.get("with_item_count", "false").lower() == "true"
|
||||
|
||||
if include_workspace:
|
||||
libraries = list(Library.nodes.order_by("name"))
|
||||
else:
|
||||
libraries = list(Library.nodes.filter(workspace_id__isnull=True).order_by("name"))
|
||||
|
||||
data = LibrarySerializer(libraries, many=True).data
|
||||
|
||||
if with_count and libraries:
|
||||
from neomodel import db
|
||||
|
||||
uids = [lib.uid for lib in libraries]
|
||||
rows, _ = db.cypher_query(
|
||||
"MATCH (l:Library) WHERE l.uid IN $uids "
|
||||
"OPTIONAL MATCH (l)-[:CONTAINS]->(:Collection)-[:CONTAINS]->(i:Item) "
|
||||
"RETURN l.uid, count(i)",
|
||||
{"uids": uids},
|
||||
)
|
||||
counts = {uid: count for (uid, count) in rows}
|
||||
for entry in data:
|
||||
entry["item_count"] = counts.get(entry["uid"], 0)
|
||||
|
||||
return Response(data)
|
||||
|
||||
# POST — create
|
||||
serializer = LibrarySerializer(data=request.data)
|
||||
|
||||
@@ -26,14 +26,26 @@ from .fusion import ImageSearchResult, SearchCandidate, reciprocal_rank_fusion
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Workspace scoping clause appended to every search Cypher query.
|
||||
# Search-scope clause appended to every search Cypher query.
|
||||
#
|
||||
# A request with workspace_id set returns ONLY that workspace's content.
|
||||
# A request with workspace_id null returns ONLY global content (libraries
|
||||
# with no workspace_id). There is no third mode.
|
||||
# Three modes, picked structurally by which params are set:
|
||||
#
|
||||
# 1. ``workspace_id`` set, ``allowed_libraries`` empty → workspace-scoped.
|
||||
# Returns ONLY content from libraries whose workspace_id matches.
|
||||
# 2. ``workspace_id`` set + ``allowed_libraries`` non-empty → workspace
|
||||
# PLUS the listed user-managed libraries (typical Phase-2 chat turn).
|
||||
# 3. Both null → global. Returns ONLY libraries with no workspace_id
|
||||
# (legacy opaque-token callers / dashboard).
|
||||
#
|
||||
# When ``allowed_libraries`` is non-empty alone (no workspace_id), it
|
||||
# narrows results to those libraries.
|
||||
_WORKSPACE_SCOPE_CLAUSE = (
|
||||
" AND ($workspace_id IS NULL AND lib.workspace_id IS NULL OR "
|
||||
"lib.workspace_id = $workspace_id)"
|
||||
" AND ("
|
||||
"($workspace_id IS NOT NULL AND lib.workspace_id = $workspace_id) "
|
||||
"OR ($allowed_libraries IS NOT NULL AND lib.uid IN $allowed_libraries) "
|
||||
"OR ($workspace_id IS NULL AND $allowed_libraries IS NULL "
|
||||
" AND lib.workspace_id IS NULL)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
@@ -52,6 +64,10 @@ class SearchRequest:
|
||||
library_type: Optional[str] = None
|
||||
collection_uid: Optional[str] = None
|
||||
workspace_id: Optional[str] = None
|
||||
# Phase-2 token claim: user-managed libraries the caller may include
|
||||
# alongside their workspace's auto-library. Cypher uses ``IS NULL`` vs
|
||||
# non-empty list to gate the second branch of the scope clause.
|
||||
allowed_libraries: Optional[list[str]] = None
|
||||
search_types: list[str] = field(
|
||||
default_factory=lambda: ["vector", "fulltext", "graph"]
|
||||
)
|
||||
@@ -73,6 +89,11 @@ class SearchRequest:
|
||||
self.library_type = None
|
||||
if self.collection_uid == "":
|
||||
self.collection_uid = None
|
||||
# Empty list collapses to None so the Cypher branch reads
|
||||
# "$allowed_libraries IS NOT NULL" rather than "size > 0" — keeps
|
||||
# the parameter binding straightforward and the predicate sargable.
|
||||
if self.allowed_libraries is not None and len(self.allowed_libraries) == 0:
|
||||
self.allowed_libraries = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -300,6 +321,7 @@ class SearchService:
|
||||
"library_type": request.library_type,
|
||||
"collection_uid": request.collection_uid,
|
||||
"workspace_id": request.workspace_id,
|
||||
"allowed_libraries": request.allowed_libraries,
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -411,6 +433,7 @@ class SearchService:
|
||||
"library_type": request.library_type,
|
||||
"collection_uid": request.collection_uid,
|
||||
"workspace_id": request.workspace_id,
|
||||
"allowed_libraries": request.allowed_libraries,
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -471,6 +494,7 @@ class SearchService:
|
||||
"library_uid": request.library_uid,
|
||||
"library_type": request.library_type,
|
||||
"workspace_id": request.workspace_id,
|
||||
"allowed_libraries": request.allowed_libraries,
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -546,6 +570,7 @@ class SearchService:
|
||||
"library_uid": request.library_uid,
|
||||
"library_type": request.library_type,
|
||||
"workspace_id": request.workspace_id,
|
||||
"allowed_libraries": request.allowed_libraries,
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -630,6 +655,7 @@ class SearchService:
|
||||
"library_uid": request.library_uid,
|
||||
"library_type": request.library_type,
|
||||
"workspace_id": request.workspace_id,
|
||||
"allowed_libraries": request.allowed_libraries,
|
||||
}
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,31 +1,68 @@
|
||||
"""MCP token resolution and FastMCP middleware for bearer-token auth."""
|
||||
"""MCP token resolution and FastMCP middleware for bearer-token auth.
|
||||
|
||||
Two token shapes are supported:
|
||||
|
||||
* **Opaque** — the original `MCPToken` row. Long-lived, hashed at rest,
|
||||
used by the dashboard / Claude Desktop / admin tooling. Plaintext
|
||||
hashes to a row in `mcp_token`.
|
||||
* **Signed JWT** — per-turn token minted by Daedalus. Carries
|
||||
`{ws, libs}` claims. Validated entirely off the signature + claims;
|
||||
no database lookup of the token itself, only of the signing key
|
||||
(`MCPSigningKey`) referenced by the JWT header's `kid`.
|
||||
|
||||
Detection: a bearer with three base64url segments separated by dots and
|
||||
a parseable `{"alg":"HS256","kid":...}` header is treated as JWT; anything
|
||||
else falls through to the opaque path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
import jwt as pyjwt
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.utils import timezone
|
||||
from fastmcp.server.dependencies import get_http_request
|
||||
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
||||
|
||||
from .metrics import mcp_auth_failures_total
|
||||
from .models import MCPToken, hash_token
|
||||
from .models import MCPSigningKey, MCPToken, hash_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STATE_KEY_USER = "mcp_user"
|
||||
STATE_KEY_TOKEN = "mcp_token"
|
||||
STATE_KEY_CLAIMS = "mcp_claims"
|
||||
|
||||
# Permitted clock skew when validating JWT exp/iat. PyJWT applies this
|
||||
# symmetrically as ``leeway``.
|
||||
_JWT_LEEWAY_SECONDS = 30
|
||||
|
||||
# Mnemosyne is the audience; Daedalus is the only accepted issuer.
|
||||
_JWT_ISS = "daedalus"
|
||||
|
||||
# Bounded LRU of recently-seen jti values to discourage replay within
|
||||
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
|
||||
# is best-effort and per-process.
|
||||
_JTI_CACHE_MAX = 4096
|
||||
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
|
||||
|
||||
|
||||
class MCPAuthError(Exception):
|
||||
"""Raised when a bearer token cannot be resolved to a valid user."""
|
||||
|
||||
|
||||
# --- Opaque token path -----------------------------------------------------
|
||||
|
||||
|
||||
def resolve_mcp_user(token_string: str):
|
||||
"""Resolve a bearer token to (user, MCPToken). Raises MCPAuthError on any failure.
|
||||
"""Resolve an opaque bearer token to (user, MCPToken).
|
||||
|
||||
Hashes the incoming bearer and looks up by the hash — plaintext is never
|
||||
stored or compared directly.
|
||||
@@ -50,6 +87,110 @@ def resolve_mcp_user(token_string: str):
|
||||
return token.user, token
|
||||
|
||||
|
||||
# --- JWT path --------------------------------------------------------------
|
||||
|
||||
|
||||
def looks_like_jwt(token_string: str) -> bool:
|
||||
"""Cheap pre-check: 3 dot-separated segments and a valid base64url header
|
||||
that decodes to JSON with an alg field."""
|
||||
parts = token_string.split(".")
|
||||
if len(parts) != 3:
|
||||
return False
|
||||
try:
|
||||
header_bytes = _b64url_decode(parts[0])
|
||||
header = json.loads(header_bytes)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
return False
|
||||
return isinstance(header, dict) and isinstance(header.get("alg"), str)
|
||||
|
||||
|
||||
def _b64url_decode(segment: str) -> bytes:
|
||||
pad = "=" * (-len(segment) % 4)
|
||||
return base64.urlsafe_b64decode(segment + pad)
|
||||
|
||||
|
||||
def _remember_jti(jti: str) -> bool:
|
||||
"""Return True if this jti has been seen before in this process."""
|
||||
now = time.time()
|
||||
# Drop any cached jti whose entry is older than 1h — generous given exp ≤ 600s.
|
||||
cutoff = now - 3600
|
||||
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < cutoff:
|
||||
_JTI_CACHE.popitem(last=False)
|
||||
if jti in _JTI_CACHE:
|
||||
return True
|
||||
if len(_JTI_CACHE) >= _JTI_CACHE_MAX:
|
||||
_JTI_CACHE.popitem(last=False)
|
||||
_JTI_CACHE[jti] = now
|
||||
return False
|
||||
|
||||
|
||||
def resolve_mcp_jwt(token_string: str) -> dict:
|
||||
"""Validate a signed JWT and return its claims dict.
|
||||
|
||||
Raises ``MCPAuthError`` on any failure. Does not touch ``MCPToken`` —
|
||||
JWTs are stateless and stored only as their signing key (``MCPSigningKey``).
|
||||
"""
|
||||
try:
|
||||
unverified_header = pyjwt.get_unverified_header(token_string)
|
||||
except pyjwt.PyJWTError as exc:
|
||||
raise MCPAuthError(f"Malformed JWT header: {exc}")
|
||||
|
||||
kid = unverified_header.get("kid")
|
||||
if not kid:
|
||||
raise MCPAuthError("JWT header missing kid.")
|
||||
if unverified_header.get("alg") != "HS256":
|
||||
raise MCPAuthError("Unsupported JWT alg.")
|
||||
|
||||
key = MCPSigningKey.objects.by_kid(kid)
|
||||
if key is None:
|
||||
raise MCPAuthError(f"Unknown signing key kid={kid!r}.")
|
||||
if not key.is_active:
|
||||
raise MCPAuthError(f"Signing key {kid!r} has been retired.")
|
||||
|
||||
try:
|
||||
secret = bytes.fromhex(key.secret_hex)
|
||||
except ValueError:
|
||||
raise MCPAuthError(f"Stored secret for kid={kid!r} is not valid hex.")
|
||||
|
||||
try:
|
||||
claims = pyjwt.decode(
|
||||
token_string,
|
||||
secret,
|
||||
algorithms=["HS256"],
|
||||
leeway=_JWT_LEEWAY_SECONDS,
|
||||
options={"require": ["exp", "iat", "iss", "sub", "jti"]},
|
||||
issuer=_JWT_ISS,
|
||||
)
|
||||
except pyjwt.ExpiredSignatureError:
|
||||
raise MCPAuthError("Token has expired.")
|
||||
except pyjwt.InvalidIssuerError:
|
||||
raise MCPAuthError("Invalid token issuer.")
|
||||
except pyjwt.InvalidSignatureError:
|
||||
raise MCPAuthError("Invalid MCP token.")
|
||||
except pyjwt.MissingRequiredClaimError as exc:
|
||||
raise MCPAuthError(f"JWT missing required claim: {exc.claim}")
|
||||
except pyjwt.PyJWTError as exc:
|
||||
raise MCPAuthError(f"Invalid JWT: {exc}")
|
||||
|
||||
jti = claims.get("jti")
|
||||
if not isinstance(jti, str) or not jti:
|
||||
raise MCPAuthError("JWT jti must be a non-empty string.")
|
||||
if _remember_jti(jti):
|
||||
raise MCPAuthError("Token replay detected.")
|
||||
|
||||
# Normalize claim shapes: ws may be null/absent, libs default to [].
|
||||
claims["ws"] = claims.get("ws") or None
|
||||
libs = claims.get("libs") or []
|
||||
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
|
||||
raise MCPAuthError("JWT libs must be a list of strings.")
|
||||
claims["libs"] = libs
|
||||
claims["kid"] = kid
|
||||
return claims
|
||||
|
||||
|
||||
# --- Middleware ------------------------------------------------------------
|
||||
|
||||
|
||||
class MCPAuthMiddleware(Middleware):
|
||||
"""
|
||||
FastMCP middleware that authenticates tool calls via Bearer tokens.
|
||||
@@ -59,20 +200,27 @@ class MCPAuthMiddleware(Middleware):
|
||||
MCP_REQUIRE_AUTH=False.
|
||||
"""
|
||||
|
||||
async def on_call_tool(
|
||||
self, context: MiddlewareContext, call_next
|
||||
):
|
||||
async def on_call_tool(self, context: MiddlewareContext, call_next):
|
||||
require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True)
|
||||
token_string = self._extract_token()
|
||||
|
||||
user = None
|
||||
token = None
|
||||
claims: dict | None = None
|
||||
|
||||
if token_string:
|
||||
try:
|
||||
user, token = await sync_to_async(
|
||||
resolve_mcp_user, thread_sensitive=True
|
||||
)(token_string)
|
||||
if looks_like_jwt(token_string):
|
||||
claims = await sync_to_async(
|
||||
resolve_mcp_jwt, thread_sensitive=True
|
||||
)(token_string)
|
||||
user = await sync_to_async(
|
||||
_resolve_jwt_actor, thread_sensitive=True
|
||||
)(claims)
|
||||
else:
|
||||
user, token = await sync_to_async(
|
||||
resolve_mcp_user, thread_sensitive=True
|
||||
)(token_string)
|
||||
except MCPAuthError as exc:
|
||||
mcp_auth_failures_total.labels(reason=str(exc)).inc()
|
||||
if require_auth:
|
||||
@@ -89,9 +237,13 @@ class MCPAuthMiddleware(Middleware):
|
||||
)
|
||||
|
||||
fastmcp_ctx = getattr(context, "fastmcp_context", None)
|
||||
if fastmcp_ctx and user is not None:
|
||||
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
|
||||
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
|
||||
if fastmcp_ctx is not None:
|
||||
if user is not None:
|
||||
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
|
||||
if token is not None:
|
||||
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
|
||||
if claims is not None:
|
||||
await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims)
|
||||
|
||||
return await call_next(context)
|
||||
|
||||
@@ -111,3 +263,28 @@ class MCPAuthMiddleware(Middleware):
|
||||
msg = getattr(context, "message", None)
|
||||
params = getattr(msg, "params", None) if msg else None
|
||||
return getattr(params, "name", None)
|
||||
|
||||
|
||||
# --- Helpers ---------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_jwt_actor(claims: dict):
|
||||
"""Resolve the synthetic actor for a JWT-authenticated turn.
|
||||
|
||||
Returns the system service user (``MCP_JWT_SERVICE_USERNAME``, default
|
||||
``daedalus-service``). The user must exist and be active. JWT tokens
|
||||
are not tied to per-user accounts — claims encode all authorization.
|
||||
"""
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
User = get_user_model()
|
||||
username = getattr(settings, "MCP_JWT_SERVICE_USERNAME", "daedalus-service")
|
||||
try:
|
||||
user = User.objects.get(username=username)
|
||||
except User.DoesNotExist:
|
||||
raise MCPAuthError(
|
||||
f"JWT service user {username!r} does not exist; provision via management command."
|
||||
)
|
||||
if not user.is_active:
|
||||
raise MCPAuthError(f"JWT service user {username!r} is disabled.")
|
||||
return user
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from fastmcp.server.context import Context
|
||||
|
||||
from .auth import STATE_KEY_TOKEN, STATE_KEY_USER
|
||||
from .auth import STATE_KEY_CLAIMS, STATE_KEY_TOKEN, STATE_KEY_USER
|
||||
|
||||
|
||||
async def get_mcp_user(ctx: Context | None):
|
||||
@@ -17,3 +17,10 @@ async def get_mcp_token(ctx: Context | None):
|
||||
if ctx is None:
|
||||
return None
|
||||
return await ctx.get_state(STATE_KEY_TOKEN)
|
||||
|
||||
|
||||
async def get_mcp_claims(ctx: Context | None) -> dict | None:
|
||||
"""Return the JWT claims dict for this request, or None for opaque-token callers."""
|
||||
if ctx is None:
|
||||
return None
|
||||
return await ctx.get_state(STATE_KEY_CLAIMS)
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Idempotently create the service user that JWT-authenticated MCP requests act as.
|
||||
|
||||
Daedalus mints per-turn JWTs whose claims encode all authorization (workspace,
|
||||
allowed libraries). The Django ``user`` field on the request still needs to
|
||||
point at *something* — the service user is that something. It owns no data
|
||||
and does not log in via the dashboard.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Idempotently create or reactivate the JWT service user (default 'daedalus-service')."
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument("--username", default="daedalus-service")
|
||||
parser.add_argument("--email", default="daedalus-service@local")
|
||||
|
||||
def handle(self, *args, **options):
|
||||
User = get_user_model()
|
||||
username = options["username"]
|
||||
email = options["email"]
|
||||
|
||||
user, created = User.objects.get_or_create(
|
||||
username=username,
|
||||
defaults={"email": email, "is_active": True},
|
||||
)
|
||||
if created:
|
||||
# Set a random password the user cannot log in with via the UI.
|
||||
user.set_password(secrets.token_urlsafe(32))
|
||||
user.save(update_fields=["password"])
|
||||
self.stdout.write(self.style.SUCCESS(f"Created service user {username!r}"))
|
||||
else:
|
||||
changed = False
|
||||
if not user.is_active:
|
||||
user.is_active = True
|
||||
changed = True
|
||||
if user.email != email:
|
||||
user.email = email
|
||||
changed = True
|
||||
if changed:
|
||||
user.save(update_fields=["is_active", "email"])
|
||||
self.stdout.write(self.style.SUCCESS(f"Updated service user {username!r}"))
|
||||
else:
|
||||
self.stdout.write(f"Service user {username!r} already provisioned")
|
||||
71
mnemosyne/mcp_server/management/commands/seed_signing_key.py
Normal file
71
mnemosyne/mcp_server/management/commands/seed_signing_key.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Mint a new MCPSigningKey for per-turn JWTs.
|
||||
|
||||
The secret is generated locally, printed once, and stored hex-encoded.
|
||||
Distribute the printed secret to the issuer (Daedalus) — after this
|
||||
command exits, the database row is the only copy on the Mnemosyne
|
||||
side, but copying the secret out of the row is allowed (admin model
|
||||
exposes it). Treat it like any other shared secret.
|
||||
|
||||
Use ``--retire-other`` to flip every other key to ``is_active=False``
|
||||
in the same transaction (typical rotation flow). Without it, the new
|
||||
key joins the active set and old keys keep validating until they are
|
||||
retired explicitly.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from django.db import transaction
|
||||
|
||||
from mcp_server.models import MCPSigningKey
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Generate a new HMAC signing key for per-turn JWTs and print it once."
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"--kid",
|
||||
required=True,
|
||||
help="Key ID (short, stable string — e.g. 'daedalus-1').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--note",
|
||||
default="",
|
||||
help="Optional note (e.g. 'rotated 2026-05-03 after staging incident').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retire-other",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Mark every other MCPSigningKey row inactive in the same "
|
||||
"transaction. Use during rotation."
|
||||
),
|
||||
)
|
||||
|
||||
@transaction.atomic
|
||||
def handle(self, *args, **options):
|
||||
kid = options["kid"]
|
||||
if MCPSigningKey.objects.filter(kid=kid).exists():
|
||||
raise CommandError(f"kid {kid!r} already exists; choose a unique value.")
|
||||
|
||||
# 256-bit secret, hex-encoded for portability.
|
||||
secret = secrets.token_hex(32)
|
||||
|
||||
if options["retire_other"]:
|
||||
from django.utils import timezone
|
||||
MCPSigningKey.objects.filter(is_active=True).update(
|
||||
is_active=False, retired_at=timezone.now()
|
||||
)
|
||||
|
||||
key = MCPSigningKey.objects.create(
|
||||
kid=kid, secret_hex=secret, is_active=True, note=options["note"]
|
||||
)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS("MCP signing key created"))
|
||||
self.stdout.write(f" kid: {key.kid}")
|
||||
self.stdout.write(f" active: {key.is_active}")
|
||||
if options["retire_other"]:
|
||||
self.stdout.write(" (other active keys retired)")
|
||||
self.stdout.write(self.style.WARNING(" Secret (256-bit hex — share with Daedalus):"))
|
||||
self.stdout.write(f" {secret}")
|
||||
38
mnemosyne/mcp_server/migrations/0003_mcpsigningkey.py
Normal file
38
mnemosyne/mcp_server/migrations/0003_mcpsigningkey.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""HMAC signing keys for per-turn JWTs minted by Daedalus.
|
||||
|
||||
Adds the MCPSigningKey table. Per-turn tokens (workspace + library claims,
|
||||
exp <= 600s) are not stored — only the signing key, indexed by ``kid``,
|
||||
so the signature can be validated and rotated cleanly.
|
||||
"""
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("mcp_server", "0002_hash_token"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="MCPSigningKey",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
("kid", models.CharField(max_length=64, unique=True, db_index=True)),
|
||||
("secret_hex", models.CharField(max_length=128)),
|
||||
("is_active", models.BooleanField(default=True)),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("retired_at", models.DateTimeField(blank=True, null=True)),
|
||||
("note", models.TextField(blank=True)),
|
||||
],
|
||||
options={"ordering": ["-created_at"]},
|
||||
),
|
||||
]
|
||||
@@ -84,3 +84,46 @@ class MCPToken(models.Model):
|
||||
hash prefixed with `mcp_…`. Stable per token, never reveals plaintext.
|
||||
"""
|
||||
return f"mcp_…{self.token_hash[:8]}"
|
||||
|
||||
|
||||
class MCPSigningKeyManager(models.Manager):
|
||||
def active(self):
|
||||
"""Active keys, newest first. Multiple may overlap during rotation."""
|
||||
return self.filter(is_active=True).order_by("-created_at")
|
||||
|
||||
def by_kid(self, kid: str):
|
||||
return self.filter(kid=kid).first()
|
||||
|
||||
|
||||
class MCPSigningKey(models.Model):
|
||||
"""HMAC signing key for per-turn JWTs minted by Daedalus.
|
||||
|
||||
Per-turn tokens carry workspace + library claims and expire in minutes.
|
||||
They are validated entirely off the signature + claims; no row is stored
|
||||
per token. Only the *signing key* is persisted here, indexed by ``kid``.
|
||||
|
||||
Rotation: seed a new active key, distribute the secret to Daedalus,
|
||||
flip the old one ``is_active=False``. In-flight tokens with the retired
|
||||
``kid`` fail at ``exp`` (bounded by the per-turn TTL).
|
||||
"""
|
||||
|
||||
kid = models.CharField(max_length=64, unique=True, db_index=True)
|
||||
secret_hex = models.CharField(max_length=128) # 256-bit secret = 64 hex
|
||||
is_active = models.BooleanField(default=True)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
retired_at = models.DateTimeField(null=True, blank=True)
|
||||
note = models.TextField(blank=True)
|
||||
|
||||
objects = MCPSigningKeyManager()
|
||||
|
||||
class Meta:
|
||||
ordering = ["-created_at"]
|
||||
|
||||
def __str__(self):
|
||||
suffix = "active" if self.is_active else "retired"
|
||||
return f"{self.kid} ({suffix})"
|
||||
|
||||
def retire(self):
|
||||
self.is_active = False
|
||||
self.retired_at = timezone.now()
|
||||
self.save(update_fields=["is_active", "retired_at"])
|
||||
|
||||
@@ -5,13 +5,29 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastmcp.server.context import Context
|
||||
|
||||
from ..context import get_mcp_claims
|
||||
from ..metrics import record_tool_call
|
||||
|
||||
DEFAULT_LIMIT = 50
|
||||
MAX_LIMIT = 200
|
||||
|
||||
|
||||
def _scope_from_claims(claims: dict | None,
|
||||
arg_workspace_id: str | None) -> tuple[str | None, list[str] | None]:
|
||||
"""Return (workspace_id, allowed_libraries) for a tool call.
|
||||
|
||||
Token claims, when present, trump tool args — that's the security
|
||||
contract. Opaque-token callers (no claims) keep the legacy behavior
|
||||
where the caller may pass workspace_id explicitly (typically null,
|
||||
yielding global scope).
|
||||
"""
|
||||
if claims is not None:
|
||||
return claims.get("ws"), claims.get("libs") or None
|
||||
return arg_workspace_id, None
|
||||
|
||||
|
||||
def _clamp(limit: int) -> int:
|
||||
if limit < 1:
|
||||
return 1
|
||||
@@ -25,6 +41,7 @@ def register_discovery_tools(mcp):
|
||||
offset: int = 0,
|
||||
# System-injected; deliberately absent from the docstring.
|
||||
workspace_id: str | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""List Mnemosyne libraries. Each library has a content-aware library_type
|
||||
(fiction, nonfiction, technical, music, film, art, journal, business,
|
||||
@@ -32,9 +49,11 @@ def register_discovery_tools(mcp):
|
||||
name, library_type, description for each library — use the uid or
|
||||
library_type to scope a subsequent search.
|
||||
"""
|
||||
claims = await get_mcp_claims(ctx)
|
||||
ws, libs = _scope_from_claims(claims, workspace_id)
|
||||
with record_tool_call("list_libraries"):
|
||||
return await sync_to_async(_query_libraries, thread_sensitive=True)(
|
||||
_clamp(limit), max(offset, 0), workspace_id
|
||||
_clamp(limit), max(offset, 0), ws, libs
|
||||
)
|
||||
|
||||
@mcp.tool
|
||||
@@ -44,15 +63,18 @@ def register_discovery_tools(mcp):
|
||||
offset: int = 0,
|
||||
# System-injected; deliberately absent from the docstring.
|
||||
workspace_id: str | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""List collections, optionally filtered by parent library_uid.
|
||||
Collections group related items inside a library (e.g. a series of novels,
|
||||
a multi-volume manual). Returns uid, name, description, library_uid,
|
||||
library_name. Use the uid to scope a subsequent search to one collection.
|
||||
"""
|
||||
claims = await get_mcp_claims(ctx)
|
||||
ws, libs = _scope_from_claims(claims, workspace_id)
|
||||
with record_tool_call("list_collections"):
|
||||
return await sync_to_async(_query_collections, thread_sensitive=True)(
|
||||
library_uid, _clamp(limit), max(offset, 0), workspace_id
|
||||
library_uid, _clamp(limit), max(offset, 0), ws, libs
|
||||
)
|
||||
|
||||
@mcp.tool
|
||||
@@ -63,6 +85,7 @@ def register_discovery_tools(mcp):
|
||||
offset: int = 0,
|
||||
# System-injected; deliberately absent from the docstring.
|
||||
workspace_id: str | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""List items (the indexed documents/files), optionally filtered by
|
||||
collection_uid or library_uid. Returns uid, title, item_type, file_type,
|
||||
@@ -70,20 +93,27 @@ def register_discovery_tools(mcp):
|
||||
document size; use embedding_status to skip items that are not yet
|
||||
searchable (only 'completed' items appear in search results).
|
||||
"""
|
||||
claims = await get_mcp_claims(ctx)
|
||||
ws, libs = _scope_from_claims(claims, workspace_id)
|
||||
with record_tool_call("list_items"):
|
||||
return await sync_to_async(_query_items, thread_sensitive=True)(
|
||||
collection_uid, library_uid, _clamp(limit), max(offset, 0), workspace_id
|
||||
collection_uid, library_uid, _clamp(limit), max(offset, 0), ws, libs
|
||||
)
|
||||
|
||||
|
||||
_WORKSPACE_SCOPE = (
|
||||
"($workspace_id IS NULL AND l.workspace_id IS NULL OR "
|
||||
"l.workspace_id = $workspace_id)"
|
||||
"(($workspace_id IS NOT NULL AND l.workspace_id = $workspace_id) "
|
||||
"OR ($allowed_libraries IS NOT NULL AND l.uid IN $allowed_libraries) "
|
||||
"OR ($workspace_id IS NULL AND $allowed_libraries IS NULL "
|
||||
" AND l.workspace_id IS NULL))"
|
||||
)
|
||||
|
||||
|
||||
def _query_libraries(
|
||||
limit: int, offset: int, workspace_id: str | None = None
|
||||
limit: int,
|
||||
offset: int,
|
||||
workspace_id: str | None = None,
|
||||
allowed_libraries: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
from neomodel import db
|
||||
|
||||
@@ -92,7 +122,11 @@ def _query_libraries(
|
||||
f"WHERE {_WORKSPACE_SCOPE} "
|
||||
"RETURN l.uid, l.name, l.library_type, l.description "
|
||||
"ORDER BY l.name SKIP $offset LIMIT $limit",
|
||||
{"offset": offset, "limit": limit, "workspace_id": workspace_id},
|
||||
{
|
||||
"offset": offset, "limit": limit,
|
||||
"workspace_id": workspace_id,
|
||||
"allowed_libraries": allowed_libraries,
|
||||
},
|
||||
)
|
||||
return {
|
||||
"libraries": [
|
||||
@@ -110,11 +144,19 @@ def _query_libraries(
|
||||
|
||||
|
||||
def _query_collections(
|
||||
library_uid: str | None, limit: int, offset: int,
|
||||
library_uid: str | None,
|
||||
limit: int,
|
||||
offset: int,
|
||||
workspace_id: str | None = None,
|
||||
allowed_libraries: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
from neomodel import db
|
||||
|
||||
base_params = {
|
||||
"offset": offset, "limit": limit,
|
||||
"workspace_id": workspace_id,
|
||||
"allowed_libraries": allowed_libraries,
|
||||
}
|
||||
if library_uid:
|
||||
cypher = (
|
||||
"MATCH (l:Library {uid: $library_uid})-[:CONTAINS]->(c:Collection) "
|
||||
@@ -122,10 +164,7 @@ def _query_collections(
|
||||
"RETURN c.uid, c.name, c.description, l.uid, l.name "
|
||||
"ORDER BY c.name SKIP $offset LIMIT $limit"
|
||||
)
|
||||
params = {
|
||||
"library_uid": library_uid, "offset": offset, "limit": limit,
|
||||
"workspace_id": workspace_id,
|
||||
}
|
||||
params = {**base_params, "library_uid": library_uid}
|
||||
else:
|
||||
cypher = (
|
||||
"MATCH (l:Library)-[:CONTAINS]->(c:Collection) "
|
||||
@@ -133,10 +172,7 @@ def _query_collections(
|
||||
"RETURN c.uid, c.name, c.description, l.uid, l.name "
|
||||
"ORDER BY l.name, c.name SKIP $offset LIMIT $limit"
|
||||
)
|
||||
params = {
|
||||
"offset": offset, "limit": limit,
|
||||
"workspace_id": workspace_id,
|
||||
}
|
||||
params = base_params
|
||||
|
||||
rows, _ = db.cypher_query(cypher, params)
|
||||
return {
|
||||
@@ -161,6 +197,7 @@ def _query_items(
|
||||
limit: int,
|
||||
offset: int,
|
||||
workspace_id: str | None = None,
|
||||
allowed_libraries: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
from neomodel import db
|
||||
|
||||
@@ -168,6 +205,7 @@ def _query_items(
|
||||
params: dict[str, Any] = {
|
||||
"offset": offset, "limit": limit,
|
||||
"workspace_id": workspace_id,
|
||||
"allowed_libraries": allowed_libraries,
|
||||
}
|
||||
if collection_uid:
|
||||
where.append("c.uid = $collection_uid")
|
||||
|
||||
@@ -10,9 +10,18 @@ from django.conf import settings
|
||||
from django.core.files.storage import default_storage
|
||||
from fastmcp.server.context import Context
|
||||
|
||||
from ..context import get_mcp_user
|
||||
from ..context import get_mcp_claims, get_mcp_user
|
||||
from ..metrics import record_tool_call
|
||||
|
||||
|
||||
def _scope_from_claims(claims: dict | None,
|
||||
arg_workspace_id: str | None) -> tuple[str | None, list[str] | None]:
|
||||
"""Return (workspace_id, allowed_libraries) for a tool call. Token claims
|
||||
trump tool args when present."""
|
||||
if claims is not None:
|
||||
return claims.get("ws"), claims.get("libs") or None
|
||||
return arg_workspace_id, None
|
||||
|
||||
DEFAULT_SEARCH_TYPES = ["vector", "fulltext", "graph"]
|
||||
|
||||
|
||||
@@ -47,6 +56,8 @@ def register_search_tools(mcp):
|
||||
score, and source. Also returns matching images when include_images=True.
|
||||
"""
|
||||
types = search_types or DEFAULT_SEARCH_TYPES
|
||||
claims = await get_mcp_claims(ctx)
|
||||
ws, libs = _scope_from_claims(claims, workspace_id)
|
||||
with record_tool_call("search"):
|
||||
user = await get_mcp_user(ctx)
|
||||
return await sync_to_async(_run_search, thread_sensitive=True)(
|
||||
@@ -55,7 +66,8 @@ def register_search_tools(mcp):
|
||||
library_uid=library_uid,
|
||||
library_type=library_type,
|
||||
collection_uid=collection_uid,
|
||||
workspace_id=workspace_id,
|
||||
workspace_id=ws,
|
||||
allowed_libraries=libs,
|
||||
limit=limit,
|
||||
rerank=rerank,
|
||||
include_images=include_images,
|
||||
@@ -75,14 +87,17 @@ def register_search_tools(mcp):
|
||||
item_uid, item_title, library_type, text. Use this when the 500-character
|
||||
text_preview from `search` isn't enough.
|
||||
"""
|
||||
claims = await get_mcp_claims(ctx)
|
||||
ws, libs = _scope_from_claims(claims, workspace_id)
|
||||
with record_tool_call("get_chunk"):
|
||||
return await sync_to_async(_load_chunk, thread_sensitive=True)(
|
||||
chunk_uid, workspace_id
|
||||
chunk_uid, ws, libs
|
||||
)
|
||||
|
||||
|
||||
def _run_search(*, user, query, library_uid, library_type, collection_uid,
|
||||
workspace_id, limit, rerank, include_images, search_types) -> dict[str, Any]:
|
||||
workspace_id, allowed_libraries, limit, rerank, include_images,
|
||||
search_types) -> dict[str, Any]:
|
||||
from library.services.search import SearchRequest, SearchService
|
||||
|
||||
req = SearchRequest(
|
||||
@@ -91,6 +106,7 @@ def _run_search(*, user, query, library_uid, library_type, collection_uid,
|
||||
library_type=library_type,
|
||||
collection_uid=collection_uid,
|
||||
workspace_id=workspace_id,
|
||||
allowed_libraries=allowed_libraries,
|
||||
search_types=search_types,
|
||||
limit=limit,
|
||||
vector_top_k=getattr(settings, "SEARCH_VECTOR_TOP_K", 50),
|
||||
@@ -112,17 +128,27 @@ def _run_search(*, user, query, library_uid, library_type, collection_uid,
|
||||
}
|
||||
|
||||
|
||||
def _load_chunk(chunk_uid: str, workspace_id: str | None = None) -> dict[str, Any]:
|
||||
def _load_chunk(
|
||||
chunk_uid: str,
|
||||
workspace_id: str | None = None,
|
||||
allowed_libraries: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
from neomodel import db
|
||||
|
||||
rows, _ = db.cypher_query(
|
||||
"MATCH (l:Library)-[:CONTAINS]->(:Collection)-[:CONTAINS]->"
|
||||
"(i:Item)-[:HAS_CHUNK]->(c:Chunk {uid: $uid}) "
|
||||
"WHERE ($workspace_id IS NULL AND l.workspace_id IS NULL OR "
|
||||
" l.workspace_id = $workspace_id) "
|
||||
"WHERE (($workspace_id IS NOT NULL AND l.workspace_id = $workspace_id) "
|
||||
" OR ($allowed_libraries IS NOT NULL AND l.uid IN $allowed_libraries) "
|
||||
" OR ($workspace_id IS NULL AND $allowed_libraries IS NULL "
|
||||
" AND l.workspace_id IS NULL)) "
|
||||
"RETURN c.uid, c.chunk_index, c.chunk_s3_key, "
|
||||
"i.uid, i.title, l.library_type LIMIT 1",
|
||||
{"uid": chunk_uid, "workspace_id": workspace_id},
|
||||
{
|
||||
"uid": chunk_uid,
|
||||
"workspace_id": workspace_id,
|
||||
"allowed_libraries": allowed_libraries,
|
||||
},
|
||||
)
|
||||
if not rows:
|
||||
raise ValueError(f"Chunk not found: {chunk_uid}")
|
||||
|
||||
Reference in New Issue
Block a user