feat(library): add workspace-scoped search and JWT auth for Daedalus
All checks were successful
CVE Scan & Docker Build / security-scan (push) Successful in 52s
CVE Scan & Docker Build / build-and-push (push) Successful in 2m32s

- Extend library list endpoint with `include_workspace` and
  `with_item_count` query params to support Daedalus registry mirroring
- Expand search scope clause to three modes: workspace-only, workspace
  plus allowed user libraries, and global
- Add `allowed_libraries` field to SearchRequest for Phase-2 JWT claims
- Introduce JWT-based actor resolution using a synthetic service user
  (`MCP_JWT_SERVICE_USERNAME`) for Daedalus-originated requests
This commit is contained in:
2026-05-03 17:36:06 -04:00
parent e5618973fc
commit a2c885cf34
11 changed files with 555 additions and 48 deletions

View File

@@ -40,13 +40,44 @@ logger = logging.getLogger(__name__)
@api_view(["GET", "POST"])
@permission_classes([IsAuthenticated])
def library_list_create(request):
"""List all libraries or create a new one."""
"""List all libraries or create a new one.
GET supports ``?include_workspace=false`` (default ``true``) to filter out
libraries that belong to a Daedalus workspace — the Daedalus library
registry uses this to mirror only the user-managed catalog.
GET supports ``?with_item_count=true`` (default ``false``) to attach
a per-library ``item_count``. Off by default because the count is a
Cypher aggregate; on for the Daedalus-side registry poll.
"""
from library.models import Library
if request.method == "GET":
libraries = Library.nodes.order_by("name")
serializer = LibrarySerializer(libraries, many=True)
return Response(serializer.data)
include_workspace = request.GET.get("include_workspace", "true").lower() != "false"
with_count = request.GET.get("with_item_count", "false").lower() == "true"
if include_workspace:
libraries = list(Library.nodes.order_by("name"))
else:
libraries = list(Library.nodes.filter(workspace_id__isnull=True).order_by("name"))
data = LibrarySerializer(libraries, many=True).data
if with_count and libraries:
from neomodel import db
uids = [lib.uid for lib in libraries]
rows, _ = db.cypher_query(
"MATCH (l:Library) WHERE l.uid IN $uids "
"OPTIONAL MATCH (l)-[:CONTAINS]->(:Collection)-[:CONTAINS]->(i:Item) "
"RETURN l.uid, count(i)",
{"uids": uids},
)
counts = {uid: count for (uid, count) in rows}
for entry in data:
entry["item_count"] = counts.get(entry["uid"], 0)
return Response(data)
# POST — create
serializer = LibrarySerializer(data=request.data)

View File

@@ -26,14 +26,26 @@ from .fusion import ImageSearchResult, SearchCandidate, reciprocal_rank_fusion
logger = logging.getLogger(__name__)
# Workspace scoping clause appended to every search Cypher query.
# Search-scope clause appended to every search Cypher query.
#
# A request with workspace_id set returns ONLY that workspace's content.
# A request with workspace_id null returns ONLY global content (libraries
# with no workspace_id). There is no third mode.
# Three modes, picked structurally by which params are set:
#
# 1. ``workspace_id`` set, ``allowed_libraries`` empty → workspace-scoped.
# Returns ONLY content from libraries whose workspace_id matches.
# 2. ``workspace_id`` set + ``allowed_libraries`` non-empty → workspace
# PLUS the listed user-managed libraries (typical Phase-2 chat turn).
# 3. Both null → global. Returns ONLY libraries with no workspace_id
# (legacy opaque-token callers / dashboard).
#
# When ``allowed_libraries`` is non-empty alone (no workspace_id), it
# narrows results to those libraries.
_WORKSPACE_SCOPE_CLAUSE = (
" AND ($workspace_id IS NULL AND lib.workspace_id IS NULL OR "
"lib.workspace_id = $workspace_id)"
" AND ("
"($workspace_id IS NOT NULL AND lib.workspace_id = $workspace_id) "
"OR ($allowed_libraries IS NOT NULL AND lib.uid IN $allowed_libraries) "
"OR ($workspace_id IS NULL AND $allowed_libraries IS NULL "
" AND lib.workspace_id IS NULL)"
")"
)
@@ -52,6 +64,10 @@ class SearchRequest:
library_type: Optional[str] = None
collection_uid: Optional[str] = None
workspace_id: Optional[str] = None
# Phase-2 token claim: user-managed libraries the caller may include
# alongside their workspace's auto-library. Cypher uses ``IS NULL`` vs
# non-empty list to gate the second branch of the scope clause.
allowed_libraries: Optional[list[str]] = None
search_types: list[str] = field(
default_factory=lambda: ["vector", "fulltext", "graph"]
)
@@ -73,6 +89,11 @@ class SearchRequest:
self.library_type = None
if self.collection_uid == "":
self.collection_uid = None
# Empty list collapses to None so the Cypher branch reads
# "$allowed_libraries IS NOT NULL" rather than "size > 0" — keeps
# the parameter binding straightforward and the predicate sargable.
if self.allowed_libraries is not None and len(self.allowed_libraries) == 0:
self.allowed_libraries = None
@dataclass
@@ -300,6 +321,7 @@ class SearchService:
"library_type": request.library_type,
"collection_uid": request.collection_uid,
"workspace_id": request.workspace_id,
"allowed_libraries": request.allowed_libraries,
}
try:
@@ -411,6 +433,7 @@ class SearchService:
"library_type": request.library_type,
"collection_uid": request.collection_uid,
"workspace_id": request.workspace_id,
"allowed_libraries": request.allowed_libraries,
}
try:
@@ -471,6 +494,7 @@ class SearchService:
"library_uid": request.library_uid,
"library_type": request.library_type,
"workspace_id": request.workspace_id,
"allowed_libraries": request.allowed_libraries,
}
try:
@@ -546,6 +570,7 @@ class SearchService:
"library_uid": request.library_uid,
"library_type": request.library_type,
"workspace_id": request.workspace_id,
"allowed_libraries": request.allowed_libraries,
}
try:
@@ -630,6 +655,7 @@ class SearchService:
"library_uid": request.library_uid,
"library_type": request.library_type,
"workspace_id": request.workspace_id,
"allowed_libraries": request.allowed_libraries,
}
try:

View File

@@ -1,31 +1,68 @@
"""MCP token resolution and FastMCP middleware for bearer-token auth."""
"""MCP token resolution and FastMCP middleware for bearer-token auth.
Two token shapes are supported:
* **Opaque** — the original `MCPToken` row. Long-lived, hashed at rest,
used by the dashboard / Claude Desktop / admin tooling. Plaintext
hashes to a row in `mcp_token`.
* **Signed JWT** — per-turn token minted by Daedalus. Carries
`{ws, libs}` claims. Validated entirely off the signature + claims;
no database lookup of the token itself, only of the signing key
(`MCPSigningKey`) referenced by the JWT header's `kid`.
Detection: a bearer with three base64url segments separated by dots and
a parseable `{"alg":"HS256","kid":...}` header is treated as JWT; anything
else falls through to the opaque path.
"""
from __future__ import annotations
import base64
import hashlib
import json
import logging
import time
from collections import OrderedDict
import jwt as pyjwt
from asgiref.sync import sync_to_async
from django.conf import settings
from django.contrib.auth import get_user_model
from django.utils import timezone
from fastmcp.server.dependencies import get_http_request
from fastmcp.server.middleware import Middleware, MiddlewareContext
from .metrics import mcp_auth_failures_total
from .models import MCPToken, hash_token
from .models import MCPSigningKey, MCPToken, hash_token
logger = logging.getLogger(__name__)
STATE_KEY_USER = "mcp_user"
STATE_KEY_TOKEN = "mcp_token"
STATE_KEY_CLAIMS = "mcp_claims"
# Permitted clock skew when validating JWT exp/iat. PyJWT applies this
# symmetrically as ``leeway``.
_JWT_LEEWAY_SECONDS = 30
# Mnemosyne is the audience; Daedalus is the only accepted issuer.
_JWT_ISS = "daedalus"
# Bounded LRU of recently-seen jti values to discourage replay within
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
# is best-effort and per-process.
_JTI_CACHE_MAX = 4096
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
class MCPAuthError(Exception):
"""Raised when a bearer token cannot be resolved to a valid user."""
# --- Opaque token path -----------------------------------------------------
def resolve_mcp_user(token_string: str):
"""Resolve a bearer token to (user, MCPToken). Raises MCPAuthError on any failure.
"""Resolve an opaque bearer token to (user, MCPToken).
Hashes the incoming bearer and looks up by the hash — plaintext is never
stored or compared directly.
@@ -50,6 +87,110 @@ def resolve_mcp_user(token_string: str):
return token.user, token
# --- JWT path --------------------------------------------------------------
def looks_like_jwt(token_string: str) -> bool:
"""Cheap pre-check: 3 dot-separated segments and a valid base64url header
that decodes to JSON with an alg field."""
parts = token_string.split(".")
if len(parts) != 3:
return False
try:
header_bytes = _b64url_decode(parts[0])
header = json.loads(header_bytes)
except (ValueError, json.JSONDecodeError):
return False
return isinstance(header, dict) and isinstance(header.get("alg"), str)
def _b64url_decode(segment: str) -> bytes:
pad = "=" * (-len(segment) % 4)
return base64.urlsafe_b64decode(segment + pad)
def _remember_jti(jti: str) -> bool:
"""Return True if this jti has been seen before in this process."""
now = time.time()
# Drop any cached jti whose entry is older than 1h — generous given exp ≤ 600s.
cutoff = now - 3600
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < cutoff:
_JTI_CACHE.popitem(last=False)
if jti in _JTI_CACHE:
return True
if len(_JTI_CACHE) >= _JTI_CACHE_MAX:
_JTI_CACHE.popitem(last=False)
_JTI_CACHE[jti] = now
return False
def resolve_mcp_jwt(token_string: str) -> dict:
"""Validate a signed JWT and return its claims dict.
Raises ``MCPAuthError`` on any failure. Does not touch ``MCPToken`` —
JWTs are stateless and stored only as their signing key (``MCPSigningKey``).
"""
try:
unverified_header = pyjwt.get_unverified_header(token_string)
except pyjwt.PyJWTError as exc:
raise MCPAuthError(f"Malformed JWT header: {exc}")
kid = unverified_header.get("kid")
if not kid:
raise MCPAuthError("JWT header missing kid.")
if unverified_header.get("alg") != "HS256":
raise MCPAuthError("Unsupported JWT alg.")
key = MCPSigningKey.objects.by_kid(kid)
if key is None:
raise MCPAuthError(f"Unknown signing key kid={kid!r}.")
if not key.is_active:
raise MCPAuthError(f"Signing key {kid!r} has been retired.")
try:
secret = bytes.fromhex(key.secret_hex)
except ValueError:
raise MCPAuthError(f"Stored secret for kid={kid!r} is not valid hex.")
try:
claims = pyjwt.decode(
token_string,
secret,
algorithms=["HS256"],
leeway=_JWT_LEEWAY_SECONDS,
options={"require": ["exp", "iat", "iss", "sub", "jti"]},
issuer=_JWT_ISS,
)
except pyjwt.ExpiredSignatureError:
raise MCPAuthError("Token has expired.")
except pyjwt.InvalidIssuerError:
raise MCPAuthError("Invalid token issuer.")
except pyjwt.InvalidSignatureError:
raise MCPAuthError("Invalid MCP token.")
except pyjwt.MissingRequiredClaimError as exc:
raise MCPAuthError(f"JWT missing required claim: {exc.claim}")
except pyjwt.PyJWTError as exc:
raise MCPAuthError(f"Invalid JWT: {exc}")
jti = claims.get("jti")
if not isinstance(jti, str) or not jti:
raise MCPAuthError("JWT jti must be a non-empty string.")
if _remember_jti(jti):
raise MCPAuthError("Token replay detected.")
# Normalize claim shapes: ws may be null/absent, libs default to [].
claims["ws"] = claims.get("ws") or None
libs = claims.get("libs") or []
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
raise MCPAuthError("JWT libs must be a list of strings.")
claims["libs"] = libs
claims["kid"] = kid
return claims
# --- Middleware ------------------------------------------------------------
class MCPAuthMiddleware(Middleware):
"""
FastMCP middleware that authenticates tool calls via Bearer tokens.
@@ -59,17 +200,24 @@ class MCPAuthMiddleware(Middleware):
MCP_REQUIRE_AUTH=False.
"""
async def on_call_tool(
self, context: MiddlewareContext, call_next
):
async def on_call_tool(self, context: MiddlewareContext, call_next):
require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True)
token_string = self._extract_token()
user = None
token = None
claims: dict | None = None
if token_string:
try:
if looks_like_jwt(token_string):
claims = await sync_to_async(
resolve_mcp_jwt, thread_sensitive=True
)(token_string)
user = await sync_to_async(
_resolve_jwt_actor, thread_sensitive=True
)(claims)
else:
user, token = await sync_to_async(
resolve_mcp_user, thread_sensitive=True
)(token_string)
@@ -89,9 +237,13 @@ class MCPAuthMiddleware(Middleware):
)
fastmcp_ctx = getattr(context, "fastmcp_context", None)
if fastmcp_ctx and user is not None:
if fastmcp_ctx is not None:
if user is not None:
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
if token is not None:
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
if claims is not None:
await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims)
return await call_next(context)
@@ -111,3 +263,28 @@ class MCPAuthMiddleware(Middleware):
msg = getattr(context, "message", None)
params = getattr(msg, "params", None) if msg else None
return getattr(params, "name", None)
# --- Helpers ---------------------------------------------------------------
def _resolve_jwt_actor(claims: dict):
"""Resolve the synthetic actor for a JWT-authenticated turn.
Returns the system service user (``MCP_JWT_SERVICE_USERNAME``, default
``daedalus-service``). The user must exist and be active. JWT tokens
are not tied to per-user accounts — claims encode all authorization.
"""
from django.contrib.auth import get_user_model
User = get_user_model()
username = getattr(settings, "MCP_JWT_SERVICE_USERNAME", "daedalus-service")
try:
user = User.objects.get(username=username)
except User.DoesNotExist:
raise MCPAuthError(
f"JWT service user {username!r} does not exist; provision via management command."
)
if not user.is_active:
raise MCPAuthError(f"JWT service user {username!r} is disabled.")
return user

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from fastmcp.server.context import Context
from .auth import STATE_KEY_TOKEN, STATE_KEY_USER
from .auth import STATE_KEY_CLAIMS, STATE_KEY_TOKEN, STATE_KEY_USER
async def get_mcp_user(ctx: Context | None):
@@ -17,3 +17,10 @@ async def get_mcp_token(ctx: Context | None):
if ctx is None:
return None
return await ctx.get_state(STATE_KEY_TOKEN)
async def get_mcp_claims(ctx: Context | None) -> dict | None:
"""Return the JWT claims dict for this request, or None for opaque-token callers."""
if ctx is None:
return None
return await ctx.get_state(STATE_KEY_CLAIMS)

View File

@@ -0,0 +1,48 @@
"""Idempotently create the service user that JWT-authenticated MCP requests act as.
Daedalus mints per-turn JWTs whose claims encode all authorization (workspace,
allowed libraries). The Django ``user`` field on the request still needs to
point at *something* — the service user is that something. It owns no data
and does not log in via the dashboard.
"""
import secrets
from django.contrib.auth import get_user_model
from django.core.management.base import BaseCommand
class Command(BaseCommand):
help = "Idempotently create or reactivate the JWT service user (default 'daedalus-service')."
def add_arguments(self, parser):
parser.add_argument("--username", default="daedalus-service")
parser.add_argument("--email", default="daedalus-service@local")
def handle(self, *args, **options):
User = get_user_model()
username = options["username"]
email = options["email"]
user, created = User.objects.get_or_create(
username=username,
defaults={"email": email, "is_active": True},
)
if created:
# Set a random password the user cannot log in with via the UI.
user.set_password(secrets.token_urlsafe(32))
user.save(update_fields=["password"])
self.stdout.write(self.style.SUCCESS(f"Created service user {username!r}"))
else:
changed = False
if not user.is_active:
user.is_active = True
changed = True
if user.email != email:
user.email = email
changed = True
if changed:
user.save(update_fields=["is_active", "email"])
self.stdout.write(self.style.SUCCESS(f"Updated service user {username!r}"))
else:
self.stdout.write(f"Service user {username!r} already provisioned")

View File

@@ -0,0 +1,71 @@
"""Mint a new MCPSigningKey for per-turn JWTs.
The secret is generated locally, printed once, and stored hex-encoded.
Distribute the printed secret to the issuer (Daedalus) — after this
command exits, the database row is the only copy on the Mnemosyne
side, but copying the secret out of the row is allowed (admin model
exposes it). Treat it like any other shared secret.
Use ``--retire-other`` to flip every other key to ``is_active=False``
in the same transaction (typical rotation flow). Without it, the new
key joins the active set and old keys keep validating until they are
retired explicitly.
"""
import secrets
from django.core.management.base import BaseCommand, CommandError
from django.db import transaction
from mcp_server.models import MCPSigningKey
class Command(BaseCommand):
help = "Generate a new HMAC signing key for per-turn JWTs and print it once."
def add_arguments(self, parser):
parser.add_argument(
"--kid",
required=True,
help="Key ID (short, stable string — e.g. 'daedalus-1').",
)
parser.add_argument(
"--note",
default="",
help="Optional note (e.g. 'rotated 2026-05-03 after staging incident').",
)
parser.add_argument(
"--retire-other",
action="store_true",
help=(
"Mark every other MCPSigningKey row inactive in the same "
"transaction. Use during rotation."
),
)
@transaction.atomic
def handle(self, *args, **options):
kid = options["kid"]
if MCPSigningKey.objects.filter(kid=kid).exists():
raise CommandError(f"kid {kid!r} already exists; choose a unique value.")
# 256-bit secret, hex-encoded for portability.
secret = secrets.token_hex(32)
if options["retire_other"]:
from django.utils import timezone
MCPSigningKey.objects.filter(is_active=True).update(
is_active=False, retired_at=timezone.now()
)
key = MCPSigningKey.objects.create(
kid=kid, secret_hex=secret, is_active=True, note=options["note"]
)
self.stdout.write(self.style.SUCCESS("MCP signing key created"))
self.stdout.write(f" kid: {key.kid}")
self.stdout.write(f" active: {key.is_active}")
if options["retire_other"]:
self.stdout.write(" (other active keys retired)")
self.stdout.write(self.style.WARNING(" Secret (256-bit hex — share with Daedalus):"))
self.stdout.write(f" {secret}")

View File

@@ -0,0 +1,38 @@
"""HMAC signing keys for per-turn JWTs minted by Daedalus.
Adds the MCPSigningKey table. Per-turn tokens (workspace + library claims,
exp <= 600s) are not stored — only the signing key, indexed by ``kid``,
so the signature can be validated and rotated cleanly.
"""
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("mcp_server", "0002_hash_token"),
]
operations = [
migrations.CreateModel(
name="MCPSigningKey",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("kid", models.CharField(max_length=64, unique=True, db_index=True)),
("secret_hex", models.CharField(max_length=128)),
("is_active", models.BooleanField(default=True)),
("created_at", models.DateTimeField(auto_now_add=True)),
("retired_at", models.DateTimeField(blank=True, null=True)),
("note", models.TextField(blank=True)),
],
options={"ordering": ["-created_at"]},
),
]

View File

@@ -84,3 +84,46 @@ class MCPToken(models.Model):
hash prefixed with `mcp_…`. Stable per token, never reveals plaintext.
"""
return f"mcp_…{self.token_hash[:8]}"
class MCPSigningKeyManager(models.Manager):
def active(self):
"""Active keys, newest first. Multiple may overlap during rotation."""
return self.filter(is_active=True).order_by("-created_at")
def by_kid(self, kid: str):
return self.filter(kid=kid).first()
class MCPSigningKey(models.Model):
"""HMAC signing key for per-turn JWTs minted by Daedalus.
Per-turn tokens carry workspace + library claims and expire in minutes.
They are validated entirely off the signature + claims; no row is stored
per token. Only the *signing key* is persisted here, indexed by ``kid``.
Rotation: seed a new active key, distribute the secret to Daedalus,
flip the old one ``is_active=False``. In-flight tokens with the retired
``kid`` fail at ``exp`` (bounded by the per-turn TTL).
"""
kid = models.CharField(max_length=64, unique=True, db_index=True)
secret_hex = models.CharField(max_length=128) # 256-bit secret = 64 hex
is_active = models.BooleanField(default=True)
created_at = models.DateTimeField(auto_now_add=True)
retired_at = models.DateTimeField(null=True, blank=True)
note = models.TextField(blank=True)
objects = MCPSigningKeyManager()
class Meta:
ordering = ["-created_at"]
def __str__(self):
suffix = "active" if self.is_active else "retired"
return f"{self.kid} ({suffix})"
def retire(self):
self.is_active = False
self.retired_at = timezone.now()
self.save(update_fields=["is_active", "retired_at"])

View File

@@ -5,13 +5,29 @@ from __future__ import annotations
from typing import Any
from asgiref.sync import sync_to_async
from fastmcp.server.context import Context
from ..context import get_mcp_claims
from ..metrics import record_tool_call
DEFAULT_LIMIT = 50
MAX_LIMIT = 200
def _scope_from_claims(claims: dict | None,
arg_workspace_id: str | None) -> tuple[str | None, list[str] | None]:
"""Return (workspace_id, allowed_libraries) for a tool call.
Token claims, when present, trump tool args — that's the security
contract. Opaque-token callers (no claims) keep the legacy behavior
where the caller may pass workspace_id explicitly (typically null,
yielding global scope).
"""
if claims is not None:
return claims.get("ws"), claims.get("libs") or None
return arg_workspace_id, None
def _clamp(limit: int) -> int:
if limit < 1:
return 1
@@ -25,6 +41,7 @@ def register_discovery_tools(mcp):
offset: int = 0,
# System-injected; deliberately absent from the docstring.
workspace_id: str | None = None,
ctx: Context | None = None,
) -> dict[str, Any]:
"""List Mnemosyne libraries. Each library has a content-aware library_type
(fiction, nonfiction, technical, music, film, art, journal, business,
@@ -32,9 +49,11 @@ def register_discovery_tools(mcp):
name, library_type, description for each library — use the uid or
library_type to scope a subsequent search.
"""
claims = await get_mcp_claims(ctx)
ws, libs = _scope_from_claims(claims, workspace_id)
with record_tool_call("list_libraries"):
return await sync_to_async(_query_libraries, thread_sensitive=True)(
_clamp(limit), max(offset, 0), workspace_id
_clamp(limit), max(offset, 0), ws, libs
)
@mcp.tool
@@ -44,15 +63,18 @@ def register_discovery_tools(mcp):
offset: int = 0,
# System-injected; deliberately absent from the docstring.
workspace_id: str | None = None,
ctx: Context | None = None,
) -> dict[str, Any]:
"""List collections, optionally filtered by parent library_uid.
Collections group related items inside a library (e.g. a series of novels,
a multi-volume manual). Returns uid, name, description, library_uid,
library_name. Use the uid to scope a subsequent search to one collection.
"""
claims = await get_mcp_claims(ctx)
ws, libs = _scope_from_claims(claims, workspace_id)
with record_tool_call("list_collections"):
return await sync_to_async(_query_collections, thread_sensitive=True)(
library_uid, _clamp(limit), max(offset, 0), workspace_id
library_uid, _clamp(limit), max(offset, 0), ws, libs
)
@mcp.tool
@@ -63,6 +85,7 @@ def register_discovery_tools(mcp):
offset: int = 0,
# System-injected; deliberately absent from the docstring.
workspace_id: str | None = None,
ctx: Context | None = None,
) -> dict[str, Any]:
"""List items (the indexed documents/files), optionally filtered by
collection_uid or library_uid. Returns uid, title, item_type, file_type,
@@ -70,20 +93,27 @@ def register_discovery_tools(mcp):
document size; use embedding_status to skip items that are not yet
searchable (only 'completed' items appear in search results).
"""
claims = await get_mcp_claims(ctx)
ws, libs = _scope_from_claims(claims, workspace_id)
with record_tool_call("list_items"):
return await sync_to_async(_query_items, thread_sensitive=True)(
collection_uid, library_uid, _clamp(limit), max(offset, 0), workspace_id
collection_uid, library_uid, _clamp(limit), max(offset, 0), ws, libs
)
_WORKSPACE_SCOPE = (
"($workspace_id IS NULL AND l.workspace_id IS NULL OR "
"l.workspace_id = $workspace_id)"
"(($workspace_id IS NOT NULL AND l.workspace_id = $workspace_id) "
"OR ($allowed_libraries IS NOT NULL AND l.uid IN $allowed_libraries) "
"OR ($workspace_id IS NULL AND $allowed_libraries IS NULL "
" AND l.workspace_id IS NULL))"
)
def _query_libraries(
limit: int, offset: int, workspace_id: str | None = None
limit: int,
offset: int,
workspace_id: str | None = None,
allowed_libraries: list[str] | None = None,
) -> dict[str, Any]:
from neomodel import db
@@ -92,7 +122,11 @@ def _query_libraries(
f"WHERE {_WORKSPACE_SCOPE} "
"RETURN l.uid, l.name, l.library_type, l.description "
"ORDER BY l.name SKIP $offset LIMIT $limit",
{"offset": offset, "limit": limit, "workspace_id": workspace_id},
{
"offset": offset, "limit": limit,
"workspace_id": workspace_id,
"allowed_libraries": allowed_libraries,
},
)
return {
"libraries": [
@@ -110,11 +144,19 @@ def _query_libraries(
def _query_collections(
library_uid: str | None, limit: int, offset: int,
library_uid: str | None,
limit: int,
offset: int,
workspace_id: str | None = None,
allowed_libraries: list[str] | None = None,
) -> dict[str, Any]:
from neomodel import db
base_params = {
"offset": offset, "limit": limit,
"workspace_id": workspace_id,
"allowed_libraries": allowed_libraries,
}
if library_uid:
cypher = (
"MATCH (l:Library {uid: $library_uid})-[:CONTAINS]->(c:Collection) "
@@ -122,10 +164,7 @@ def _query_collections(
"RETURN c.uid, c.name, c.description, l.uid, l.name "
"ORDER BY c.name SKIP $offset LIMIT $limit"
)
params = {
"library_uid": library_uid, "offset": offset, "limit": limit,
"workspace_id": workspace_id,
}
params = {**base_params, "library_uid": library_uid}
else:
cypher = (
"MATCH (l:Library)-[:CONTAINS]->(c:Collection) "
@@ -133,10 +172,7 @@ def _query_collections(
"RETURN c.uid, c.name, c.description, l.uid, l.name "
"ORDER BY l.name, c.name SKIP $offset LIMIT $limit"
)
params = {
"offset": offset, "limit": limit,
"workspace_id": workspace_id,
}
params = base_params
rows, _ = db.cypher_query(cypher, params)
return {
@@ -161,6 +197,7 @@ def _query_items(
limit: int,
offset: int,
workspace_id: str | None = None,
allowed_libraries: list[str] | None = None,
) -> dict[str, Any]:
from neomodel import db
@@ -168,6 +205,7 @@ def _query_items(
params: dict[str, Any] = {
"offset": offset, "limit": limit,
"workspace_id": workspace_id,
"allowed_libraries": allowed_libraries,
}
if collection_uid:
where.append("c.uid = $collection_uid")

View File

@@ -10,9 +10,18 @@ from django.conf import settings
from django.core.files.storage import default_storage
from fastmcp.server.context import Context
from ..context import get_mcp_user
from ..context import get_mcp_claims, get_mcp_user
from ..metrics import record_tool_call
def _scope_from_claims(claims: dict | None,
arg_workspace_id: str | None) -> tuple[str | None, list[str] | None]:
"""Return (workspace_id, allowed_libraries) for a tool call. Token claims
trump tool args when present."""
if claims is not None:
return claims.get("ws"), claims.get("libs") or None
return arg_workspace_id, None
DEFAULT_SEARCH_TYPES = ["vector", "fulltext", "graph"]
@@ -47,6 +56,8 @@ def register_search_tools(mcp):
score, and source. Also returns matching images when include_images=True.
"""
types = search_types or DEFAULT_SEARCH_TYPES
claims = await get_mcp_claims(ctx)
ws, libs = _scope_from_claims(claims, workspace_id)
with record_tool_call("search"):
user = await get_mcp_user(ctx)
return await sync_to_async(_run_search, thread_sensitive=True)(
@@ -55,7 +66,8 @@ def register_search_tools(mcp):
library_uid=library_uid,
library_type=library_type,
collection_uid=collection_uid,
workspace_id=workspace_id,
workspace_id=ws,
allowed_libraries=libs,
limit=limit,
rerank=rerank,
include_images=include_images,
@@ -75,14 +87,17 @@ def register_search_tools(mcp):
item_uid, item_title, library_type, text. Use this when the 500-character
text_preview from `search` isn't enough.
"""
claims = await get_mcp_claims(ctx)
ws, libs = _scope_from_claims(claims, workspace_id)
with record_tool_call("get_chunk"):
return await sync_to_async(_load_chunk, thread_sensitive=True)(
chunk_uid, workspace_id
chunk_uid, ws, libs
)
def _run_search(*, user, query, library_uid, library_type, collection_uid,
workspace_id, limit, rerank, include_images, search_types) -> dict[str, Any]:
workspace_id, allowed_libraries, limit, rerank, include_images,
search_types) -> dict[str, Any]:
from library.services.search import SearchRequest, SearchService
req = SearchRequest(
@@ -91,6 +106,7 @@ def _run_search(*, user, query, library_uid, library_type, collection_uid,
library_type=library_type,
collection_uid=collection_uid,
workspace_id=workspace_id,
allowed_libraries=allowed_libraries,
search_types=search_types,
limit=limit,
vector_top_k=getattr(settings, "SEARCH_VECTOR_TOP_K", 50),
@@ -112,17 +128,27 @@ def _run_search(*, user, query, library_uid, library_type, collection_uid,
}
def _load_chunk(chunk_uid: str, workspace_id: str | None = None) -> dict[str, Any]:
def _load_chunk(
chunk_uid: str,
workspace_id: str | None = None,
allowed_libraries: list[str] | None = None,
) -> dict[str, Any]:
from neomodel import db
rows, _ = db.cypher_query(
"MATCH (l:Library)-[:CONTAINS]->(:Collection)-[:CONTAINS]->"
"(i:Item)-[:HAS_CHUNK]->(c:Chunk {uid: $uid}) "
"WHERE ($workspace_id IS NULL AND l.workspace_id IS NULL OR "
" l.workspace_id = $workspace_id) "
"WHERE (($workspace_id IS NOT NULL AND l.workspace_id = $workspace_id) "
" OR ($allowed_libraries IS NOT NULL AND l.uid IN $allowed_libraries) "
" OR ($workspace_id IS NULL AND $allowed_libraries IS NULL "
" AND l.workspace_id IS NULL)) "
"RETURN c.uid, c.chunk_index, c.chunk_s3_key, "
"i.uid, i.title, l.library_type LIMIT 1",
{"uid": chunk_uid, "workspace_id": workspace_id},
{
"uid": chunk_uid,
"workspace_id": workspace_id,
"allowed_libraries": allowed_libraries,
},
)
if not rows:
raise ValueError(f"Chunk not found: {chunk_uid}")

View File

@@ -33,6 +33,8 @@ dependencies = [
# Phase 5: MCP Server
"fastmcp>=2.0,<3.0",
"uvicorn[standard]>=0.30,<1.0",
# Phase 6: Per-turn signed JWTs from Daedalus
"PyJWT>=2.8,<3.0",
]
[project.optional-dependencies]