feat(library): add workspace-scoped search and JWT auth for Daedalus
- Extend library list endpoint with `include_workspace` and `with_item_count` query params to support Daedalus registry mirroring - Expand search scope clause to three modes: workspace-only, workspace plus allowed user libraries, and global - Add `allowed_libraries` field to SearchRequest for Phase-2 JWT claims - Introduce JWT-based actor resolution using a synthetic service user (`MCP_JWT_SERVICE_USERNAME`) for Daedalus-originated requests
This commit is contained in:
@@ -40,13 +40,44 @@ logger = logging.getLogger(__name__)
|
|||||||
@api_view(["GET", "POST"])
|
@api_view(["GET", "POST"])
|
||||||
@permission_classes([IsAuthenticated])
|
@permission_classes([IsAuthenticated])
|
||||||
def library_list_create(request):
|
def library_list_create(request):
|
||||||
"""List all libraries or create a new one."""
|
"""List all libraries or create a new one.
|
||||||
|
|
||||||
|
GET supports ``?include_workspace=false`` (default ``true``) to filter out
|
||||||
|
libraries that belong to a Daedalus workspace — the Daedalus library
|
||||||
|
registry uses this to mirror only the user-managed catalog.
|
||||||
|
|
||||||
|
GET supports ``?with_item_count=true`` (default ``false``) to attach
|
||||||
|
a per-library ``item_count``. Off by default because the count is a
|
||||||
|
Cypher aggregate; on for the Daedalus-side registry poll.
|
||||||
|
"""
|
||||||
from library.models import Library
|
from library.models import Library
|
||||||
|
|
||||||
if request.method == "GET":
|
if request.method == "GET":
|
||||||
libraries = Library.nodes.order_by("name")
|
include_workspace = request.GET.get("include_workspace", "true").lower() != "false"
|
||||||
serializer = LibrarySerializer(libraries, many=True)
|
with_count = request.GET.get("with_item_count", "false").lower() == "true"
|
||||||
return Response(serializer.data)
|
|
||||||
|
if include_workspace:
|
||||||
|
libraries = list(Library.nodes.order_by("name"))
|
||||||
|
else:
|
||||||
|
libraries = list(Library.nodes.filter(workspace_id__isnull=True).order_by("name"))
|
||||||
|
|
||||||
|
data = LibrarySerializer(libraries, many=True).data
|
||||||
|
|
||||||
|
if with_count and libraries:
|
||||||
|
from neomodel import db
|
||||||
|
|
||||||
|
uids = [lib.uid for lib in libraries]
|
||||||
|
rows, _ = db.cypher_query(
|
||||||
|
"MATCH (l:Library) WHERE l.uid IN $uids "
|
||||||
|
"OPTIONAL MATCH (l)-[:CONTAINS]->(:Collection)-[:CONTAINS]->(i:Item) "
|
||||||
|
"RETURN l.uid, count(i)",
|
||||||
|
{"uids": uids},
|
||||||
|
)
|
||||||
|
counts = {uid: count for (uid, count) in rows}
|
||||||
|
for entry in data:
|
||||||
|
entry["item_count"] = counts.get(entry["uid"], 0)
|
||||||
|
|
||||||
|
return Response(data)
|
||||||
|
|
||||||
# POST — create
|
# POST — create
|
||||||
serializer = LibrarySerializer(data=request.data)
|
serializer = LibrarySerializer(data=request.data)
|
||||||
|
|||||||
@@ -26,14 +26,26 @@ from .fusion import ImageSearchResult, SearchCandidate, reciprocal_rank_fusion
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Workspace scoping clause appended to every search Cypher query.
|
# Search-scope clause appended to every search Cypher query.
|
||||||
#
|
#
|
||||||
# A request with workspace_id set returns ONLY that workspace's content.
|
# Three modes, picked structurally by which params are set:
|
||||||
# A request with workspace_id null returns ONLY global content (libraries
|
#
|
||||||
# with no workspace_id). There is no third mode.
|
# 1. ``workspace_id`` set, ``allowed_libraries`` empty → workspace-scoped.
|
||||||
|
# Returns ONLY content from libraries whose workspace_id matches.
|
||||||
|
# 2. ``workspace_id`` set + ``allowed_libraries`` non-empty → workspace
|
||||||
|
# PLUS the listed user-managed libraries (typical Phase-2 chat turn).
|
||||||
|
# 3. Both null → global. Returns ONLY libraries with no workspace_id
|
||||||
|
# (legacy opaque-token callers / dashboard).
|
||||||
|
#
|
||||||
|
# When ``allowed_libraries`` is non-empty alone (no workspace_id), it
|
||||||
|
# narrows results to those libraries.
|
||||||
_WORKSPACE_SCOPE_CLAUSE = (
|
_WORKSPACE_SCOPE_CLAUSE = (
|
||||||
" AND ($workspace_id IS NULL AND lib.workspace_id IS NULL OR "
|
" AND ("
|
||||||
"lib.workspace_id = $workspace_id)"
|
"($workspace_id IS NOT NULL AND lib.workspace_id = $workspace_id) "
|
||||||
|
"OR ($allowed_libraries IS NOT NULL AND lib.uid IN $allowed_libraries) "
|
||||||
|
"OR ($workspace_id IS NULL AND $allowed_libraries IS NULL "
|
||||||
|
" AND lib.workspace_id IS NULL)"
|
||||||
|
")"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -52,6 +64,10 @@ class SearchRequest:
|
|||||||
library_type: Optional[str] = None
|
library_type: Optional[str] = None
|
||||||
collection_uid: Optional[str] = None
|
collection_uid: Optional[str] = None
|
||||||
workspace_id: Optional[str] = None
|
workspace_id: Optional[str] = None
|
||||||
|
# Phase-2 token claim: user-managed libraries the caller may include
|
||||||
|
# alongside their workspace's auto-library. Cypher uses ``IS NULL`` vs
|
||||||
|
# non-empty list to gate the second branch of the scope clause.
|
||||||
|
allowed_libraries: Optional[list[str]] = None
|
||||||
search_types: list[str] = field(
|
search_types: list[str] = field(
|
||||||
default_factory=lambda: ["vector", "fulltext", "graph"]
|
default_factory=lambda: ["vector", "fulltext", "graph"]
|
||||||
)
|
)
|
||||||
@@ -73,6 +89,11 @@ class SearchRequest:
|
|||||||
self.library_type = None
|
self.library_type = None
|
||||||
if self.collection_uid == "":
|
if self.collection_uid == "":
|
||||||
self.collection_uid = None
|
self.collection_uid = None
|
||||||
|
# Empty list collapses to None so the Cypher branch reads
|
||||||
|
# "$allowed_libraries IS NOT NULL" rather than "size > 0" — keeps
|
||||||
|
# the parameter binding straightforward and the predicate sargable.
|
||||||
|
if self.allowed_libraries is not None and len(self.allowed_libraries) == 0:
|
||||||
|
self.allowed_libraries = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -300,6 +321,7 @@ class SearchService:
|
|||||||
"library_type": request.library_type,
|
"library_type": request.library_type,
|
||||||
"collection_uid": request.collection_uid,
|
"collection_uid": request.collection_uid,
|
||||||
"workspace_id": request.workspace_id,
|
"workspace_id": request.workspace_id,
|
||||||
|
"allowed_libraries": request.allowed_libraries,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -411,6 +433,7 @@ class SearchService:
|
|||||||
"library_type": request.library_type,
|
"library_type": request.library_type,
|
||||||
"collection_uid": request.collection_uid,
|
"collection_uid": request.collection_uid,
|
||||||
"workspace_id": request.workspace_id,
|
"workspace_id": request.workspace_id,
|
||||||
|
"allowed_libraries": request.allowed_libraries,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -471,6 +494,7 @@ class SearchService:
|
|||||||
"library_uid": request.library_uid,
|
"library_uid": request.library_uid,
|
||||||
"library_type": request.library_type,
|
"library_type": request.library_type,
|
||||||
"workspace_id": request.workspace_id,
|
"workspace_id": request.workspace_id,
|
||||||
|
"allowed_libraries": request.allowed_libraries,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -546,6 +570,7 @@ class SearchService:
|
|||||||
"library_uid": request.library_uid,
|
"library_uid": request.library_uid,
|
||||||
"library_type": request.library_type,
|
"library_type": request.library_type,
|
||||||
"workspace_id": request.workspace_id,
|
"workspace_id": request.workspace_id,
|
||||||
|
"allowed_libraries": request.allowed_libraries,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -630,6 +655,7 @@ class SearchService:
|
|||||||
"library_uid": request.library_uid,
|
"library_uid": request.library_uid,
|
||||||
"library_type": request.library_type,
|
"library_type": request.library_type,
|
||||||
"workspace_id": request.workspace_id,
|
"workspace_id": request.workspace_id,
|
||||||
|
"allowed_libraries": request.allowed_libraries,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,31 +1,68 @@
|
|||||||
"""MCP token resolution and FastMCP middleware for bearer-token auth."""
|
"""MCP token resolution and FastMCP middleware for bearer-token auth.
|
||||||
|
|
||||||
|
Two token shapes are supported:
|
||||||
|
|
||||||
|
* **Opaque** — the original `MCPToken` row. Long-lived, hashed at rest,
|
||||||
|
used by the dashboard / Claude Desktop / admin tooling. Plaintext
|
||||||
|
hashes to a row in `mcp_token`.
|
||||||
|
* **Signed JWT** — per-turn token minted by Daedalus. Carries
|
||||||
|
`{ws, libs}` claims. Validated entirely off the signature + claims;
|
||||||
|
no database lookup of the token itself, only of the signing key
|
||||||
|
(`MCPSigningKey`) referenced by the JWT header's `kid`.
|
||||||
|
|
||||||
|
Detection: a bearer with three base64url segments separated by dots and
|
||||||
|
a parseable `{"alg":"HS256","kid":...}` header is treated as JWT; anything
|
||||||
|
else falls through to the opaque path.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import jwt as pyjwt
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth import get_user_model
|
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from fastmcp.server.dependencies import get_http_request
|
from fastmcp.server.dependencies import get_http_request
|
||||||
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
||||||
|
|
||||||
from .metrics import mcp_auth_failures_total
|
from .metrics import mcp_auth_failures_total
|
||||||
from .models import MCPToken, hash_token
|
from .models import MCPSigningKey, MCPToken, hash_token
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
STATE_KEY_USER = "mcp_user"
|
STATE_KEY_USER = "mcp_user"
|
||||||
STATE_KEY_TOKEN = "mcp_token"
|
STATE_KEY_TOKEN = "mcp_token"
|
||||||
|
STATE_KEY_CLAIMS = "mcp_claims"
|
||||||
|
|
||||||
|
# Permitted clock skew when validating JWT exp/iat. PyJWT applies this
|
||||||
|
# symmetrically as ``leeway``.
|
||||||
|
_JWT_LEEWAY_SECONDS = 30
|
||||||
|
|
||||||
|
# Mnemosyne is the audience; Daedalus is the only accepted issuer.
|
||||||
|
_JWT_ISS = "daedalus"
|
||||||
|
|
||||||
|
# Bounded LRU of recently-seen jti values to discourage replay within
|
||||||
|
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
|
||||||
|
# is best-effort and per-process.
|
||||||
|
_JTI_CACHE_MAX = 4096
|
||||||
|
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
class MCPAuthError(Exception):
|
class MCPAuthError(Exception):
|
||||||
"""Raised when a bearer token cannot be resolved to a valid user."""
|
"""Raised when a bearer token cannot be resolved to a valid user."""
|
||||||
|
|
||||||
|
|
||||||
|
# --- Opaque token path -----------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def resolve_mcp_user(token_string: str):
|
def resolve_mcp_user(token_string: str):
|
||||||
"""Resolve a bearer token to (user, MCPToken). Raises MCPAuthError on any failure.
|
"""Resolve an opaque bearer token to (user, MCPToken).
|
||||||
|
|
||||||
Hashes the incoming bearer and looks up by the hash — plaintext is never
|
Hashes the incoming bearer and looks up by the hash — plaintext is never
|
||||||
stored or compared directly.
|
stored or compared directly.
|
||||||
@@ -50,6 +87,110 @@ def resolve_mcp_user(token_string: str):
|
|||||||
return token.user, token
|
return token.user, token
|
||||||
|
|
||||||
|
|
||||||
|
# --- JWT path --------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def looks_like_jwt(token_string: str) -> bool:
|
||||||
|
"""Cheap pre-check: 3 dot-separated segments and a valid base64url header
|
||||||
|
that decodes to JSON with an alg field."""
|
||||||
|
parts = token_string.split(".")
|
||||||
|
if len(parts) != 3:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
header_bytes = _b64url_decode(parts[0])
|
||||||
|
header = json.loads(header_bytes)
|
||||||
|
except (ValueError, json.JSONDecodeError):
|
||||||
|
return False
|
||||||
|
return isinstance(header, dict) and isinstance(header.get("alg"), str)
|
||||||
|
|
||||||
|
|
||||||
|
def _b64url_decode(segment: str) -> bytes:
|
||||||
|
pad = "=" * (-len(segment) % 4)
|
||||||
|
return base64.urlsafe_b64decode(segment + pad)
|
||||||
|
|
||||||
|
|
||||||
|
def _remember_jti(jti: str) -> bool:
|
||||||
|
"""Return True if this jti has been seen before in this process."""
|
||||||
|
now = time.time()
|
||||||
|
# Drop any cached jti whose entry is older than 1h — generous given exp ≤ 600s.
|
||||||
|
cutoff = now - 3600
|
||||||
|
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < cutoff:
|
||||||
|
_JTI_CACHE.popitem(last=False)
|
||||||
|
if jti in _JTI_CACHE:
|
||||||
|
return True
|
||||||
|
if len(_JTI_CACHE) >= _JTI_CACHE_MAX:
|
||||||
|
_JTI_CACHE.popitem(last=False)
|
||||||
|
_JTI_CACHE[jti] = now
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_mcp_jwt(token_string: str) -> dict:
|
||||||
|
"""Validate a signed JWT and return its claims dict.
|
||||||
|
|
||||||
|
Raises ``MCPAuthError`` on any failure. Does not touch ``MCPToken`` —
|
||||||
|
JWTs are stateless and stored only as their signing key (``MCPSigningKey``).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
unverified_header = pyjwt.get_unverified_header(token_string)
|
||||||
|
except pyjwt.PyJWTError as exc:
|
||||||
|
raise MCPAuthError(f"Malformed JWT header: {exc}")
|
||||||
|
|
||||||
|
kid = unverified_header.get("kid")
|
||||||
|
if not kid:
|
||||||
|
raise MCPAuthError("JWT header missing kid.")
|
||||||
|
if unverified_header.get("alg") != "HS256":
|
||||||
|
raise MCPAuthError("Unsupported JWT alg.")
|
||||||
|
|
||||||
|
key = MCPSigningKey.objects.by_kid(kid)
|
||||||
|
if key is None:
|
||||||
|
raise MCPAuthError(f"Unknown signing key kid={kid!r}.")
|
||||||
|
if not key.is_active:
|
||||||
|
raise MCPAuthError(f"Signing key {kid!r} has been retired.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
secret = bytes.fromhex(key.secret_hex)
|
||||||
|
except ValueError:
|
||||||
|
raise MCPAuthError(f"Stored secret for kid={kid!r} is not valid hex.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
claims = pyjwt.decode(
|
||||||
|
token_string,
|
||||||
|
secret,
|
||||||
|
algorithms=["HS256"],
|
||||||
|
leeway=_JWT_LEEWAY_SECONDS,
|
||||||
|
options={"require": ["exp", "iat", "iss", "sub", "jti"]},
|
||||||
|
issuer=_JWT_ISS,
|
||||||
|
)
|
||||||
|
except pyjwt.ExpiredSignatureError:
|
||||||
|
raise MCPAuthError("Token has expired.")
|
||||||
|
except pyjwt.InvalidIssuerError:
|
||||||
|
raise MCPAuthError("Invalid token issuer.")
|
||||||
|
except pyjwt.InvalidSignatureError:
|
||||||
|
raise MCPAuthError("Invalid MCP token.")
|
||||||
|
except pyjwt.MissingRequiredClaimError as exc:
|
||||||
|
raise MCPAuthError(f"JWT missing required claim: {exc.claim}")
|
||||||
|
except pyjwt.PyJWTError as exc:
|
||||||
|
raise MCPAuthError(f"Invalid JWT: {exc}")
|
||||||
|
|
||||||
|
jti = claims.get("jti")
|
||||||
|
if not isinstance(jti, str) or not jti:
|
||||||
|
raise MCPAuthError("JWT jti must be a non-empty string.")
|
||||||
|
if _remember_jti(jti):
|
||||||
|
raise MCPAuthError("Token replay detected.")
|
||||||
|
|
||||||
|
# Normalize claim shapes: ws may be null/absent, libs default to [].
|
||||||
|
claims["ws"] = claims.get("ws") or None
|
||||||
|
libs = claims.get("libs") or []
|
||||||
|
if not isinstance(libs, list) or not all(isinstance(x, str) for x in libs):
|
||||||
|
raise MCPAuthError("JWT libs must be a list of strings.")
|
||||||
|
claims["libs"] = libs
|
||||||
|
claims["kid"] = kid
|
||||||
|
return claims
|
||||||
|
|
||||||
|
|
||||||
|
# --- Middleware ------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class MCPAuthMiddleware(Middleware):
|
class MCPAuthMiddleware(Middleware):
|
||||||
"""
|
"""
|
||||||
FastMCP middleware that authenticates tool calls via Bearer tokens.
|
FastMCP middleware that authenticates tool calls via Bearer tokens.
|
||||||
@@ -59,20 +200,27 @@ class MCPAuthMiddleware(Middleware):
|
|||||||
MCP_REQUIRE_AUTH=False.
|
MCP_REQUIRE_AUTH=False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def on_call_tool(
|
async def on_call_tool(self, context: MiddlewareContext, call_next):
|
||||||
self, context: MiddlewareContext, call_next
|
|
||||||
):
|
|
||||||
require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True)
|
require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True)
|
||||||
token_string = self._extract_token()
|
token_string = self._extract_token()
|
||||||
|
|
||||||
user = None
|
user = None
|
||||||
token = None
|
token = None
|
||||||
|
claims: dict | None = None
|
||||||
|
|
||||||
if token_string:
|
if token_string:
|
||||||
try:
|
try:
|
||||||
user, token = await sync_to_async(
|
if looks_like_jwt(token_string):
|
||||||
resolve_mcp_user, thread_sensitive=True
|
claims = await sync_to_async(
|
||||||
)(token_string)
|
resolve_mcp_jwt, thread_sensitive=True
|
||||||
|
)(token_string)
|
||||||
|
user = await sync_to_async(
|
||||||
|
_resolve_jwt_actor, thread_sensitive=True
|
||||||
|
)(claims)
|
||||||
|
else:
|
||||||
|
user, token = await sync_to_async(
|
||||||
|
resolve_mcp_user, thread_sensitive=True
|
||||||
|
)(token_string)
|
||||||
except MCPAuthError as exc:
|
except MCPAuthError as exc:
|
||||||
mcp_auth_failures_total.labels(reason=str(exc)).inc()
|
mcp_auth_failures_total.labels(reason=str(exc)).inc()
|
||||||
if require_auth:
|
if require_auth:
|
||||||
@@ -89,9 +237,13 @@ class MCPAuthMiddleware(Middleware):
|
|||||||
)
|
)
|
||||||
|
|
||||||
fastmcp_ctx = getattr(context, "fastmcp_context", None)
|
fastmcp_ctx = getattr(context, "fastmcp_context", None)
|
||||||
if fastmcp_ctx and user is not None:
|
if fastmcp_ctx is not None:
|
||||||
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
|
if user is not None:
|
||||||
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
|
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
|
||||||
|
if token is not None:
|
||||||
|
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
|
||||||
|
if claims is not None:
|
||||||
|
await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims)
|
||||||
|
|
||||||
return await call_next(context)
|
return await call_next(context)
|
||||||
|
|
||||||
@@ -111,3 +263,28 @@ class MCPAuthMiddleware(Middleware):
|
|||||||
msg = getattr(context, "message", None)
|
msg = getattr(context, "message", None)
|
||||||
params = getattr(msg, "params", None) if msg else None
|
params = getattr(msg, "params", None) if msg else None
|
||||||
return getattr(params, "name", None)
|
return getattr(params, "name", None)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Helpers ---------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_jwt_actor(claims: dict):
|
||||||
|
"""Resolve the synthetic actor for a JWT-authenticated turn.
|
||||||
|
|
||||||
|
Returns the system service user (``MCP_JWT_SERVICE_USERNAME``, default
|
||||||
|
``daedalus-service``). The user must exist and be active. JWT tokens
|
||||||
|
are not tied to per-user accounts — claims encode all authorization.
|
||||||
|
"""
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
|
||||||
|
User = get_user_model()
|
||||||
|
username = getattr(settings, "MCP_JWT_SERVICE_USERNAME", "daedalus-service")
|
||||||
|
try:
|
||||||
|
user = User.objects.get(username=username)
|
||||||
|
except User.DoesNotExist:
|
||||||
|
raise MCPAuthError(
|
||||||
|
f"JWT service user {username!r} does not exist; provision via management command."
|
||||||
|
)
|
||||||
|
if not user.is_active:
|
||||||
|
raise MCPAuthError(f"JWT service user {username!r} is disabled.")
|
||||||
|
return user
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from fastmcp.server.context import Context
|
from fastmcp.server.context import Context
|
||||||
|
|
||||||
from .auth import STATE_KEY_TOKEN, STATE_KEY_USER
|
from .auth import STATE_KEY_CLAIMS, STATE_KEY_TOKEN, STATE_KEY_USER
|
||||||
|
|
||||||
|
|
||||||
async def get_mcp_user(ctx: Context | None):
|
async def get_mcp_user(ctx: Context | None):
|
||||||
@@ -17,3 +17,10 @@ async def get_mcp_token(ctx: Context | None):
|
|||||||
if ctx is None:
|
if ctx is None:
|
||||||
return None
|
return None
|
||||||
return await ctx.get_state(STATE_KEY_TOKEN)
|
return await ctx.get_state(STATE_KEY_TOKEN)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_mcp_claims(ctx: Context | None) -> dict | None:
|
||||||
|
"""Return the JWT claims dict for this request, or None for opaque-token callers."""
|
||||||
|
if ctx is None:
|
||||||
|
return None
|
||||||
|
return await ctx.get_state(STATE_KEY_CLAIMS)
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
"""Idempotently create the service user that JWT-authenticated MCP requests act as.
|
||||||
|
|
||||||
|
Daedalus mints per-turn JWTs whose claims encode all authorization (workspace,
|
||||||
|
allowed libraries). The Django ``user`` field on the request still needs to
|
||||||
|
point at *something* — the service user is that something. It owns no data
|
||||||
|
and does not log in via the dashboard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
from django.core.management.base import BaseCommand
|
||||||
|
|
||||||
|
|
||||||
|
class Command(BaseCommand):
|
||||||
|
help = "Idempotently create or reactivate the JWT service user (default 'daedalus-service')."
|
||||||
|
|
||||||
|
def add_arguments(self, parser):
|
||||||
|
parser.add_argument("--username", default="daedalus-service")
|
||||||
|
parser.add_argument("--email", default="daedalus-service@local")
|
||||||
|
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
User = get_user_model()
|
||||||
|
username = options["username"]
|
||||||
|
email = options["email"]
|
||||||
|
|
||||||
|
user, created = User.objects.get_or_create(
|
||||||
|
username=username,
|
||||||
|
defaults={"email": email, "is_active": True},
|
||||||
|
)
|
||||||
|
if created:
|
||||||
|
# Set a random password the user cannot log in with via the UI.
|
||||||
|
user.set_password(secrets.token_urlsafe(32))
|
||||||
|
user.save(update_fields=["password"])
|
||||||
|
self.stdout.write(self.style.SUCCESS(f"Created service user {username!r}"))
|
||||||
|
else:
|
||||||
|
changed = False
|
||||||
|
if not user.is_active:
|
||||||
|
user.is_active = True
|
||||||
|
changed = True
|
||||||
|
if user.email != email:
|
||||||
|
user.email = email
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
|
user.save(update_fields=["is_active", "email"])
|
||||||
|
self.stdout.write(self.style.SUCCESS(f"Updated service user {username!r}"))
|
||||||
|
else:
|
||||||
|
self.stdout.write(f"Service user {username!r} already provisioned")
|
||||||
71
mnemosyne/mcp_server/management/commands/seed_signing_key.py
Normal file
71
mnemosyne/mcp_server/management/commands/seed_signing_key.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""Mint a new MCPSigningKey for per-turn JWTs.
|
||||||
|
|
||||||
|
The secret is generated locally, printed once, and stored hex-encoded.
|
||||||
|
Distribute the printed secret to the issuer (Daedalus) — after this
|
||||||
|
command exits, the database row is the only copy on the Mnemosyne
|
||||||
|
side, but copying the secret out of the row is allowed (admin model
|
||||||
|
exposes it). Treat it like any other shared secret.
|
||||||
|
|
||||||
|
Use ``--retire-other`` to flip every other key to ``is_active=False``
|
||||||
|
in the same transaction (typical rotation flow). Without it, the new
|
||||||
|
key joins the active set and old keys keep validating until they are
|
||||||
|
retired explicitly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from django.core.management.base import BaseCommand, CommandError
|
||||||
|
from django.db import transaction
|
||||||
|
|
||||||
|
from mcp_server.models import MCPSigningKey
|
||||||
|
|
||||||
|
|
||||||
|
class Command(BaseCommand):
|
||||||
|
help = "Generate a new HMAC signing key for per-turn JWTs and print it once."
|
||||||
|
|
||||||
|
def add_arguments(self, parser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--kid",
|
||||||
|
required=True,
|
||||||
|
help="Key ID (short, stable string — e.g. 'daedalus-1').",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--note",
|
||||||
|
default="",
|
||||||
|
help="Optional note (e.g. 'rotated 2026-05-03 after staging incident').",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--retire-other",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Mark every other MCPSigningKey row inactive in the same "
|
||||||
|
"transaction. Use during rotation."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@transaction.atomic
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
kid = options["kid"]
|
||||||
|
if MCPSigningKey.objects.filter(kid=kid).exists():
|
||||||
|
raise CommandError(f"kid {kid!r} already exists; choose a unique value.")
|
||||||
|
|
||||||
|
# 256-bit secret, hex-encoded for portability.
|
||||||
|
secret = secrets.token_hex(32)
|
||||||
|
|
||||||
|
if options["retire_other"]:
|
||||||
|
from django.utils import timezone
|
||||||
|
MCPSigningKey.objects.filter(is_active=True).update(
|
||||||
|
is_active=False, retired_at=timezone.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
key = MCPSigningKey.objects.create(
|
||||||
|
kid=kid, secret_hex=secret, is_active=True, note=options["note"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stdout.write(self.style.SUCCESS("MCP signing key created"))
|
||||||
|
self.stdout.write(f" kid: {key.kid}")
|
||||||
|
self.stdout.write(f" active: {key.is_active}")
|
||||||
|
if options["retire_other"]:
|
||||||
|
self.stdout.write(" (other active keys retired)")
|
||||||
|
self.stdout.write(self.style.WARNING(" Secret (256-bit hex — share with Daedalus):"))
|
||||||
|
self.stdout.write(f" {secret}")
|
||||||
38
mnemosyne/mcp_server/migrations/0003_mcpsigningkey.py
Normal file
38
mnemosyne/mcp_server/migrations/0003_mcpsigningkey.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""HMAC signing keys for per-turn JWTs minted by Daedalus.
|
||||||
|
|
||||||
|
Adds the MCPSigningKey table. Per-turn tokens (workspace + library claims,
|
||||||
|
exp <= 600s) are not stored — only the signing key, indexed by ``kid``,
|
||||||
|
so the signature can be validated and rotated cleanly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("mcp_server", "0002_hash_token"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="MCPSigningKey",
|
||||||
|
fields=[
|
||||||
|
(
|
||||||
|
"id",
|
||||||
|
models.BigAutoField(
|
||||||
|
auto_created=True,
|
||||||
|
primary_key=True,
|
||||||
|
serialize=False,
|
||||||
|
verbose_name="ID",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
("kid", models.CharField(max_length=64, unique=True, db_index=True)),
|
||||||
|
("secret_hex", models.CharField(max_length=128)),
|
||||||
|
("is_active", models.BooleanField(default=True)),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("retired_at", models.DateTimeField(blank=True, null=True)),
|
||||||
|
("note", models.TextField(blank=True)),
|
||||||
|
],
|
||||||
|
options={"ordering": ["-created_at"]},
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -84,3 +84,46 @@ class MCPToken(models.Model):
|
|||||||
hash prefixed with `mcp_…`. Stable per token, never reveals plaintext.
|
hash prefixed with `mcp_…`. Stable per token, never reveals plaintext.
|
||||||
"""
|
"""
|
||||||
return f"mcp_…{self.token_hash[:8]}"
|
return f"mcp_…{self.token_hash[:8]}"
|
||||||
|
|
||||||
|
|
||||||
|
class MCPSigningKeyManager(models.Manager):
|
||||||
|
def active(self):
|
||||||
|
"""Active keys, newest first. Multiple may overlap during rotation."""
|
||||||
|
return self.filter(is_active=True).order_by("-created_at")
|
||||||
|
|
||||||
|
def by_kid(self, kid: str):
|
||||||
|
return self.filter(kid=kid).first()
|
||||||
|
|
||||||
|
|
||||||
|
class MCPSigningKey(models.Model):
|
||||||
|
"""HMAC signing key for per-turn JWTs minted by Daedalus.
|
||||||
|
|
||||||
|
Per-turn tokens carry workspace + library claims and expire in minutes.
|
||||||
|
They are validated entirely off the signature + claims; no row is stored
|
||||||
|
per token. Only the *signing key* is persisted here, indexed by ``kid``.
|
||||||
|
|
||||||
|
Rotation: seed a new active key, distribute the secret to Daedalus,
|
||||||
|
flip the old one ``is_active=False``. In-flight tokens with the retired
|
||||||
|
``kid`` fail at ``exp`` (bounded by the per-turn TTL).
|
||||||
|
"""
|
||||||
|
|
||||||
|
kid = models.CharField(max_length=64, unique=True, db_index=True)
|
||||||
|
secret_hex = models.CharField(max_length=128) # 256-bit secret = 64 hex
|
||||||
|
is_active = models.BooleanField(default=True)
|
||||||
|
created_at = models.DateTimeField(auto_now_add=True)
|
||||||
|
retired_at = models.DateTimeField(null=True, blank=True)
|
||||||
|
note = models.TextField(blank=True)
|
||||||
|
|
||||||
|
objects = MCPSigningKeyManager()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
ordering = ["-created_at"]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
suffix = "active" if self.is_active else "retired"
|
||||||
|
return f"{self.kid} ({suffix})"
|
||||||
|
|
||||||
|
def retire(self):
|
||||||
|
self.is_active = False
|
||||||
|
self.retired_at = timezone.now()
|
||||||
|
self.save(update_fields=["is_active", "retired_at"])
|
||||||
|
|||||||
@@ -5,13 +5,29 @@ from __future__ import annotations
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
from fastmcp.server.context import Context
|
||||||
|
|
||||||
|
from ..context import get_mcp_claims
|
||||||
from ..metrics import record_tool_call
|
from ..metrics import record_tool_call
|
||||||
|
|
||||||
DEFAULT_LIMIT = 50
|
DEFAULT_LIMIT = 50
|
||||||
MAX_LIMIT = 200
|
MAX_LIMIT = 200
|
||||||
|
|
||||||
|
|
||||||
|
def _scope_from_claims(claims: dict | None,
|
||||||
|
arg_workspace_id: str | None) -> tuple[str | None, list[str] | None]:
|
||||||
|
"""Return (workspace_id, allowed_libraries) for a tool call.
|
||||||
|
|
||||||
|
Token claims, when present, trump tool args — that's the security
|
||||||
|
contract. Opaque-token callers (no claims) keep the legacy behavior
|
||||||
|
where the caller may pass workspace_id explicitly (typically null,
|
||||||
|
yielding global scope).
|
||||||
|
"""
|
||||||
|
if claims is not None:
|
||||||
|
return claims.get("ws"), claims.get("libs") or None
|
||||||
|
return arg_workspace_id, None
|
||||||
|
|
||||||
|
|
||||||
def _clamp(limit: int) -> int:
|
def _clamp(limit: int) -> int:
|
||||||
if limit < 1:
|
if limit < 1:
|
||||||
return 1
|
return 1
|
||||||
@@ -25,6 +41,7 @@ def register_discovery_tools(mcp):
|
|||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
# System-injected; deliberately absent from the docstring.
|
# System-injected; deliberately absent from the docstring.
|
||||||
workspace_id: str | None = None,
|
workspace_id: str | None = None,
|
||||||
|
ctx: Context | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""List Mnemosyne libraries. Each library has a content-aware library_type
|
"""List Mnemosyne libraries. Each library has a content-aware library_type
|
||||||
(fiction, nonfiction, technical, music, film, art, journal, business,
|
(fiction, nonfiction, technical, music, film, art, journal, business,
|
||||||
@@ -32,9 +49,11 @@ def register_discovery_tools(mcp):
|
|||||||
name, library_type, description for each library — use the uid or
|
name, library_type, description for each library — use the uid or
|
||||||
library_type to scope a subsequent search.
|
library_type to scope a subsequent search.
|
||||||
"""
|
"""
|
||||||
|
claims = await get_mcp_claims(ctx)
|
||||||
|
ws, libs = _scope_from_claims(claims, workspace_id)
|
||||||
with record_tool_call("list_libraries"):
|
with record_tool_call("list_libraries"):
|
||||||
return await sync_to_async(_query_libraries, thread_sensitive=True)(
|
return await sync_to_async(_query_libraries, thread_sensitive=True)(
|
||||||
_clamp(limit), max(offset, 0), workspace_id
|
_clamp(limit), max(offset, 0), ws, libs
|
||||||
)
|
)
|
||||||
|
|
||||||
@mcp.tool
|
@mcp.tool
|
||||||
@@ -44,15 +63,18 @@ def register_discovery_tools(mcp):
|
|||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
# System-injected; deliberately absent from the docstring.
|
# System-injected; deliberately absent from the docstring.
|
||||||
workspace_id: str | None = None,
|
workspace_id: str | None = None,
|
||||||
|
ctx: Context | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""List collections, optionally filtered by parent library_uid.
|
"""List collections, optionally filtered by parent library_uid.
|
||||||
Collections group related items inside a library (e.g. a series of novels,
|
Collections group related items inside a library (e.g. a series of novels,
|
||||||
a multi-volume manual). Returns uid, name, description, library_uid,
|
a multi-volume manual). Returns uid, name, description, library_uid,
|
||||||
library_name. Use the uid to scope a subsequent search to one collection.
|
library_name. Use the uid to scope a subsequent search to one collection.
|
||||||
"""
|
"""
|
||||||
|
claims = await get_mcp_claims(ctx)
|
||||||
|
ws, libs = _scope_from_claims(claims, workspace_id)
|
||||||
with record_tool_call("list_collections"):
|
with record_tool_call("list_collections"):
|
||||||
return await sync_to_async(_query_collections, thread_sensitive=True)(
|
return await sync_to_async(_query_collections, thread_sensitive=True)(
|
||||||
library_uid, _clamp(limit), max(offset, 0), workspace_id
|
library_uid, _clamp(limit), max(offset, 0), ws, libs
|
||||||
)
|
)
|
||||||
|
|
||||||
@mcp.tool
|
@mcp.tool
|
||||||
@@ -63,6 +85,7 @@ def register_discovery_tools(mcp):
|
|||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
# System-injected; deliberately absent from the docstring.
|
# System-injected; deliberately absent from the docstring.
|
||||||
workspace_id: str | None = None,
|
workspace_id: str | None = None,
|
||||||
|
ctx: Context | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""List items (the indexed documents/files), optionally filtered by
|
"""List items (the indexed documents/files), optionally filtered by
|
||||||
collection_uid or library_uid. Returns uid, title, item_type, file_type,
|
collection_uid or library_uid. Returns uid, title, item_type, file_type,
|
||||||
@@ -70,20 +93,27 @@ def register_discovery_tools(mcp):
|
|||||||
document size; use embedding_status to skip items that are not yet
|
document size; use embedding_status to skip items that are not yet
|
||||||
searchable (only 'completed' items appear in search results).
|
searchable (only 'completed' items appear in search results).
|
||||||
"""
|
"""
|
||||||
|
claims = await get_mcp_claims(ctx)
|
||||||
|
ws, libs = _scope_from_claims(claims, workspace_id)
|
||||||
with record_tool_call("list_items"):
|
with record_tool_call("list_items"):
|
||||||
return await sync_to_async(_query_items, thread_sensitive=True)(
|
return await sync_to_async(_query_items, thread_sensitive=True)(
|
||||||
collection_uid, library_uid, _clamp(limit), max(offset, 0), workspace_id
|
collection_uid, library_uid, _clamp(limit), max(offset, 0), ws, libs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_WORKSPACE_SCOPE = (
|
_WORKSPACE_SCOPE = (
|
||||||
"($workspace_id IS NULL AND l.workspace_id IS NULL OR "
|
"(($workspace_id IS NOT NULL AND l.workspace_id = $workspace_id) "
|
||||||
"l.workspace_id = $workspace_id)"
|
"OR ($allowed_libraries IS NOT NULL AND l.uid IN $allowed_libraries) "
|
||||||
|
"OR ($workspace_id IS NULL AND $allowed_libraries IS NULL "
|
||||||
|
" AND l.workspace_id IS NULL))"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _query_libraries(
|
def _query_libraries(
|
||||||
limit: int, offset: int, workspace_id: str | None = None
|
limit: int,
|
||||||
|
offset: int,
|
||||||
|
workspace_id: str | None = None,
|
||||||
|
allowed_libraries: list[str] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
from neomodel import db
|
from neomodel import db
|
||||||
|
|
||||||
@@ -92,7 +122,11 @@ def _query_libraries(
|
|||||||
f"WHERE {_WORKSPACE_SCOPE} "
|
f"WHERE {_WORKSPACE_SCOPE} "
|
||||||
"RETURN l.uid, l.name, l.library_type, l.description "
|
"RETURN l.uid, l.name, l.library_type, l.description "
|
||||||
"ORDER BY l.name SKIP $offset LIMIT $limit",
|
"ORDER BY l.name SKIP $offset LIMIT $limit",
|
||||||
{"offset": offset, "limit": limit, "workspace_id": workspace_id},
|
{
|
||||||
|
"offset": offset, "limit": limit,
|
||||||
|
"workspace_id": workspace_id,
|
||||||
|
"allowed_libraries": allowed_libraries,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"libraries": [
|
"libraries": [
|
||||||
@@ -110,11 +144,19 @@ def _query_libraries(
|
|||||||
|
|
||||||
|
|
||||||
def _query_collections(
|
def _query_collections(
|
||||||
library_uid: str | None, limit: int, offset: int,
|
library_uid: str | None,
|
||||||
|
limit: int,
|
||||||
|
offset: int,
|
||||||
workspace_id: str | None = None,
|
workspace_id: str | None = None,
|
||||||
|
allowed_libraries: list[str] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
from neomodel import db
|
from neomodel import db
|
||||||
|
|
||||||
|
base_params = {
|
||||||
|
"offset": offset, "limit": limit,
|
||||||
|
"workspace_id": workspace_id,
|
||||||
|
"allowed_libraries": allowed_libraries,
|
||||||
|
}
|
||||||
if library_uid:
|
if library_uid:
|
||||||
cypher = (
|
cypher = (
|
||||||
"MATCH (l:Library {uid: $library_uid})-[:CONTAINS]->(c:Collection) "
|
"MATCH (l:Library {uid: $library_uid})-[:CONTAINS]->(c:Collection) "
|
||||||
@@ -122,10 +164,7 @@ def _query_collections(
|
|||||||
"RETURN c.uid, c.name, c.description, l.uid, l.name "
|
"RETURN c.uid, c.name, c.description, l.uid, l.name "
|
||||||
"ORDER BY c.name SKIP $offset LIMIT $limit"
|
"ORDER BY c.name SKIP $offset LIMIT $limit"
|
||||||
)
|
)
|
||||||
params = {
|
params = {**base_params, "library_uid": library_uid}
|
||||||
"library_uid": library_uid, "offset": offset, "limit": limit,
|
|
||||||
"workspace_id": workspace_id,
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
cypher = (
|
cypher = (
|
||||||
"MATCH (l:Library)-[:CONTAINS]->(c:Collection) "
|
"MATCH (l:Library)-[:CONTAINS]->(c:Collection) "
|
||||||
@@ -133,10 +172,7 @@ def _query_collections(
|
|||||||
"RETURN c.uid, c.name, c.description, l.uid, l.name "
|
"RETURN c.uid, c.name, c.description, l.uid, l.name "
|
||||||
"ORDER BY l.name, c.name SKIP $offset LIMIT $limit"
|
"ORDER BY l.name, c.name SKIP $offset LIMIT $limit"
|
||||||
)
|
)
|
||||||
params = {
|
params = base_params
|
||||||
"offset": offset, "limit": limit,
|
|
||||||
"workspace_id": workspace_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, _ = db.cypher_query(cypher, params)
|
rows, _ = db.cypher_query(cypher, params)
|
||||||
return {
|
return {
|
||||||
@@ -161,6 +197,7 @@ def _query_items(
|
|||||||
limit: int,
|
limit: int,
|
||||||
offset: int,
|
offset: int,
|
||||||
workspace_id: str | None = None,
|
workspace_id: str | None = None,
|
||||||
|
allowed_libraries: list[str] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
from neomodel import db
|
from neomodel import db
|
||||||
|
|
||||||
@@ -168,6 +205,7 @@ def _query_items(
|
|||||||
params: dict[str, Any] = {
|
params: dict[str, Any] = {
|
||||||
"offset": offset, "limit": limit,
|
"offset": offset, "limit": limit,
|
||||||
"workspace_id": workspace_id,
|
"workspace_id": workspace_id,
|
||||||
|
"allowed_libraries": allowed_libraries,
|
||||||
}
|
}
|
||||||
if collection_uid:
|
if collection_uid:
|
||||||
where.append("c.uid = $collection_uid")
|
where.append("c.uid = $collection_uid")
|
||||||
|
|||||||
@@ -10,9 +10,18 @@ from django.conf import settings
|
|||||||
from django.core.files.storage import default_storage
|
from django.core.files.storage import default_storage
|
||||||
from fastmcp.server.context import Context
|
from fastmcp.server.context import Context
|
||||||
|
|
||||||
from ..context import get_mcp_user
|
from ..context import get_mcp_claims, get_mcp_user
|
||||||
from ..metrics import record_tool_call
|
from ..metrics import record_tool_call
|
||||||
|
|
||||||
|
|
||||||
|
def _scope_from_claims(claims: dict | None,
|
||||||
|
arg_workspace_id: str | None) -> tuple[str | None, list[str] | None]:
|
||||||
|
"""Return (workspace_id, allowed_libraries) for a tool call. Token claims
|
||||||
|
trump tool args when present."""
|
||||||
|
if claims is not None:
|
||||||
|
return claims.get("ws"), claims.get("libs") or None
|
||||||
|
return arg_workspace_id, None
|
||||||
|
|
||||||
DEFAULT_SEARCH_TYPES = ["vector", "fulltext", "graph"]
|
DEFAULT_SEARCH_TYPES = ["vector", "fulltext", "graph"]
|
||||||
|
|
||||||
|
|
||||||
@@ -47,6 +56,8 @@ def register_search_tools(mcp):
|
|||||||
score, and source. Also returns matching images when include_images=True.
|
score, and source. Also returns matching images when include_images=True.
|
||||||
"""
|
"""
|
||||||
types = search_types or DEFAULT_SEARCH_TYPES
|
types = search_types or DEFAULT_SEARCH_TYPES
|
||||||
|
claims = await get_mcp_claims(ctx)
|
||||||
|
ws, libs = _scope_from_claims(claims, workspace_id)
|
||||||
with record_tool_call("search"):
|
with record_tool_call("search"):
|
||||||
user = await get_mcp_user(ctx)
|
user = await get_mcp_user(ctx)
|
||||||
return await sync_to_async(_run_search, thread_sensitive=True)(
|
return await sync_to_async(_run_search, thread_sensitive=True)(
|
||||||
@@ -55,7 +66,8 @@ def register_search_tools(mcp):
|
|||||||
library_uid=library_uid,
|
library_uid=library_uid,
|
||||||
library_type=library_type,
|
library_type=library_type,
|
||||||
collection_uid=collection_uid,
|
collection_uid=collection_uid,
|
||||||
workspace_id=workspace_id,
|
workspace_id=ws,
|
||||||
|
allowed_libraries=libs,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
rerank=rerank,
|
rerank=rerank,
|
||||||
include_images=include_images,
|
include_images=include_images,
|
||||||
@@ -75,14 +87,17 @@ def register_search_tools(mcp):
|
|||||||
item_uid, item_title, library_type, text. Use this when the 500-character
|
item_uid, item_title, library_type, text. Use this when the 500-character
|
||||||
text_preview from `search` isn't enough.
|
text_preview from `search` isn't enough.
|
||||||
"""
|
"""
|
||||||
|
claims = await get_mcp_claims(ctx)
|
||||||
|
ws, libs = _scope_from_claims(claims, workspace_id)
|
||||||
with record_tool_call("get_chunk"):
|
with record_tool_call("get_chunk"):
|
||||||
return await sync_to_async(_load_chunk, thread_sensitive=True)(
|
return await sync_to_async(_load_chunk, thread_sensitive=True)(
|
||||||
chunk_uid, workspace_id
|
chunk_uid, ws, libs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _run_search(*, user, query, library_uid, library_type, collection_uid,
|
def _run_search(*, user, query, library_uid, library_type, collection_uid,
|
||||||
workspace_id, limit, rerank, include_images, search_types) -> dict[str, Any]:
|
workspace_id, allowed_libraries, limit, rerank, include_images,
|
||||||
|
search_types) -> dict[str, Any]:
|
||||||
from library.services.search import SearchRequest, SearchService
|
from library.services.search import SearchRequest, SearchService
|
||||||
|
|
||||||
req = SearchRequest(
|
req = SearchRequest(
|
||||||
@@ -91,6 +106,7 @@ def _run_search(*, user, query, library_uid, library_type, collection_uid,
|
|||||||
library_type=library_type,
|
library_type=library_type,
|
||||||
collection_uid=collection_uid,
|
collection_uid=collection_uid,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
|
allowed_libraries=allowed_libraries,
|
||||||
search_types=search_types,
|
search_types=search_types,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
vector_top_k=getattr(settings, "SEARCH_VECTOR_TOP_K", 50),
|
vector_top_k=getattr(settings, "SEARCH_VECTOR_TOP_K", 50),
|
||||||
@@ -112,17 +128,27 @@ def _run_search(*, user, query, library_uid, library_type, collection_uid,
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _load_chunk(chunk_uid: str, workspace_id: str | None = None) -> dict[str, Any]:
|
def _load_chunk(
|
||||||
|
chunk_uid: str,
|
||||||
|
workspace_id: str | None = None,
|
||||||
|
allowed_libraries: list[str] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
from neomodel import db
|
from neomodel import db
|
||||||
|
|
||||||
rows, _ = db.cypher_query(
|
rows, _ = db.cypher_query(
|
||||||
"MATCH (l:Library)-[:CONTAINS]->(:Collection)-[:CONTAINS]->"
|
"MATCH (l:Library)-[:CONTAINS]->(:Collection)-[:CONTAINS]->"
|
||||||
"(i:Item)-[:HAS_CHUNK]->(c:Chunk {uid: $uid}) "
|
"(i:Item)-[:HAS_CHUNK]->(c:Chunk {uid: $uid}) "
|
||||||
"WHERE ($workspace_id IS NULL AND l.workspace_id IS NULL OR "
|
"WHERE (($workspace_id IS NOT NULL AND l.workspace_id = $workspace_id) "
|
||||||
" l.workspace_id = $workspace_id) "
|
" OR ($allowed_libraries IS NOT NULL AND l.uid IN $allowed_libraries) "
|
||||||
|
" OR ($workspace_id IS NULL AND $allowed_libraries IS NULL "
|
||||||
|
" AND l.workspace_id IS NULL)) "
|
||||||
"RETURN c.uid, c.chunk_index, c.chunk_s3_key, "
|
"RETURN c.uid, c.chunk_index, c.chunk_s3_key, "
|
||||||
"i.uid, i.title, l.library_type LIMIT 1",
|
"i.uid, i.title, l.library_type LIMIT 1",
|
||||||
{"uid": chunk_uid, "workspace_id": workspace_id},
|
{
|
||||||
|
"uid": chunk_uid,
|
||||||
|
"workspace_id": workspace_id,
|
||||||
|
"allowed_libraries": allowed_libraries,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if not rows:
|
if not rows:
|
||||||
raise ValueError(f"Chunk not found: {chunk_uid}")
|
raise ValueError(f"Chunk not found: {chunk_uid}")
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ dependencies = [
|
|||||||
# Phase 5: MCP Server
|
# Phase 5: MCP Server
|
||||||
"fastmcp>=2.0,<3.0",
|
"fastmcp>=2.0,<3.0",
|
||||||
"uvicorn[standard]>=0.30,<1.0",
|
"uvicorn[standard]>=0.30,<1.0",
|
||||||
|
# Phase 6: Per-turn signed JWTs from Daedalus
|
||||||
|
"PyJWT>=2.8,<3.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
Reference in New Issue
Block a user