Files
mnemosyne/mnemosyne/mcp_server/tests/test_auth.py
Robert Helewka 93639188d3
Some checks failed
CVE Scan & Docker Build / build-and-push (push) Has been cancelled
CVE Scan & Docker Build / security-scan (push) Has been cancelled
Build & Deploy Docs / build-and-deploy (push) Successful in 1m10s
feat: rework auth model with UserToken and Daedalus/Pallas integration
- Rename MCPToken to UserToken across models, views, and tests
- Update URL names from mcp-token-* to token-*
- Add Daedalus/Pallas integration design doc (v2)
- Switch docker-compose to build local mnemosyne:local image via shared
  build config instead of pulling from git.helu.ca
2026-05-23 19:50:29 -04:00

543 lines
18 KiB
Python

"""Tests for the MCP auth surface.
Covers all three credential types described in
``docs/DAEDALUS_PALLAS_INTEGRATION_v1.md`` §3.2:
1. Opaque :class:`~mcp_server.models.UserToken` — ``resolve_mcp_user``
+ ``MCPAuthMiddleware`` opaque branch.
2. Per-turn JWT (``iss=daedalus``, legacy) — ``resolve_mcp_jwt`` normal
path, ``_remember_jti`` replay cache, ``claims["libs"]``-derived
resolved_libraries.
3. Team JWT (``iss=mnemosyne typ=team``) — ``resolve_mcp_jwt`` team
branch (bypasses replay cache), ``_libraries_for_team`` enforcing
``Team.active`` + ``Team.active_jti``, ``resolved_libraries`` coming
from a mocked Neo4j lookup.
Neo4j is stubbed by patching ``neomodel.db.cypher_query`` — these tests
do not require a running Neo4j instance.
"""
from __future__ import annotations
import time
import uuid
from datetime import timedelta
from unittest import mock
import jwt as pyjwt
from django.contrib.auth import get_user_model
from django.test import TestCase
from django.utils import timezone
from mcp_server import auth as auth_module
from mcp_server.auth import (
MCPAuthError,
_libraries_for_team,
_remember_jti,
_resolve_jwt_actor,
_resolved_libraries_for_jwt,
looks_like_jwt,
resolve_mcp_jwt,
resolve_mcp_user,
)
from mcp_server.models import (
MCPSigningKey,
UserToken,
Team,
TeamWorkspaceAssignment,
)
User = get_user_model()
# ---------------------------------------------------------------------------
# Opaque UserToken
# ---------------------------------------------------------------------------
class ResolveMCPUserTest(TestCase):
"""The pre-existing coverage of the opaque bearer path."""
def setUp(self):
self.user = User.objects.create_user(
username="bob", email="bob@example.com", password="pw"
)
self.token, self.plaintext = UserToken.objects.create_token(
user=self.user, name="t"
)
def test_resolves_valid_token(self):
user, token = resolve_mcp_user(self.plaintext)
self.assertEqual(user.pk, self.user.pk)
self.assertEqual(token.pk, self.token.pk)
def test_records_usage(self):
self.assertIsNone(self.token.last_used_at)
resolve_mcp_user(self.plaintext)
self.token.refresh_from_db()
self.assertIsNotNone(self.token.last_used_at)
def test_invalid_token_raises(self):
with self.assertRaises(MCPAuthError):
resolve_mcp_user("not-a-real-token")
def test_inactive_token_raises(self):
self.token.is_active = False
self.token.save()
with self.assertRaises(MCPAuthError):
resolve_mcp_user(self.plaintext)
def test_expired_token_raises(self):
self.token.expires_at = timezone.now() - timedelta(hours=1)
self.token.save()
with self.assertRaises(MCPAuthError):
resolve_mcp_user(self.plaintext)
def test_disabled_user_raises(self):
self.user.is_active = False
self.user.save()
with self.assertRaises(MCPAuthError):
resolve_mcp_user(self.plaintext)
def test_plaintext_not_in_db(self):
# Defense in depth: scan every column for the plaintext value.
from django.db import connection
plaintext = self.plaintext
with connection.cursor() as cur:
cur.execute("SELECT * FROM mcp_server_usertoken")
rows = cur.fetchall()
for row in rows:
for value in row:
self.assertNotEqual(
value, plaintext,
f"Plaintext token leaked into the database: {value!r}",
)
# ---------------------------------------------------------------------------
# JWT plumbing helpers
# ---------------------------------------------------------------------------
def _make_signing_key() -> MCPSigningKey:
"""Fresh active HS256 key with a hex secret the auth module accepts."""
secret_hex = "a" * 64 # 256 bits / 64 hex chars
return MCPSigningKey.objects.create(
kid=f"test-kid-{uuid.uuid4().hex[:8]}",
secret_hex=secret_hex,
is_active=True,
)
def _encode(payload: dict, key: MCPSigningKey) -> str:
return pyjwt.encode(
payload,
bytes.fromhex(key.secret_hex),
algorithm="HS256",
headers={"kid": key.kid},
)
def _per_turn_payload(**overrides) -> dict:
now = int(time.time())
payload = {
"iss": "daedalus",
"sub": "user:someone",
"iat": now,
"exp": now + 300,
"jti": str(uuid.uuid4()),
"libs": ["lib-a", "lib-b"],
"ws": "ws-42",
}
payload.update(overrides)
return payload
def _team_payload(team: Team, **overrides) -> dict:
now = int(time.time())
payload = {
"iss": "mnemosyne",
"aud": "mnemosyne",
"sub": f"team:{team.id}",
"typ": "team",
"iat": now,
"exp": now + 3600,
"jti": str(team.active_jti),
}
payload.update(overrides)
return payload
# ---------------------------------------------------------------------------
# looks_like_jwt
# ---------------------------------------------------------------------------
class LooksLikeJWTTest(TestCase):
def test_three_segments_with_json_header(self):
key = _make_signing_key()
token = _encode(_per_turn_payload(), key)
self.assertTrue(looks_like_jwt(token))
def test_two_segments_rejected(self):
self.assertFalse(looks_like_jwt("aaa.bbb"))
def test_four_segments_rejected(self):
self.assertFalse(looks_like_jwt("aaa.bbb.ccc.ddd"))
def test_garbage_header_rejected(self):
self.assertFalse(looks_like_jwt("!!!.bbb.ccc"))
def test_opaque_token_rejected(self):
# Real ``UserToken.create_token`` plaintext is 48-byte base64, often
# contains dashes but never two dots.
self.assertFalse(
looks_like_jwt("CxGb3rThJ7_4jUGl0q2_fakey_fakey_fakey_fakey_fakey")
)
# ---------------------------------------------------------------------------
# Per-turn JWT branch (iss=daedalus)
# ---------------------------------------------------------------------------
class ResolvePerTurnJWTTest(TestCase):
def setUp(self):
self.key = _make_signing_key()
# Make sure the replay cache is clean across tests; it is a
# module-level dict.
auth_module._JTI_CACHE.clear()
def test_happy_path_returns_normalized_claims(self):
token = _encode(_per_turn_payload(), self.key)
claims = resolve_mcp_jwt(token)
self.assertEqual(claims["iss"], "daedalus")
self.assertEqual(claims["libs"], ["lib-a", "lib-b"])
self.assertEqual(claims["ws"], "ws-42")
self.assertEqual(claims["kid"], self.key.kid)
self.assertNotIn("typ", claims) # per-turn tokens omit typ
self.assertNotIn("team_id", claims)
def test_libs_defaults_to_empty_list(self):
payload = _per_turn_payload()
payload.pop("libs", None)
token = _encode(payload, self.key)
claims = resolve_mcp_jwt(token)
self.assertEqual(claims["libs"], [])
def test_ws_defaults_to_none(self):
payload = _per_turn_payload()
payload.pop("ws", None)
token = _encode(payload, self.key)
claims = resolve_mcp_jwt(token)
self.assertIsNone(claims["ws"])
def test_invalid_libs_shape_rejected(self):
payload = _per_turn_payload(libs="not-a-list")
token = _encode(payload, self.key)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_unknown_issuer_rejected(self):
payload = _per_turn_payload(iss="attacker")
token = _encode(payload, self.key)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_expired_token_rejected(self):
payload = _per_turn_payload(
iat=int(time.time()) - 600,
exp=int(time.time()) - 60,
)
token = _encode(payload, self.key)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_retired_signing_key_rejected(self):
self.key.is_active = False
self.key.save(update_fields=["is_active"])
token = _encode(_per_turn_payload(), self.key)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_unknown_kid_rejected(self):
foreign_secret_hex = "b" * 64
token = pyjwt.encode(
_per_turn_payload(),
bytes.fromhex(foreign_secret_hex),
algorithm="HS256",
headers={"kid": "no-such-kid"},
)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_bad_signature_rejected(self):
# Re-sign with a different secret but reuse the stored kid.
foreign_secret_hex = "b" * 64
token = pyjwt.encode(
_per_turn_payload(),
bytes.fromhex(foreign_secret_hex),
algorithm="HS256",
headers={"kid": self.key.kid},
)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
class PerTurnReplayCacheTest(TestCase):
"""``_remember_jti`` flags a jti as replay iff we see it *after* its exp."""
def setUp(self):
auth_module._JTI_CACHE.clear()
def test_first_sighting_not_replay(self):
self.assertFalse(_remember_jti("jti-1", time.time() + 300))
def test_same_jti_within_validity_not_replay(self):
# Daedalus routinely reuses a per-turn jti across multiple tool
# calls in a single turn; that must not trip the replay gate.
exp = time.time() + 300
_remember_jti("jti-2", exp)
self.assertFalse(_remember_jti("jti-2", exp))
self.assertFalse(_remember_jti("jti-2", exp))
def test_same_jti_after_exp_is_replay(self):
exp = time.time() - 60 # already past
# First sighting just records; second is the belt-and-braces trip.
_remember_jti("jti-3", exp)
self.assertTrue(_remember_jti("jti-3", exp))
# ---------------------------------------------------------------------------
# Team JWT branch (iss=mnemosyne, typ=team)
# ---------------------------------------------------------------------------
class ResolveTeamJWTTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.owner = User.objects.create_user(username="alice", password="pw")
def setUp(self):
self.key = _make_signing_key()
self.team = Team.objects.create(
id=uuid.uuid4(),
name="Pallas-Harper",
owner=self.owner,
active=True,
active_jti=uuid.uuid4(),
)
def test_team_jwt_happy_path_populates_team_id(self):
token = _encode(_team_payload(self.team), self.key)
claims = resolve_mcp_jwt(token)
self.assertEqual(claims["typ"], "team")
self.assertEqual(claims["team_id"], self.team.id)
# Team tokens must never be normalized into the per-turn shape.
self.assertNotIn("libs", claims)
self.assertNotIn("ws", claims)
def test_malformed_sub_rejected(self):
token = _encode(
_team_payload(self.team, sub="not-a-team-sub"), self.key
)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_non_uuid_sub_rejected(self):
token = _encode(
_team_payload(self.team, sub="team:not-a-uuid"), self.key
)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_team_jwt_bypasses_replay_cache(self):
"""``typ=team`` tokens are intentionally reused; no jti pollution."""
auth_module._JTI_CACHE.clear()
token = _encode(_team_payload(self.team), self.key)
resolve_mcp_jwt(token)
resolve_mcp_jwt(token) # a second validation must succeed
# And the replay cache must not have been touched.
self.assertEqual(auth_module._JTI_CACHE, {})
# ---------------------------------------------------------------------------
# _libraries_for_team — Postgres + mocked Neo4j
# ---------------------------------------------------------------------------
class LibrariesForTeamTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.owner = User.objects.create_user(username="alice", password="pw")
def setUp(self):
self.team = Team.objects.create(
id=uuid.uuid4(),
name="T",
owner=self.owner,
active=True,
active_jti=uuid.uuid4(),
)
def test_unknown_team_rejected(self):
with self.assertRaises(MCPAuthError):
_libraries_for_team(uuid.uuid4(), str(uuid.uuid4()))
def test_inactive_team_rejected(self):
self.team.active = False
self.team.save(update_fields=["active"])
with self.assertRaises(MCPAuthError):
_libraries_for_team(self.team.id, str(self.team.active_jti))
def test_stale_jti_rejected(self):
# Incoming jti no longer matches — rotation happened since mint.
with self.assertRaises(MCPAuthError):
_libraries_for_team(self.team.id, str(uuid.uuid4()))
def test_no_workspace_assignments_returns_empty_list(self):
# Fail-closed: a team with no assignments sees no libraries.
libs = _libraries_for_team(self.team.id, str(self.team.active_jti))
self.assertEqual(libs, [])
def test_workspace_assignments_translated_to_library_uids(self):
TeamWorkspaceAssignment.objects.create(team=self.team, workspace_id="ws-a")
TeamWorkspaceAssignment.objects.create(team=self.team, workspace_id="ws-b")
fake_rows = [["lib-1"], ["lib-2"], ["lib-3"]]
with mock.patch(
"neomodel.db.cypher_query",
return_value=(fake_rows, ["l.uid"]),
) as cypher:
libs = _libraries_for_team(
self.team.id, str(self.team.active_jti)
)
self.assertEqual(sorted(libs), ["lib-1", "lib-2", "lib-3"])
# Assert we pass the workspace ids as the ``ws`` parameter, which
# the Cypher query projects through to the Neo4j driver.
(args, kwargs) = cypher.call_args
self.assertIn("WHERE l.workspace_id IN $ws", args[0])
self.assertEqual(sorted(args[1]["ws"]), ["ws-a", "ws-b"])
# ---------------------------------------------------------------------------
# _resolved_libraries_for_jwt dispatcher
# ---------------------------------------------------------------------------
class ResolvedLibrariesForJWTDispatcherTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.owner = User.objects.create_user(username="alice", password="pw")
def test_per_turn_claims_use_libs(self):
claims = {"libs": ["one", "two"]}
self.assertEqual(
_resolved_libraries_for_jwt(claims), ["one", "two"]
)
def test_team_claims_use_team_lookup(self):
team = Team.objects.create(
id=uuid.uuid4(),
name="T",
owner=self.owner,
active=True,
active_jti=uuid.uuid4(),
)
claims = {
"typ": "team",
"team_id": team.id,
"jti": str(team.active_jti),
}
with mock.patch(
"neomodel.db.cypher_query",
return_value=([["lib-team"]], ["l.uid"]),
):
TeamWorkspaceAssignment.objects.create(
team=team, workspace_id="ws-x"
)
out = _resolved_libraries_for_jwt(claims)
self.assertEqual(out, ["lib-team"])
# ---------------------------------------------------------------------------
# _resolve_jwt_actor — team JWTs resolve to the team's owner
# ---------------------------------------------------------------------------
class ResolveJWTActorTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.owner = User.objects.create_user(username="alice", password="pw")
def setUp(self):
self.team = Team.objects.create(
id=uuid.uuid4(),
name="T",
owner=self.owner,
active=True,
active_jti=uuid.uuid4(),
)
def test_team_jwt_resolves_to_team_owner(self):
claims = {"typ": "team", "team_id": self.team.id}
self.assertEqual(_resolve_jwt_actor(claims), self.owner)
def test_inactive_team_rejected(self):
self.team.deactivate()
claims = {"typ": "team", "team_id": self.team.id}
with self.assertRaises(MCPAuthError):
_resolve_jwt_actor(claims)
def test_disabled_owner_rejected(self):
self.owner.is_active = False
self.owner.save(update_fields=["is_active"])
claims = {"typ": "team", "team_id": self.team.id}
with self.assertRaises(MCPAuthError):
_resolve_jwt_actor(claims)
def test_per_turn_jwt_rejected(self):
# No more service-user fallback: per-turn JWTs can't be attributed
# to a Mnemosyne user, so the resolver refuses them.
with self.assertRaises(MCPAuthError):
_resolve_jwt_actor({"libs": ["x"]})
def test_unknown_team_rejected(self):
claims = {"typ": "team", "team_id": uuid.uuid4()}
with self.assertRaises(MCPAuthError):
_resolve_jwt_actor(claims)
# ---------------------------------------------------------------------------
# Module-level import guards (regression tests)
# ---------------------------------------------------------------------------
class AuthModuleImportsTest(TestCase):
"""Pin the imports the runtime depends on.
These are regressions waiting to happen: a quick `grep` doesn't
catch a missing import when the consuming code is only reached via
a runtime path (FastMCP middleware, async tool dispatch, …) that
the test suite doesn't exercise end-to-end.
Add a check here whenever production fails with a
``NameError: name 'X' is not defined`` that the test suite missed.
"""
def test_settings_is_importable(self):
"""``MCPAuthMiddleware.on_call_tool`` reads
``settings.MCP_REQUIRE_AUTH`` on every tool call (including
unauthenticated ``get_health`` polls from Pallas). Removing the
``from django.conf import settings`` import — as happened during
the v2 token-consolidation cleanup — surfaces as
``NameError: name 'settings' is not defined`` for *every* MCP
client. Keep this import alive.
"""
from django.conf import settings as dj_settings
from mcp_server import auth as auth_module
self.assertTrue(hasattr(auth_module, "settings"))
self.assertIs(auth_module.settings, dj_settings)