From a2c885cf34e73f05b2c2ba862f7cea94a6eae2e5 Mon Sep 17 00:00:00 2001 From: Robert Helewka Date: Sun, 3 May 2026 17:36:06 -0400 Subject: [PATCH] feat(library): add workspace-scoped search and JWT auth for Daedalus - Extend library list endpoint with `include_workspace` and `with_item_count` query params to support Daedalus registry mirroring - Expand search scope clause to three modes: workspace-only, workspace plus allowed user libraries, and global - Add `allowed_libraries` field to SearchRequest for Phase-2 JWT claims - Introduce JWT-based actor resolution using a synthetic service user (`MCP_JWT_SERVICE_USERNAME`) for Daedalus-originated requests --- mnemosyne/library/api/views.py | 39 +++- mnemosyne/library/services/search.py | 38 +++- mnemosyne/mcp_server/auth.py | 203 ++++++++++++++++-- mnemosyne/mcp_server/context.py | 9 +- .../commands/ensure_service_user.py | 48 +++++ .../management/commands/seed_signing_key.py | 71 ++++++ .../migrations/0003_mcpsigningkey.py | 38 ++++ mnemosyne/mcp_server/models.py | 43 ++++ mnemosyne/mcp_server/tools/discovery.py | 70 ++++-- mnemosyne/mcp_server/tools/search.py | 42 +++- pyproject.toml | 2 + 11 files changed, 555 insertions(+), 48 deletions(-) create mode 100644 mnemosyne/mcp_server/management/commands/ensure_service_user.py create mode 100644 mnemosyne/mcp_server/management/commands/seed_signing_key.py create mode 100644 mnemosyne/mcp_server/migrations/0003_mcpsigningkey.py diff --git a/mnemosyne/library/api/views.py b/mnemosyne/library/api/views.py index 26ad318..748d652 100644 --- a/mnemosyne/library/api/views.py +++ b/mnemosyne/library/api/views.py @@ -40,13 +40,44 @@ logger = logging.getLogger(__name__) @api_view(["GET", "POST"]) @permission_classes([IsAuthenticated]) def library_list_create(request): - """List all libraries or create a new one.""" + """List all libraries or create a new one. + + GET supports ``?include_workspace=false`` (default ``true``) to filter out + libraries that belong to a Daedalus workspace — the Daedalus library + registry uses this to mirror only the user-managed catalog. + + GET supports ``?with_item_count=true`` (default ``false``) to attach + a per-library ``item_count``. Off by default because the count is a + Cypher aggregate; on for the Daedalus-side registry poll. + """ from library.models import Library if request.method == "GET": - libraries = Library.nodes.order_by("name") - serializer = LibrarySerializer(libraries, many=True) - return Response(serializer.data) + include_workspace = request.GET.get("include_workspace", "true").lower() != "false" + with_count = request.GET.get("with_item_count", "false").lower() == "true" + + if include_workspace: + libraries = list(Library.nodes.order_by("name")) + else: + libraries = list(Library.nodes.filter(workspace_id__isnull=True).order_by("name")) + + data = LibrarySerializer(libraries, many=True).data + + if with_count and libraries: + from neomodel import db + + uids = [lib.uid for lib in libraries] + rows, _ = db.cypher_query( + "MATCH (l:Library) WHERE l.uid IN $uids " + "OPTIONAL MATCH (l)-[:CONTAINS]->(:Collection)-[:CONTAINS]->(i:Item) " + "RETURN l.uid, count(i)", + {"uids": uids}, + ) + counts = {uid: count for (uid, count) in rows} + for entry in data: + entry["item_count"] = counts.get(entry["uid"], 0) + + return Response(data) # POST — create serializer = LibrarySerializer(data=request.data) diff --git a/mnemosyne/library/services/search.py b/mnemosyne/library/services/search.py index bd114f2..c4d4ea0 100644 --- a/mnemosyne/library/services/search.py +++ b/mnemosyne/library/services/search.py @@ -26,14 +26,26 @@ from .fusion import ImageSearchResult, SearchCandidate, reciprocal_rank_fusion logger = logging.getLogger(__name__) -# Workspace scoping clause appended to every search Cypher query. +# Search-scope clause appended to every search Cypher query. # -# A request with workspace_id set returns ONLY that workspace's content. -# A request with workspace_id null returns ONLY global content (libraries -# with no workspace_id). There is no third mode. +# Three modes, picked structurally by which params are set: +# +# 1. ``workspace_id`` set, ``allowed_libraries`` empty → workspace-scoped. +# Returns ONLY content from libraries whose workspace_id matches. +# 2. ``workspace_id`` set + ``allowed_libraries`` non-empty → workspace +# PLUS the listed user-managed libraries (typical Phase-2 chat turn). +# 3. Both null → global. Returns ONLY libraries with no workspace_id +# (legacy opaque-token callers / dashboard). +# +# When ``allowed_libraries`` is non-empty alone (no workspace_id), it +# narrows results to those libraries. _WORKSPACE_SCOPE_CLAUSE = ( - " AND ($workspace_id IS NULL AND lib.workspace_id IS NULL OR " - "lib.workspace_id = $workspace_id)" + " AND (" + "($workspace_id IS NOT NULL AND lib.workspace_id = $workspace_id) " + "OR ($allowed_libraries IS NOT NULL AND lib.uid IN $allowed_libraries) " + "OR ($workspace_id IS NULL AND $allowed_libraries IS NULL " + " AND lib.workspace_id IS NULL)" + ")" ) @@ -52,6 +64,10 @@ class SearchRequest: library_type: Optional[str] = None collection_uid: Optional[str] = None workspace_id: Optional[str] = None + # Phase-2 token claim: user-managed libraries the caller may include + # alongside their workspace's auto-library. Cypher uses ``IS NULL`` vs + # non-empty list to gate the second branch of the scope clause. + allowed_libraries: Optional[list[str]] = None search_types: list[str] = field( default_factory=lambda: ["vector", "fulltext", "graph"] ) @@ -73,6 +89,11 @@ class SearchRequest: self.library_type = None if self.collection_uid == "": self.collection_uid = None + # Empty list collapses to None so the Cypher branch reads + # "$allowed_libraries IS NOT NULL" rather than "size > 0" — keeps + # the parameter binding straightforward and the predicate sargable. + if self.allowed_libraries is not None and len(self.allowed_libraries) == 0: + self.allowed_libraries = None @dataclass @@ -300,6 +321,7 @@ class SearchService: "library_type": request.library_type, "collection_uid": request.collection_uid, "workspace_id": request.workspace_id, + "allowed_libraries": request.allowed_libraries, } try: @@ -411,6 +433,7 @@ class SearchService: "library_type": request.library_type, "collection_uid": request.collection_uid, "workspace_id": request.workspace_id, + "allowed_libraries": request.allowed_libraries, } try: @@ -471,6 +494,7 @@ class SearchService: "library_uid": request.library_uid, "library_type": request.library_type, "workspace_id": request.workspace_id, + "allowed_libraries": request.allowed_libraries, } try: @@ -546,6 +570,7 @@ class SearchService: "library_uid": request.library_uid, "library_type": request.library_type, "workspace_id": request.workspace_id, + "allowed_libraries": request.allowed_libraries, } try: @@ -630,6 +655,7 @@ class SearchService: "library_uid": request.library_uid, "library_type": request.library_type, "workspace_id": request.workspace_id, + "allowed_libraries": request.allowed_libraries, } try: diff --git a/mnemosyne/mcp_server/auth.py b/mnemosyne/mcp_server/auth.py index 075e287..36dc989 100644 --- a/mnemosyne/mcp_server/auth.py +++ b/mnemosyne/mcp_server/auth.py @@ -1,31 +1,68 @@ -"""MCP token resolution and FastMCP middleware for bearer-token auth.""" +"""MCP token resolution and FastMCP middleware for bearer-token auth. + +Two token shapes are supported: + +* **Opaque** — the original `MCPToken` row. Long-lived, hashed at rest, + used by the dashboard / Claude Desktop / admin tooling. Plaintext + hashes to a row in `mcp_token`. +* **Signed JWT** — per-turn token minted by Daedalus. Carries + `{ws, libs}` claims. Validated entirely off the signature + claims; + no database lookup of the token itself, only of the signing key + (`MCPSigningKey`) referenced by the JWT header's `kid`. + +Detection: a bearer with three base64url segments separated by dots and +a parseable `{"alg":"HS256","kid":...}` header is treated as JWT; anything +else falls through to the opaque path. +""" from __future__ import annotations +import base64 +import hashlib +import json import logging +import time +from collections import OrderedDict +import jwt as pyjwt from asgiref.sync import sync_to_async from django.conf import settings -from django.contrib.auth import get_user_model from django.utils import timezone from fastmcp.server.dependencies import get_http_request from fastmcp.server.middleware import Middleware, MiddlewareContext from .metrics import mcp_auth_failures_total -from .models import MCPToken, hash_token +from .models import MCPSigningKey, MCPToken, hash_token logger = logging.getLogger(__name__) STATE_KEY_USER = "mcp_user" STATE_KEY_TOKEN = "mcp_token" +STATE_KEY_CLAIMS = "mcp_claims" + +# Permitted clock skew when validating JWT exp/iat. PyJWT applies this +# symmetrically as ``leeway``. +_JWT_LEEWAY_SECONDS = 30 + +# Mnemosyne is the audience; Daedalus is the only accepted issuer. +_JWT_ISS = "daedalus" + +# Bounded LRU of recently-seen jti values to discourage replay within +# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this +# is best-effort and per-process. +_JTI_CACHE_MAX = 4096 +_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict() class MCPAuthError(Exception): """Raised when a bearer token cannot be resolved to a valid user.""" +# --- Opaque token path ----------------------------------------------------- + + def resolve_mcp_user(token_string: str): - """Resolve a bearer token to (user, MCPToken). Raises MCPAuthError on any failure. + """Resolve an opaque bearer token to (user, MCPToken). Hashes the incoming bearer and looks up by the hash — plaintext is never stored or compared directly. @@ -50,6 +87,110 @@ def resolve_mcp_user(token_string: str): return token.user, token +# --- JWT path -------------------------------------------------------------- + + +def looks_like_jwt(token_string: str) -> bool: + """Cheap pre-check: 3 dot-separated segments and a valid base64url header + that decodes to JSON with an alg field.""" + parts = token_string.split(".") + if len(parts) != 3: + return False + try: + header_bytes = _b64url_decode(parts[0]) + header = json.loads(header_bytes) + except (ValueError, json.JSONDecodeError): + return False + return isinstance(header, dict) and isinstance(header.get("alg"), str) + + +def _b64url_decode(segment: str) -> bytes: + pad = "=" * (-len(segment) % 4) + return base64.urlsafe_b64decode(segment + pad) + + +def _remember_jti(jti: str) -> bool: + """Return True if this jti has been seen before in this process.""" + now = time.time() + # Drop any cached jti whose entry is older than 1h — generous given exp ≤ 600s. + cutoff = now - 3600 + while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < cutoff: + _JTI_CACHE.popitem(last=False) + if jti in _JTI_CACHE: + return True + if len(_JTI_CACHE) >= _JTI_CACHE_MAX: + _JTI_CACHE.popitem(last=False) + _JTI_CACHE[jti] = now + return False + + +def resolve_mcp_jwt(token_string: str) -> dict: + """Validate a signed JWT and return its claims dict. + + Raises ``MCPAuthError`` on any failure. Does not touch ``MCPToken`` — + JWTs are stateless and stored only as their signing key (``MCPSigningKey``). + """ + try: + unverified_header = pyjwt.get_unverified_header(token_string) + except pyjwt.PyJWTError as exc: + raise MCPAuthError(f"Malformed JWT header: {exc}") + + kid = unverified_header.get("kid") + if not kid: + raise MCPAuthError("JWT header missing kid.") + if unverified_header.get("alg") != "HS256": + raise MCPAuthError("Unsupported JWT alg.") + + key = MCPSigningKey.objects.by_kid(kid) + if key is None: + raise MCPAuthError(f"Unknown signing key kid={kid!r}.") + if not key.is_active: + raise MCPAuthError(f"Signing key {kid!r} has been retired.") + + try: + secret = bytes.fromhex(key.secret_hex) + except ValueError: + raise MCPAuthError(f"Stored secret for kid={kid!r} is not valid hex.") + + try: + claims = pyjwt.decode( + token_string, + secret, + algorithms=["HS256"], + leeway=_JWT_LEEWAY_SECONDS, + options={"require": ["exp", "iat", "iss", "sub", "jti"]}, + issuer=_JWT_ISS, + ) + except pyjwt.ExpiredSignatureError: + raise MCPAuthError("Token has expired.") + except pyjwt.InvalidIssuerError: + raise MCPAuthError("Invalid token issuer.") + except pyjwt.InvalidSignatureError: + raise MCPAuthError("Invalid MCP token.") + except pyjwt.MissingRequiredClaimError as exc: + raise MCPAuthError(f"JWT missing required claim: {exc.claim}") + except pyjwt.PyJWTError as exc: + raise MCPAuthError(f"Invalid JWT: {exc}") + + jti = claims.get("jti") + if not isinstance(jti, str) or not jti: + raise MCPAuthError("JWT jti must be a non-empty string.") + if _remember_jti(jti): + raise MCPAuthError("Token replay detected.") + + # Normalize claim shapes: ws may be null/absent, libs default to []. + claims["ws"] = claims.get("ws") or None + libs = claims.get("libs") or [] + if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs): + raise MCPAuthError("JWT libs must be a list of strings.") + claims["libs"] = libs + claims["kid"] = kid + return claims + + +# --- Middleware ------------------------------------------------------------ + + class MCPAuthMiddleware(Middleware): """ FastMCP middleware that authenticates tool calls via Bearer tokens. @@ -59,20 +200,27 @@ class MCPAuthMiddleware(Middleware): MCP_REQUIRE_AUTH=False. """ - async def on_call_tool( - self, context: MiddlewareContext, call_next - ): + async def on_call_tool(self, context: MiddlewareContext, call_next): require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True) token_string = self._extract_token() user = None token = None + claims: dict | None = None if token_string: try: - user, token = await sync_to_async( - resolve_mcp_user, thread_sensitive=True - )(token_string) + if looks_like_jwt(token_string): + claims = await sync_to_async( + resolve_mcp_jwt, thread_sensitive=True + )(token_string) + user = await sync_to_async( + _resolve_jwt_actor, thread_sensitive=True + )(claims) + else: + user, token = await sync_to_async( + resolve_mcp_user, thread_sensitive=True + )(token_string) except MCPAuthError as exc: mcp_auth_failures_total.labels(reason=str(exc)).inc() if require_auth: @@ -89,9 +237,13 @@ class MCPAuthMiddleware(Middleware): ) fastmcp_ctx = getattr(context, "fastmcp_context", None) - if fastmcp_ctx and user is not None: - await fastmcp_ctx.set_state(STATE_KEY_USER, user) - await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token) + if fastmcp_ctx is not None: + if user is not None: + await fastmcp_ctx.set_state(STATE_KEY_USER, user) + if token is not None: + await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token) + if claims is not None: + await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims) return await call_next(context) @@ -111,3 +263,28 @@ class MCPAuthMiddleware(Middleware): msg = getattr(context, "message", None) params = getattr(msg, "params", None) if msg else None return getattr(params, "name", None) + + +# --- Helpers --------------------------------------------------------------- + + +def _resolve_jwt_actor(claims: dict): + """Resolve the synthetic actor for a JWT-authenticated turn. + + Returns the system service user (``MCP_JWT_SERVICE_USERNAME``, default + ``daedalus-service``). The user must exist and be active. JWT tokens + are not tied to per-user accounts — claims encode all authorization. + """ + from django.contrib.auth import get_user_model + + User = get_user_model() + username = getattr(settings, "MCP_JWT_SERVICE_USERNAME", "daedalus-service") + try: + user = User.objects.get(username=username) + except User.DoesNotExist: + raise MCPAuthError( + f"JWT service user {username!r} does not exist; provision via management command." + ) + if not user.is_active: + raise MCPAuthError(f"JWT service user {username!r} is disabled.") + return user diff --git a/mnemosyne/mcp_server/context.py b/mnemosyne/mcp_server/context.py index 9a8d11f..dbf1c52 100644 --- a/mnemosyne/mcp_server/context.py +++ b/mnemosyne/mcp_server/context.py @@ -4,7 +4,7 @@ from __future__ import annotations from fastmcp.server.context import Context -from .auth import STATE_KEY_TOKEN, STATE_KEY_USER +from .auth import STATE_KEY_CLAIMS, STATE_KEY_TOKEN, STATE_KEY_USER async def get_mcp_user(ctx: Context | None): @@ -17,3 +17,10 @@ async def get_mcp_token(ctx: Context | None): if ctx is None: return None return await ctx.get_state(STATE_KEY_TOKEN) + + +async def get_mcp_claims(ctx: Context | None) -> dict | None: + """Return the JWT claims dict for this request, or None for opaque-token callers.""" + if ctx is None: + return None + return await ctx.get_state(STATE_KEY_CLAIMS) diff --git a/mnemosyne/mcp_server/management/commands/ensure_service_user.py b/mnemosyne/mcp_server/management/commands/ensure_service_user.py new file mode 100644 index 0000000..d42147b --- /dev/null +++ b/mnemosyne/mcp_server/management/commands/ensure_service_user.py @@ -0,0 +1,48 @@ +"""Idempotently create the service user that JWT-authenticated MCP requests act as. + +Daedalus mints per-turn JWTs whose claims encode all authorization (workspace, +allowed libraries). The Django ``user`` field on the request still needs to +point at *something* — the service user is that something. It owns no data +and does not log in via the dashboard. +""" + +import secrets + +from django.contrib.auth import get_user_model +from django.core.management.base import BaseCommand + + +class Command(BaseCommand): + help = "Idempotently create or reactivate the JWT service user (default 'daedalus-service')." + + def add_arguments(self, parser): + parser.add_argument("--username", default="daedalus-service") + parser.add_argument("--email", default="daedalus-service@local") + + def handle(self, *args, **options): + User = get_user_model() + username = options["username"] + email = options["email"] + + user, created = User.objects.get_or_create( + username=username, + defaults={"email": email, "is_active": True}, + ) + if created: + # Set a random password the user cannot log in with via the UI. + user.set_password(secrets.token_urlsafe(32)) + user.save(update_fields=["password"]) + self.stdout.write(self.style.SUCCESS(f"Created service user {username!r}")) + else: + changed = False + if not user.is_active: + user.is_active = True + changed = True + if user.email != email: + user.email = email + changed = True + if changed: + user.save(update_fields=["is_active", "email"]) + self.stdout.write(self.style.SUCCESS(f"Updated service user {username!r}")) + else: + self.stdout.write(f"Service user {username!r} already provisioned") diff --git a/mnemosyne/mcp_server/management/commands/seed_signing_key.py b/mnemosyne/mcp_server/management/commands/seed_signing_key.py new file mode 100644 index 0000000..745dcbf --- /dev/null +++ b/mnemosyne/mcp_server/management/commands/seed_signing_key.py @@ -0,0 +1,71 @@ +"""Mint a new MCPSigningKey for per-turn JWTs. + +The secret is generated locally, printed once, and stored hex-encoded. +Distribute the printed secret to the issuer (Daedalus) — after this +command exits, the database row is the only copy on the Mnemosyne +side, but copying the secret out of the row is allowed (admin model +exposes it). Treat it like any other shared secret. + +Use ``--retire-other`` to flip every other key to ``is_active=False`` +in the same transaction (typical rotation flow). Without it, the new +key joins the active set and old keys keep validating until they are +retired explicitly. +""" + +import secrets + +from django.core.management.base import BaseCommand, CommandError +from django.db import transaction + +from mcp_server.models import MCPSigningKey + + +class Command(BaseCommand): + help = "Generate a new HMAC signing key for per-turn JWTs and print it once." + + def add_arguments(self, parser): + parser.add_argument( + "--kid", + required=True, + help="Key ID (short, stable string — e.g. 'daedalus-1').", + ) + parser.add_argument( + "--note", + default="", + help="Optional note (e.g. 'rotated 2026-05-03 after staging incident').", + ) + parser.add_argument( + "--retire-other", + action="store_true", + help=( + "Mark every other MCPSigningKey row inactive in the same " + "transaction. Use during rotation." + ), + ) + + @transaction.atomic + def handle(self, *args, **options): + kid = options["kid"] + if MCPSigningKey.objects.filter(kid=kid).exists(): + raise CommandError(f"kid {kid!r} already exists; choose a unique value.") + + # 256-bit secret, hex-encoded for portability. + secret = secrets.token_hex(32) + + if options["retire_other"]: + from django.utils import timezone + MCPSigningKey.objects.filter(is_active=True).update( + is_active=False, retired_at=timezone.now() + ) + + key = MCPSigningKey.objects.create( + kid=kid, secret_hex=secret, is_active=True, note=options["note"] + ) + + self.stdout.write(self.style.SUCCESS("MCP signing key created")) + self.stdout.write(f" kid: {key.kid}") + self.stdout.write(f" active: {key.is_active}") + if options["retire_other"]: + self.stdout.write(" (other active keys retired)") + self.stdout.write(self.style.WARNING(" Secret (256-bit hex — share with Daedalus):")) + self.stdout.write(f" {secret}") diff --git a/mnemosyne/mcp_server/migrations/0003_mcpsigningkey.py b/mnemosyne/mcp_server/migrations/0003_mcpsigningkey.py new file mode 100644 index 0000000..8eeb3c5 --- /dev/null +++ b/mnemosyne/mcp_server/migrations/0003_mcpsigningkey.py @@ -0,0 +1,38 @@ +"""HMAC signing keys for per-turn JWTs minted by Daedalus. + +Adds the MCPSigningKey table. Per-turn tokens (workspace + library claims, +exp <= 600s) are not stored — only the signing key, indexed by ``kid``, +so the signature can be validated and rotated cleanly. +""" + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("mcp_server", "0002_hash_token"), + ] + + operations = [ + migrations.CreateModel( + name="MCPSigningKey", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("kid", models.CharField(max_length=64, unique=True, db_index=True)), + ("secret_hex", models.CharField(max_length=128)), + ("is_active", models.BooleanField(default=True)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("retired_at", models.DateTimeField(blank=True, null=True)), + ("note", models.TextField(blank=True)), + ], + options={"ordering": ["-created_at"]}, + ), + ] diff --git a/mnemosyne/mcp_server/models.py b/mnemosyne/mcp_server/models.py index aaeeb57..d5426dd 100644 --- a/mnemosyne/mcp_server/models.py +++ b/mnemosyne/mcp_server/models.py @@ -84,3 +84,46 @@ class MCPToken(models.Model): hash prefixed with `mcp_…`. Stable per token, never reveals plaintext. """ return f"mcp_…{self.token_hash[:8]}" + + +class MCPSigningKeyManager(models.Manager): + def active(self): + """Active keys, newest first. Multiple may overlap during rotation.""" + return self.filter(is_active=True).order_by("-created_at") + + def by_kid(self, kid: str): + return self.filter(kid=kid).first() + + +class MCPSigningKey(models.Model): + """HMAC signing key for per-turn JWTs minted by Daedalus. + + Per-turn tokens carry workspace + library claims and expire in minutes. + They are validated entirely off the signature + claims; no row is stored + per token. Only the *signing key* is persisted here, indexed by ``kid``. + + Rotation: seed a new active key, distribute the secret to Daedalus, + flip the old one ``is_active=False``. In-flight tokens with the retired + ``kid`` fail at ``exp`` (bounded by the per-turn TTL). + """ + + kid = models.CharField(max_length=64, unique=True, db_index=True) + secret_hex = models.CharField(max_length=128) # 256-bit secret = 64 hex + is_active = models.BooleanField(default=True) + created_at = models.DateTimeField(auto_now_add=True) + retired_at = models.DateTimeField(null=True, blank=True) + note = models.TextField(blank=True) + + objects = MCPSigningKeyManager() + + class Meta: + ordering = ["-created_at"] + + def __str__(self): + suffix = "active" if self.is_active else "retired" + return f"{self.kid} ({suffix})" + + def retire(self): + self.is_active = False + self.retired_at = timezone.now() + self.save(update_fields=["is_active", "retired_at"]) diff --git a/mnemosyne/mcp_server/tools/discovery.py b/mnemosyne/mcp_server/tools/discovery.py index 413c2f5..0a08d6d 100644 --- a/mnemosyne/mcp_server/tools/discovery.py +++ b/mnemosyne/mcp_server/tools/discovery.py @@ -5,13 +5,29 @@ from __future__ import annotations from typing import Any from asgiref.sync import sync_to_async +from fastmcp.server.context import Context +from ..context import get_mcp_claims from ..metrics import record_tool_call DEFAULT_LIMIT = 50 MAX_LIMIT = 200 +def _scope_from_claims(claims: dict | None, + arg_workspace_id: str | None) -> tuple[str | None, list[str] | None]: + """Return (workspace_id, allowed_libraries) for a tool call. + + Token claims, when present, trump tool args — that's the security + contract. Opaque-token callers (no claims) keep the legacy behavior + where the caller may pass workspace_id explicitly (typically null, + yielding global scope). + """ + if claims is not None: + return claims.get("ws"), claims.get("libs") or None + return arg_workspace_id, None + + def _clamp(limit: int) -> int: if limit < 1: return 1 @@ -25,6 +41,7 @@ def register_discovery_tools(mcp): offset: int = 0, # System-injected; deliberately absent from the docstring. workspace_id: str | None = None, + ctx: Context | None = None, ) -> dict[str, Any]: """List Mnemosyne libraries. Each library has a content-aware library_type (fiction, nonfiction, technical, music, film, art, journal, business, @@ -32,9 +49,11 @@ def register_discovery_tools(mcp): name, library_type, description for each library — use the uid or library_type to scope a subsequent search. """ + claims = await get_mcp_claims(ctx) + ws, libs = _scope_from_claims(claims, workspace_id) with record_tool_call("list_libraries"): return await sync_to_async(_query_libraries, thread_sensitive=True)( - _clamp(limit), max(offset, 0), workspace_id + _clamp(limit), max(offset, 0), ws, libs ) @mcp.tool @@ -44,15 +63,18 @@ def register_discovery_tools(mcp): offset: int = 0, # System-injected; deliberately absent from the docstring. workspace_id: str | None = None, + ctx: Context | None = None, ) -> dict[str, Any]: """List collections, optionally filtered by parent library_uid. Collections group related items inside a library (e.g. a series of novels, a multi-volume manual). Returns uid, name, description, library_uid, library_name. Use the uid to scope a subsequent search to one collection. """ + claims = await get_mcp_claims(ctx) + ws, libs = _scope_from_claims(claims, workspace_id) with record_tool_call("list_collections"): return await sync_to_async(_query_collections, thread_sensitive=True)( - library_uid, _clamp(limit), max(offset, 0), workspace_id + library_uid, _clamp(limit), max(offset, 0), ws, libs ) @mcp.tool @@ -63,6 +85,7 @@ def register_discovery_tools(mcp): offset: int = 0, # System-injected; deliberately absent from the docstring. workspace_id: str | None = None, + ctx: Context | None = None, ) -> dict[str, Any]: """List items (the indexed documents/files), optionally filtered by collection_uid or library_uid. Returns uid, title, item_type, file_type, @@ -70,20 +93,27 @@ def register_discovery_tools(mcp): document size; use embedding_status to skip items that are not yet searchable (only 'completed' items appear in search results). """ + claims = await get_mcp_claims(ctx) + ws, libs = _scope_from_claims(claims, workspace_id) with record_tool_call("list_items"): return await sync_to_async(_query_items, thread_sensitive=True)( - collection_uid, library_uid, _clamp(limit), max(offset, 0), workspace_id + collection_uid, library_uid, _clamp(limit), max(offset, 0), ws, libs ) _WORKSPACE_SCOPE = ( - "($workspace_id IS NULL AND l.workspace_id IS NULL OR " - "l.workspace_id = $workspace_id)" + "(($workspace_id IS NOT NULL AND l.workspace_id = $workspace_id) " + "OR ($allowed_libraries IS NOT NULL AND l.uid IN $allowed_libraries) " + "OR ($workspace_id IS NULL AND $allowed_libraries IS NULL " + " AND l.workspace_id IS NULL))" ) def _query_libraries( - limit: int, offset: int, workspace_id: str | None = None + limit: int, + offset: int, + workspace_id: str | None = None, + allowed_libraries: list[str] | None = None, ) -> dict[str, Any]: from neomodel import db @@ -92,7 +122,11 @@ def _query_libraries( f"WHERE {_WORKSPACE_SCOPE} " "RETURN l.uid, l.name, l.library_type, l.description " "ORDER BY l.name SKIP $offset LIMIT $limit", - {"offset": offset, "limit": limit, "workspace_id": workspace_id}, + { + "offset": offset, "limit": limit, + "workspace_id": workspace_id, + "allowed_libraries": allowed_libraries, + }, ) return { "libraries": [ @@ -110,11 +144,19 @@ def _query_libraries( def _query_collections( - library_uid: str | None, limit: int, offset: int, + library_uid: str | None, + limit: int, + offset: int, workspace_id: str | None = None, + allowed_libraries: list[str] | None = None, ) -> dict[str, Any]: from neomodel import db + base_params = { + "offset": offset, "limit": limit, + "workspace_id": workspace_id, + "allowed_libraries": allowed_libraries, + } if library_uid: cypher = ( "MATCH (l:Library {uid: $library_uid})-[:CONTAINS]->(c:Collection) " @@ -122,10 +164,7 @@ def _query_collections( "RETURN c.uid, c.name, c.description, l.uid, l.name " "ORDER BY c.name SKIP $offset LIMIT $limit" ) - params = { - "library_uid": library_uid, "offset": offset, "limit": limit, - "workspace_id": workspace_id, - } + params = {**base_params, "library_uid": library_uid} else: cypher = ( "MATCH (l:Library)-[:CONTAINS]->(c:Collection) " @@ -133,10 +172,7 @@ def _query_collections( "RETURN c.uid, c.name, c.description, l.uid, l.name " "ORDER BY l.name, c.name SKIP $offset LIMIT $limit" ) - params = { - "offset": offset, "limit": limit, - "workspace_id": workspace_id, - } + params = base_params rows, _ = db.cypher_query(cypher, params) return { @@ -161,6 +197,7 @@ def _query_items( limit: int, offset: int, workspace_id: str | None = None, + allowed_libraries: list[str] | None = None, ) -> dict[str, Any]: from neomodel import db @@ -168,6 +205,7 @@ def _query_items( params: dict[str, Any] = { "offset": offset, "limit": limit, "workspace_id": workspace_id, + "allowed_libraries": allowed_libraries, } if collection_uid: where.append("c.uid = $collection_uid") diff --git a/mnemosyne/mcp_server/tools/search.py b/mnemosyne/mcp_server/tools/search.py index d428eb9..3647b5b 100644 --- a/mnemosyne/mcp_server/tools/search.py +++ b/mnemosyne/mcp_server/tools/search.py @@ -10,9 +10,18 @@ from django.conf import settings from django.core.files.storage import default_storage from fastmcp.server.context import Context -from ..context import get_mcp_user +from ..context import get_mcp_claims, get_mcp_user from ..metrics import record_tool_call + +def _scope_from_claims(claims: dict | None, + arg_workspace_id: str | None) -> tuple[str | None, list[str] | None]: + """Return (workspace_id, allowed_libraries) for a tool call. Token claims + trump tool args when present.""" + if claims is not None: + return claims.get("ws"), claims.get("libs") or None + return arg_workspace_id, None + DEFAULT_SEARCH_TYPES = ["vector", "fulltext", "graph"] @@ -47,6 +56,8 @@ def register_search_tools(mcp): score, and source. Also returns matching images when include_images=True. """ types = search_types or DEFAULT_SEARCH_TYPES + claims = await get_mcp_claims(ctx) + ws, libs = _scope_from_claims(claims, workspace_id) with record_tool_call("search"): user = await get_mcp_user(ctx) return await sync_to_async(_run_search, thread_sensitive=True)( @@ -55,7 +66,8 @@ def register_search_tools(mcp): library_uid=library_uid, library_type=library_type, collection_uid=collection_uid, - workspace_id=workspace_id, + workspace_id=ws, + allowed_libraries=libs, limit=limit, rerank=rerank, include_images=include_images, @@ -75,14 +87,17 @@ def register_search_tools(mcp): item_uid, item_title, library_type, text. Use this when the 500-character text_preview from `search` isn't enough. """ + claims = await get_mcp_claims(ctx) + ws, libs = _scope_from_claims(claims, workspace_id) with record_tool_call("get_chunk"): return await sync_to_async(_load_chunk, thread_sensitive=True)( - chunk_uid, workspace_id + chunk_uid, ws, libs ) def _run_search(*, user, query, library_uid, library_type, collection_uid, - workspace_id, limit, rerank, include_images, search_types) -> dict[str, Any]: + workspace_id, allowed_libraries, limit, rerank, include_images, + search_types) -> dict[str, Any]: from library.services.search import SearchRequest, SearchService req = SearchRequest( @@ -91,6 +106,7 @@ def _run_search(*, user, query, library_uid, library_type, collection_uid, library_type=library_type, collection_uid=collection_uid, workspace_id=workspace_id, + allowed_libraries=allowed_libraries, search_types=search_types, limit=limit, vector_top_k=getattr(settings, "SEARCH_VECTOR_TOP_K", 50), @@ -112,17 +128,27 @@ def _run_search(*, user, query, library_uid, library_type, collection_uid, } -def _load_chunk(chunk_uid: str, workspace_id: str | None = None) -> dict[str, Any]: +def _load_chunk( + chunk_uid: str, + workspace_id: str | None = None, + allowed_libraries: list[str] | None = None, +) -> dict[str, Any]: from neomodel import db rows, _ = db.cypher_query( "MATCH (l:Library)-[:CONTAINS]->(:Collection)-[:CONTAINS]->" "(i:Item)-[:HAS_CHUNK]->(c:Chunk {uid: $uid}) " - "WHERE ($workspace_id IS NULL AND l.workspace_id IS NULL OR " - " l.workspace_id = $workspace_id) " + "WHERE (($workspace_id IS NOT NULL AND l.workspace_id = $workspace_id) " + " OR ($allowed_libraries IS NOT NULL AND l.uid IN $allowed_libraries) " + " OR ($workspace_id IS NULL AND $allowed_libraries IS NULL " + " AND l.workspace_id IS NULL)) " "RETURN c.uid, c.chunk_index, c.chunk_s3_key, " "i.uid, i.title, l.library_type LIMIT 1", - {"uid": chunk_uid, "workspace_id": workspace_id}, + { + "uid": chunk_uid, + "workspace_id": workspace_id, + "allowed_libraries": allowed_libraries, + }, ) if not rows: raise ValueError(f"Chunk not found: {chunk_uid}") diff --git a/pyproject.toml b/pyproject.toml index 7f67a16..86ae9a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,8 @@ dependencies = [ # Phase 5: MCP Server "fastmcp>=2.0,<3.0", "uvicorn[standard]>=0.30,<1.0", + # Phase 6: Per-turn signed JWTs from Daedalus + "PyJWT>=2.8,<3.0", ] [project.optional-dependencies]