feat: replace server-side RAG with MCP retrieval primitives
- Remove Phase 4 RAG pipeline in favor of retrieval-only architecture - Add FastMCP server exposing search, get_chunk, list_libraries tools - Mount MCP endpoints (streamable HTTP + SSE) via Starlette in ASGI config - Update README to clarify Mnemosyne is a retrieval engine, not RAG - Let calling LLMs drive synthesis and iterative retrieval themselves
This commit is contained in:
105
mnemosyne/mcp_server/auth.py
Normal file
105
mnemosyne/mcp_server/auth.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""MCP token resolution and FastMCP middleware for bearer-token auth."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.utils import timezone
|
||||
from fastmcp.server.dependencies import get_http_request
|
||||
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
||||
|
||||
from .metrics import mcp_auth_failures_total
|
||||
from .models import MCPToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STATE_KEY_USER = "mcp_user"
|
||||
STATE_KEY_TOKEN = "mcp_token"
|
||||
|
||||
|
||||
class MCPAuthError(Exception):
|
||||
"""Raised when a bearer token cannot be resolved to a valid user."""
|
||||
|
||||
|
||||
def resolve_mcp_user(token_string: str):
|
||||
"""Resolve a bearer token to (user, MCPToken). Raises MCPAuthError on any failure."""
|
||||
try:
|
||||
token = MCPToken.objects.select_related("user").get(token=token_string)
|
||||
except MCPToken.DoesNotExist:
|
||||
raise MCPAuthError("Invalid MCP token.")
|
||||
|
||||
if not token.is_active:
|
||||
raise MCPAuthError("Token has been deactivated.")
|
||||
if token.expires_at and token.expires_at < timezone.now():
|
||||
raise MCPAuthError("Token has expired.")
|
||||
if not token.user.is_active:
|
||||
raise MCPAuthError("User account is disabled.")
|
||||
|
||||
token.record_usage()
|
||||
return token.user, token
|
||||
|
||||
|
||||
class MCPAuthMiddleware(Middleware):
|
||||
"""
|
||||
FastMCP middleware that authenticates tool calls via Bearer tokens.
|
||||
|
||||
Listing tools/resources is permitted unauthenticated so clients can
|
||||
discover the surface; calling a tool requires a valid token unless
|
||||
MCP_REQUIRE_AUTH=False.
|
||||
"""
|
||||
|
||||
async def on_call_tool(
|
||||
self, context: MiddlewareContext, call_next
|
||||
):
|
||||
require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True)
|
||||
token_string = self._extract_token()
|
||||
|
||||
user = None
|
||||
token = None
|
||||
|
||||
if token_string:
|
||||
try:
|
||||
user, token = await sync_to_async(
|
||||
resolve_mcp_user, thread_sensitive=True
|
||||
)(token_string)
|
||||
except MCPAuthError as exc:
|
||||
mcp_auth_failures_total.labels(reason=str(exc)).inc()
|
||||
if require_auth:
|
||||
raise PermissionError(str(exc))
|
||||
elif require_auth:
|
||||
mcp_auth_failures_total.labels(reason="missing_token").inc()
|
||||
raise PermissionError("Authentication required. Provide a Bearer token.")
|
||||
|
||||
tool_name = self._extract_tool_name(context)
|
||||
if token and tool_name and not token.can_use_tool(tool_name):
|
||||
mcp_auth_failures_total.labels(reason="tool_not_allowed").inc()
|
||||
raise PermissionError(
|
||||
f"Token does not have permission to call '{tool_name}'."
|
||||
)
|
||||
|
||||
fastmcp_ctx = getattr(context, "fastmcp_context", None)
|
||||
if fastmcp_ctx and user is not None:
|
||||
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
|
||||
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
|
||||
|
||||
return await call_next(context)
|
||||
|
||||
@staticmethod
|
||||
def _extract_token() -> str | None:
|
||||
try:
|
||||
request = get_http_request()
|
||||
except RuntimeError:
|
||||
return None
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header[7:].strip() or None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_name(context: MiddlewareContext) -> str | None:
|
||||
msg = getattr(context, "message", None)
|
||||
params = getattr(msg, "params", None) if msg else None
|
||||
return getattr(params, "name", None)
|
||||
Reference in New Issue
Block a user