Replace plaintext token storage with SHA-256 hashes so leaked database contents cannot be used to authenticate. Plaintext is generated, shown once at creation time, and never persisted. - Add `hash_token()` helper and `MCPTokenManager.create_token()` that returns `(instance, plaintext)`. - Replace `token` field with indexed `token_hash`; look up bearers by hashing the incoming value. - Update dashboard, management command, and admin to surface plaintext only at creation. Disable admin "add" since it cannot reveal plaintext. - Migration drops the old `token` column and adds `token_hash`; pre-existing tokens are invalidated and must be reissued.
114 lines
3.7 KiB
Python
114 lines
3.7 KiB
Python
"""MCP token resolution and FastMCP middleware for bearer-token auth."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
from asgiref.sync import sync_to_async
|
|
from django.conf import settings
|
|
from django.contrib.auth import get_user_model
|
|
from django.utils import timezone
|
|
from fastmcp.server.dependencies import get_http_request
|
|
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
|
|
|
from .metrics import mcp_auth_failures_total
|
|
from .models import MCPToken, hash_token
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
STATE_KEY_USER = "mcp_user"
|
|
STATE_KEY_TOKEN = "mcp_token"
|
|
|
|
|
|
class MCPAuthError(Exception):
|
|
"""Raised when a bearer token cannot be resolved to a valid user."""
|
|
|
|
|
|
def resolve_mcp_user(token_string: str):
|
|
"""Resolve a bearer token to (user, MCPToken). Raises MCPAuthError on any failure.
|
|
|
|
Hashes the incoming bearer and looks up by the hash — plaintext is never
|
|
stored or compared directly.
|
|
"""
|
|
try:
|
|
token = (
|
|
MCPToken.objects
|
|
.select_related("user")
|
|
.get(token_hash=hash_token(token_string))
|
|
)
|
|
except MCPToken.DoesNotExist:
|
|
raise MCPAuthError("Invalid MCP token.")
|
|
|
|
if not token.is_active:
|
|
raise MCPAuthError("Token has been deactivated.")
|
|
if token.expires_at and token.expires_at < timezone.now():
|
|
raise MCPAuthError("Token has expired.")
|
|
if not token.user.is_active:
|
|
raise MCPAuthError("User account is disabled.")
|
|
|
|
token.record_usage()
|
|
return token.user, token
|
|
|
|
|
|
class MCPAuthMiddleware(Middleware):
|
|
"""
|
|
FastMCP middleware that authenticates tool calls via Bearer tokens.
|
|
|
|
Listing tools/resources is permitted unauthenticated so clients can
|
|
discover the surface; calling a tool requires a valid token unless
|
|
MCP_REQUIRE_AUTH=False.
|
|
"""
|
|
|
|
async def on_call_tool(
|
|
self, context: MiddlewareContext, call_next
|
|
):
|
|
require_auth = getattr(settings, "MCP_REQUIRE_AUTH", True)
|
|
token_string = self._extract_token()
|
|
|
|
user = None
|
|
token = None
|
|
|
|
if token_string:
|
|
try:
|
|
user, token = await sync_to_async(
|
|
resolve_mcp_user, thread_sensitive=True
|
|
)(token_string)
|
|
except MCPAuthError as exc:
|
|
mcp_auth_failures_total.labels(reason=str(exc)).inc()
|
|
if require_auth:
|
|
raise PermissionError(str(exc))
|
|
elif require_auth:
|
|
mcp_auth_failures_total.labels(reason="missing_token").inc()
|
|
raise PermissionError("Authentication required. Provide a Bearer token.")
|
|
|
|
tool_name = self._extract_tool_name(context)
|
|
if token and tool_name and not token.can_use_tool(tool_name):
|
|
mcp_auth_failures_total.labels(reason="tool_not_allowed").inc()
|
|
raise PermissionError(
|
|
f"Token does not have permission to call '{tool_name}'."
|
|
)
|
|
|
|
fastmcp_ctx = getattr(context, "fastmcp_context", None)
|
|
if fastmcp_ctx and user is not None:
|
|
await fastmcp_ctx.set_state(STATE_KEY_USER, user)
|
|
await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token)
|
|
|
|
return await call_next(context)
|
|
|
|
@staticmethod
|
|
def _extract_token() -> str | None:
|
|
try:
|
|
request = get_http_request()
|
|
except RuntimeError:
|
|
return None
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if auth_header.startswith("Bearer "):
|
|
return auth_header[7:].strip() or None
|
|
return None
|
|
|
|
@staticmethod
|
|
def _extract_tool_name(context: MiddlewareContext) -> str | None:
|
|
msg = getattr(context, "message", None)
|
|
params = getattr(msg, "params", None) if msg else None
|
|
return getattr(params, "name", None)
|