diff --git a/mnemosyne/mcp_server/auth.py b/mnemosyne/mcp_server/auth.py index e532b68..63563d8 100644 --- a/mnemosyne/mcp_server/auth.py +++ b/mnemosyne/mcp_server/auth.py @@ -410,15 +410,22 @@ class MCPAuthMiddleware(Middleware): fastmcp_ctx = getattr(context, "fastmcp_context", None) if fastmcp_ctx is not None: + # ``Context.set_state`` is a synchronous method in FastMCP; it + # stores into ``self._state`` and returns ``None``. Awaiting its + # return value raises ``TypeError: object NoneType can't be used + # in 'await' expression`` which propagates through FastMCP's + # dispatch as an opaque string-valued ``CallToolResult`` — + # exactly the symptom documented in ``pallas._fastagent_patch``. + # Call synchronously. if user is not None: - await fastmcp_ctx.set_state(STATE_KEY_USER, user) + fastmcp_ctx.set_state(STATE_KEY_USER, user) if token is not None: - await fastmcp_ctx.set_state(STATE_KEY_TOKEN, token) + fastmcp_ctx.set_state(STATE_KEY_TOKEN, token) if claims is not None: - await fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims) + fastmcp_ctx.set_state(STATE_KEY_CLAIMS, claims) # Always publish resolved_libraries — None means "no auth # information" and the tools treat that as fail-closed. - await fastmcp_ctx.set_state( + fastmcp_ctx.set_state( STATE_KEY_RESOLVED_LIBRARIES, resolved_libraries ) diff --git a/mnemosyne/mcp_server/context.py b/mnemosyne/mcp_server/context.py index 1a0c5f7..3f888f5 100644 --- a/mnemosyne/mcp_server/context.py +++ b/mnemosyne/mcp_server/context.py @@ -32,23 +32,32 @@ from .auth import ( ) +# ``Context.get_state`` is a synchronous method in FastMCP — it returns the +# stored value (``Any``) or ``None`` if the key is absent. Awaiting the +# returned value raises ``TypeError: object NoneType can't be used in 'await' +# expression`` whenever the value is ``None`` (and is semantically wrong even +# when it isn't). These helpers stay ``async def`` so call sites (and their +# ``await`` usage) don't have to change, but they call ``get_state`` +# synchronously. + + async def get_mcp_user(ctx: Context | None): if ctx is None: return None - return await ctx.get_state(STATE_KEY_USER) + return ctx.get_state(STATE_KEY_USER) async def get_mcp_token(ctx: Context | None): if ctx is None: return None - return await ctx.get_state(STATE_KEY_TOKEN) + return ctx.get_state(STATE_KEY_TOKEN) async def get_mcp_claims(ctx: Context | None) -> dict | None: """Return the JWT claims dict for this request, or None for opaque-token callers.""" if ctx is None: return None - return await ctx.get_state(STATE_KEY_CLAIMS) + return ctx.get_state(STATE_KEY_CLAIMS) async def get_mcp_resolved_libraries(ctx: Context | None) -> list[str] | None: @@ -68,4 +77,4 @@ async def get_mcp_resolved_libraries(ctx: Context | None) -> list[str] | None: """ if ctx is None: return None - return await ctx.get_state(STATE_KEY_RESOLVED_LIBRARIES) + return ctx.get_state(STATE_KEY_RESOLVED_LIBRARIES) diff --git a/mnemosyne/mcp_server/tests/test_context_state.py b/mnemosyne/mcp_server/tests/test_context_state.py new file mode 100644 index 0000000..29c9a75 --- /dev/null +++ b/mnemosyne/mcp_server/tests/test_context_state.py @@ -0,0 +1,210 @@ +"""Regression tests: FastMCP ``Context.set_state``/``get_state`` are sync. + +The ``mnemosyne__search`` tool call (and every other authenticated tool on +Mnemosyne) was dead-on-arrival from day one because ``mcp_server.auth`` +and ``mcp_server.context`` both treated FastMCP's ``Context.set_state`` +and ``Context.get_state`` as coroutines. They are not — they are plain +synchronous methods that return ``None`` / ``Any`` respectively. + +Awaiting ``set_state``'s return value raised +``TypeError: object NoneType can't be used in 'await' expression`` +*inside* ``MCPAuthMiddleware.on_call_tool``, which FastMCP caught and +repackaged as an opaque string-valued ``CallToolResult(isError=True)``. +Pallas/fast-agent forwarded that string to the calling LLM as the tool +result, which is why the symptom looked like a downstream auth failure +when it was actually our own code awaiting a ``None``. + +These tests lock the invariant in so any future ``await ctx.set_state`` +/ ``await ctx.get_state`` regression fails loudly rather than silently +breaking every tool call again. +""" + +from __future__ import annotations + +import asyncio +import inspect +from unittest import mock + +from django.test import SimpleTestCase +from fastmcp.server.context import Context + +from mcp_server import context as context_module +from mcp_server.auth import ( + STATE_KEY_CLAIMS, + STATE_KEY_RESOLVED_LIBRARIES, + STATE_KEY_TOKEN, + STATE_KEY_USER, +) + + +def _run(coro): + """Synchronously execute a coroutine. ``asyncio.run`` per test keeps + each case hermetic (no shared event loop state between cases).""" + return asyncio.run(coro) + + +class FastMCPContextStateShapeTest(SimpleTestCase): + """Contract test against the real ``fastmcp.server.context.Context``. + + This is the ground truth for everything else in this file — if a + future FastMCP release flips ``set_state`` / ``get_state`` to async, + this test will turn red and the production helpers should be + revisited rather than silently broken by the existing ``return ctx + .get_state(...)`` pattern. + """ + + def test_set_state_is_synchronous(self): + self.assertFalse( + inspect.iscoroutinefunction(Context.set_state), + "Context.set_state unexpectedly became a coroutine; " + "mcp_server.auth.MCPAuthMiddleware.on_call_tool must switch " + "back to `await fastmcp_ctx.set_state(...)` if that happens.", + ) + + def test_get_state_is_synchronous(self): + self.assertFalse( + inspect.iscoroutinefunction(Context.get_state), + "Context.get_state unexpectedly became a coroutine; " + "mcp_server.context.get_mcp_* helpers must switch back to " + "`await ctx.get_state(...)` if that happens.", + ) + + +class _FakeContext: + """Minimal stand-in for ``fastmcp.server.context.Context``. + + We only need the ``get_state`` surface the helpers touch. Using a + hand-built fake (rather than instantiating a real ``Context``, which + requires a live FastMCP server + session) keeps the test hermetic. + + The shape mirrors the real one deliberately: ``get_state`` is a + plain sync method that returns ``self._state.get(key)``. Any test + that accidentally reintroduces ``await ctx.get_state(...)`` in the + helpers will hit ``TypeError: object NoneType can't be used in + 'await' expression`` against this fake, just as it does in prod. + """ + + def __init__(self, state: dict | None = None): + self._state = dict(state or {}) + + def set_state(self, key, value): + self._state[key] = value + + def get_state(self, key): + return self._state.get(key) + + +class GetMCPStateHelpersTest(SimpleTestCase): + """Every ``get_mcp_*`` helper must call ``ctx.get_state`` synchronously. + + The helpers themselves are ``async def`` for call-site ergonomics — + every tool already does ``await get_mcp_resolved_libraries(ctx)`` — + but they must not await ``get_state``'s return value. + """ + + def test_get_mcp_user_returns_stored_value(self): + ctx = _FakeContext({STATE_KEY_USER: mock.sentinel.user}) + self.assertIs( + _run(context_module.get_mcp_user(ctx)), + mock.sentinel.user, + ) + + def test_get_mcp_token_returns_stored_value(self): + ctx = _FakeContext({STATE_KEY_TOKEN: mock.sentinel.token}) + self.assertIs( + _run(context_module.get_mcp_token(ctx)), + mock.sentinel.token, + ) + + def test_get_mcp_claims_returns_stored_value(self): + claims = {"sub": "team:abc", "typ": "team"} + ctx = _FakeContext({STATE_KEY_CLAIMS: claims}) + self.assertEqual( + _run(context_module.get_mcp_claims(ctx)), + claims, + ) + + def test_get_mcp_resolved_libraries_returns_stored_value(self): + libs = ["lib_a", "lib_b"] + ctx = _FakeContext({STATE_KEY_RESOLVED_LIBRARIES: libs}) + self.assertEqual( + _run(context_module.get_mcp_resolved_libraries(ctx)), + libs, + ) + + def test_absent_key_returns_none_not_typeerror(self): + """The canonical regression: ``get_state`` returns ``None`` when + the key has never been set. If a future change re-adds ``await`` + in front of the call, *this* specific case will raise + ``TypeError: object NoneType can't be used in 'await' expression`` + — exactly the original bug. + """ + ctx = _FakeContext() # empty state + for helper in ( + context_module.get_mcp_user, + context_module.get_mcp_token, + context_module.get_mcp_claims, + context_module.get_mcp_resolved_libraries, + ): + with self.subTest(helper=helper.__name__): + self.assertIsNone(_run(helper(ctx))) + + def test_none_ctx_returns_none_without_touching_state(self): + """``ctx is None`` short-circuit must not call ``get_state``.""" + for helper in ( + context_module.get_mcp_user, + context_module.get_mcp_token, + context_module.get_mcp_claims, + context_module.get_mcp_resolved_libraries, + ): + with self.subTest(helper=helper.__name__): + self.assertIsNone(_run(helper(None))) + + +class MiddlewareSetStateUsageTest(SimpleTestCase): + """``MCPAuthMiddleware.on_call_tool`` must call ``set_state`` synchronously. + + Rather than stand up the whole middleware (which would pull in JWT + decoding, Django ORM, Neo4j, etc.), this test asserts the + *source-level* invariant: there is no ``await`` immediately before + any ``set_state(`` call in ``mcp_server.auth``. A grep-style check + is the smallest possible guard that would have caught the original + bug — richer behavioural coverage is in the integration tests that + run against a live Mnemosyne container. + """ + + def test_auth_module_never_awaits_set_state(self): + import re + from pathlib import Path + + source = Path( + __import__("mcp_server.auth", fromlist=["__file__"]).__file__ + ).read_text() + # Match ``await .set_state(`` tolerant of whitespace. + offending = re.findall(r"await\s+[^\s()]+\.set_state\(", source) + self.assertEqual( + offending, + [], + "mcp_server.auth must not await ``set_state`` — " + "it is a synchronous FastMCP Context method. Awaiting its " + "return value raises ``TypeError: object NoneType can't be " + "used in 'await' expression`` and breaks every authenticated " + "tool call silently.", + ) + + def test_context_module_never_awaits_get_state(self): + import re + from pathlib import Path + + source = Path( + __import__("mcp_server.context", fromlist=["__file__"]).__file__ + ).read_text() + offending = re.findall(r"await\s+[^\s()]+\.get_state\(", source) + self.assertEqual( + offending, + [], + "mcp_server.context must not await ``get_state`` — " + "it is a synchronous FastMCP Context method. The helpers " + "themselves stay ``async def`` for call-site ergonomics, " + "but the underlying call must be synchronous.", + )