diff --git a/mnemosyne/mcp_server/tests/test_middleware_extract.py b/mnemosyne/mcp_server/tests/test_middleware_extract.py new file mode 100644 index 0000000..2caedfc --- /dev/null +++ b/mnemosyne/mcp_server/tests/test_middleware_extract.py @@ -0,0 +1,250 @@ +"""Unit tests for ``MCPAuthMiddleware._extract_tool_name`` / ``_extract_token``. + +Both of these helpers turned out to be load-bearing during the +Pallas↔Mnemosyne shakedown — see ``pallas/docs/pallas.md`` §"Incidents +& Lessons Learned" for the full story: + +* ``_extract_tool_name`` reads ``context.message.name`` directly. + FastMCP's ``Middleware.on_call_tool`` is typed as + ``MiddlewareContext[CallToolRequestParams]``, so ``message`` *is* the + params object — there is no nested ``.params``. The legacy + ``message.params.name`` access silently returned ``None`` and caused + the ``_PUBLIC_TOOLS`` bypass for ``get_health`` to stop working + (every call was treated as "no name known", which then short-circuited + the per-tool ACL to the fail-open branch in ``can_use_tool``). + +* ``_extract_token`` relies on ``fastmcp.server.dependencies.get_http_request`` + which raises ``RuntimeError`` outside of an active HTTP dispatch + (background tasks, pre-session init hooks). It must degrade to + ``None`` in that case so the caller can decide whether to raise + ``PermissionError``. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import patch + +from django.test import SimpleTestCase + +from mcp_server.auth import MCPAuthMiddleware + + +# --------------------------------------------------------------------------- +# _extract_tool_name +# --------------------------------------------------------------------------- + + +class ExtractToolNameTest(SimpleTestCase): + """Exercise every branch of ``MCPAuthMiddleware._extract_tool_name``. + + The helper is a pure staticmethod, so we feed it hand-built + ``MiddlewareContext`` stand-ins rather than spinning up a real + FastMCP server. + """ + + def _ctx(self, message): + """Build a minimal stand-in for ``MiddlewareContext``. + + Only ``.message`` is read by the helper, so a ``SimpleNamespace`` + with that attribute is sufficient — keeping this test decoupled + from FastMCP's evolving class shape. + """ + return SimpleNamespace(message=message) + + def test_reads_name_from_message_directly(self): + """Current (correct) shape: ``message`` IS the CallToolRequestParams.""" + # Shape matches ``mcp.types.CallToolRequestParams`` — verified via + # ``CallToolRequestParams(name='search', arguments={...}).name``. + msg = SimpleNamespace(name="search", arguments={"q": "hello"}) + self.assertEqual( + MCPAuthMiddleware._extract_tool_name(self._ctx(msg)), + "search", + ) + + def test_falls_back_to_nested_params_name(self): + """Legacy safety net for a hypothetical future FastMCP shape. + + Older code read ``message.params.name``. We keep that path as a + fallback so a version bump that reverts the shape doesn't silently + re-introduce the public-tools bypass bug. + """ + msg = SimpleNamespace( + name=None, + params=SimpleNamespace(name="get_health"), + ) + self.assertEqual( + MCPAuthMiddleware._extract_tool_name(self._ctx(msg)), + "get_health", + ) + + def test_prefers_direct_name_over_nested_params(self): + """If both are populated, the direct attribute wins. + + That's the current FastMCP reality and the shape we want to + exercise in production; the nested fallback must not shadow it. + """ + msg = SimpleNamespace( + name="search", + params=SimpleNamespace(name="wrong"), + ) + self.assertEqual( + MCPAuthMiddleware._extract_tool_name(self._ctx(msg)), + "search", + ) + + def test_returns_none_when_message_missing(self): + """A context with no ``.message`` at all (defensive).""" + ctx = SimpleNamespace() # no message attr + self.assertIsNone(MCPAuthMiddleware._extract_tool_name(ctx)) + + def test_returns_none_when_message_is_none(self): + self.assertIsNone( + MCPAuthMiddleware._extract_tool_name(self._ctx(None)) + ) + + def test_returns_none_when_no_name_anywhere(self): + """No direct ``name`` and no ``params`` — nothing to extract.""" + msg = SimpleNamespace(name=None) + self.assertIsNone( + MCPAuthMiddleware._extract_tool_name(self._ctx(msg)) + ) + + def test_returns_none_when_params_has_no_name(self): + msg = SimpleNamespace(name=None, params=SimpleNamespace()) + self.assertIsNone( + MCPAuthMiddleware._extract_tool_name(self._ctx(msg)) + ) + + def test_real_call_tool_request_params_shape(self): + """Contract test against the actual ``mcp.types.CallToolRequestParams``. + + Skips silently if the ``mcp`` package isn't importable in this + test environment — the rest of the suite uses hand-built + namespaces precisely so it stays runnable in minimal envs. + """ + try: + from mcp.types import CallToolRequestParams + except Exception: # pragma: no cover - env-dependent + self.skipTest("mcp package not installed in this environment") + + params = CallToolRequestParams(name="list_libraries", arguments={}) + self.assertEqual( + MCPAuthMiddleware._extract_tool_name(self._ctx(params)), + "list_libraries", + ) + + +# --------------------------------------------------------------------------- +# _extract_token +# --------------------------------------------------------------------------- + + +class _FakeHeaders: + """Starlette-style headers with case-insensitive ``get`` and ``keys``.""" + + def __init__(self, mapping: dict[str, str]): + # Starlette itself is case-insensitive; emulate that closely enough + # for the helper without depending on Starlette in test deps. + self._store = {k.lower(): v for k, v in mapping.items()} + + def get(self, key: str, default: str = "") -> str: + return self._store.get(key.lower(), default) + + def keys(self): + return list(self._store.keys()) + + +class _FakeRequest: + def __init__(self, headers: dict[str, str], url: str = "http://x/mcp"): + self.headers = _FakeHeaders(headers) + self.url = url + + +class ExtractTokenTest(SimpleTestCase): + """Exercise ``MCPAuthMiddleware._extract_token``. + + We patch ``mcp_server.auth.get_http_request`` directly rather than + reaching into FastMCP internals — the helper only cares about the + return value (or raised ``RuntimeError``) of that import-bound name. + """ + + def test_returns_token_when_bearer_header_present(self): + request = _FakeRequest({"Authorization": "Bearer abc.def.ghi"}) + with patch("mcp_server.auth.get_http_request", return_value=request): + self.assertEqual( + MCPAuthMiddleware._extract_token(), + "abc.def.ghi", + ) + + def test_accepts_lowercase_bearer_scheme(self): + """Some clients emit ``bearer`` (lowercase) — accept it.""" + request = _FakeRequest({"Authorization": "bearer xyz"}) + with patch("mcp_server.auth.get_http_request", return_value=request): + self.assertEqual(MCPAuthMiddleware._extract_token(), "xyz") + + def test_accepts_lowercase_header_name(self): + """HTTP/2 normalizes header names to lowercase; some proxies follow. + + Belt-and-braces lookup in ``_extract_token`` tries both forms. + """ + request = _FakeRequest({"authorization": "Bearer lower-case-hdr"}) + with patch("mcp_server.auth.get_http_request", return_value=request): + self.assertEqual( + MCPAuthMiddleware._extract_token(), + "lower-case-hdr", + ) + + def test_strips_surrounding_whitespace(self): + request = _FakeRequest({"Authorization": "Bearer padded-token "}) + with patch("mcp_server.auth.get_http_request", return_value=request): + self.assertEqual( + MCPAuthMiddleware._extract_token(), + "padded-token", + ) + + def test_returns_none_when_header_missing(self): + request = _FakeRequest({}) + with patch("mcp_server.auth.get_http_request", return_value=request): + self.assertIsNone(MCPAuthMiddleware._extract_token()) + + def test_returns_none_when_header_is_empty_bearer(self): + """``Authorization: Bearer `` with no value should read as missing.""" + request = _FakeRequest({"Authorization": "Bearer "}) + with patch("mcp_server.auth.get_http_request", return_value=request): + self.assertIsNone(MCPAuthMiddleware._extract_token()) + + def test_returns_none_when_scheme_is_not_bearer(self): + """Non-Bearer schemes (Basic, Digest, etc.) are ignored.""" + request = _FakeRequest({"Authorization": "Basic dXNlcjpwdw=="}) + with patch("mcp_server.auth.get_http_request", return_value=request): + self.assertIsNone(MCPAuthMiddleware._extract_token()) + + def test_runtime_error_from_get_http_request_degrades_to_none(self): + """Outside an HTTP dispatch, ``get_http_request`` raises. + + The helper must swallow that and return ``None`` so the caller + (typically ``on_call_tool``) can decide whether to raise + ``PermissionError`` based on ``MCP_REQUIRE_AUTH`` — rather than + letting a bare RuntimeError bubble out into the FastMCP layer + where it gets rewrapped into an opaque ``CallToolResult(isError=True)``. + """ + with patch( + "mcp_server.auth.get_http_request", + side_effect=RuntimeError("no active request"), + ): + self.assertIsNone(MCPAuthMiddleware._extract_token()) + + def test_other_exceptions_propagate(self): + """Only ``RuntimeError`` is caught — other errors must surface. + + If ``get_http_request`` ever starts raising something unexpected + we want a loud failure in logs, not a silent ``None`` that masks + the bug. + """ + with patch( + "mcp_server.auth.get_http_request", + side_effect=ValueError("something else entirely"), + ): + with self.assertRaises(ValueError): + MCPAuthMiddleware._extract_token()