diff --git a/mnemosyne/mcp_server/auth.py b/mnemosyne/mcp_server/auth.py index 4d0641e..e532b68 100644 --- a/mnemosyne/mcp_server/auth.py +++ b/mnemosyne/mcp_server/auth.py @@ -222,7 +222,14 @@ def resolve_mcp_jwt(token_string: str) -> dict: secret, algorithms=["HS256"], leeway=_JWT_LEEWAY_SECONDS, - options={"require": ["exp", "iat", "iss", "sub", "jti"]}, + options={ + "require": ["exp", "iat", "iss", "sub", "jti"], + # Team JWTs carry ``aud=mnemosyne`` for informational + # purposes; per-turn JWTs omit ``aud`` entirely. We + # don't enforce either shape because ``iss`` + ``typ`` + # already partition the two token populations. + "verify_aud": False, + }, # ``issuer=`` accepts ``str | Iterable[str]`` and raises # ``InvalidIssuerError`` if the claim is outside the set. issuer=list(_JWT_ISS_VALUES), diff --git a/mnemosyne/mcp_server/tests/test_auth.py b/mnemosyne/mcp_server/tests/test_auth.py index 4e9e6ca..f616ae5 100644 --- a/mnemosyne/mcp_server/tests/test_auth.py +++ b/mnemosyne/mcp_server/tests/test_auth.py @@ -1,18 +1,62 @@ -"""Tests for resolve_mcp_user.""" +"""Tests for the MCP auth surface. +Covers all three credential types described in +``docs/DAEDALUS_PALLAS_INTEGRATION_v1.md`` §3.2: + +1. Opaque :class:`~mcp_server.models.MCPToken` — ``resolve_mcp_user`` + + ``MCPAuthMiddleware`` opaque branch. +2. Per-turn JWT (``iss=daedalus``, legacy) — ``resolve_mcp_jwt`` normal + path, ``_remember_jti`` replay cache, ``claims["libs"]``-derived + resolved_libraries. +3. Team JWT (``iss=mnemosyne typ=team``) — ``resolve_mcp_jwt`` team + branch (bypasses replay cache), ``_libraries_for_team`` enforcing + ``Team.active`` + ``Team.active_jti``, ``resolved_libraries`` coming + from a mocked Neo4j lookup. + +Neo4j is stubbed by patching ``neomodel.db.cypher_query`` — these tests +do not require a running Neo4j instance. +""" + +from __future__ import annotations + +import time +import uuid from datetime import timedelta +from unittest import mock +import jwt as pyjwt from django.contrib.auth import get_user_model from django.test import TestCase from django.utils import timezone -from mcp_server.auth import MCPAuthError, resolve_mcp_user -from mcp_server.models import MCPToken +from mcp_server import auth as auth_module +from mcp_server.auth import ( + MCPAuthError, + _libraries_for_team, + _remember_jti, + _resolved_libraries_for_jwt, + looks_like_jwt, + resolve_mcp_jwt, + resolve_mcp_user, +) +from mcp_server.models import ( + MCPSigningKey, + MCPToken, + Team, + TeamWorkspaceAssignment, +) User = get_user_model() +# --------------------------------------------------------------------------- +# Opaque MCPToken +# --------------------------------------------------------------------------- + + class ResolveMCPUserTest(TestCase): + """The pre-existing coverage of the opaque bearer path.""" + def setUp(self): self.user = User.objects.create_user( username="bob", email="bob@example.com", password="pw" @@ -68,3 +112,333 @@ class ResolveMCPUserTest(TestCase): value, plaintext, f"Plaintext token leaked into the database: {value!r}", ) + + +# --------------------------------------------------------------------------- +# JWT plumbing helpers +# --------------------------------------------------------------------------- + + +def _make_signing_key() -> MCPSigningKey: + """Fresh active HS256 key with a hex secret the auth module accepts.""" + secret_hex = "a" * 64 # 256 bits / 64 hex chars + return MCPSigningKey.objects.create( + kid=f"test-kid-{uuid.uuid4().hex[:8]}", + secret_hex=secret_hex, + is_active=True, + ) + + +def _encode(payload: dict, key: MCPSigningKey) -> str: + return pyjwt.encode( + payload, + bytes.fromhex(key.secret_hex), + algorithm="HS256", + headers={"kid": key.kid}, + ) + + +def _per_turn_payload(**overrides) -> dict: + now = int(time.time()) + payload = { + "iss": "daedalus", + "sub": "user:someone", + "iat": now, + "exp": now + 300, + "jti": str(uuid.uuid4()), + "libs": ["lib-a", "lib-b"], + "ws": "ws-42", + } + payload.update(overrides) + return payload + + +def _team_payload(team: Team, **overrides) -> dict: + now = int(time.time()) + payload = { + "iss": "mnemosyne", + "aud": "mnemosyne", + "sub": f"team:{team.id}", + "typ": "team", + "iat": now, + "exp": now + 3600, + "jti": str(team.active_jti), + } + payload.update(overrides) + return payload + + +# --------------------------------------------------------------------------- +# looks_like_jwt +# --------------------------------------------------------------------------- + + +class LooksLikeJWTTest(TestCase): + def test_three_segments_with_json_header(self): + key = _make_signing_key() + token = _encode(_per_turn_payload(), key) + self.assertTrue(looks_like_jwt(token)) + + def test_two_segments_rejected(self): + self.assertFalse(looks_like_jwt("aaa.bbb")) + + def test_four_segments_rejected(self): + self.assertFalse(looks_like_jwt("aaa.bbb.ccc.ddd")) + + def test_garbage_header_rejected(self): + self.assertFalse(looks_like_jwt("!!!.bbb.ccc")) + + def test_opaque_token_rejected(self): + # Real ``MCPToken.create_token`` plaintext is 48-byte base64, often + # contains dashes but never two dots. + self.assertFalse( + looks_like_jwt("CxGb3rThJ7_4jUGl0q2_fakey_fakey_fakey_fakey_fakey") + ) + + +# --------------------------------------------------------------------------- +# Per-turn JWT branch (iss=daedalus) +# --------------------------------------------------------------------------- + + +class ResolvePerTurnJWTTest(TestCase): + def setUp(self): + self.key = _make_signing_key() + # Make sure the replay cache is clean across tests; it is a + # module-level dict. + auth_module._JTI_CACHE.clear() + + def test_happy_path_returns_normalized_claims(self): + token = _encode(_per_turn_payload(), self.key) + claims = resolve_mcp_jwt(token) + self.assertEqual(claims["iss"], "daedalus") + self.assertEqual(claims["libs"], ["lib-a", "lib-b"]) + self.assertEqual(claims["ws"], "ws-42") + self.assertEqual(claims["kid"], self.key.kid) + self.assertNotIn("typ", claims) # per-turn tokens omit typ + self.assertNotIn("team_id", claims) + + def test_libs_defaults_to_empty_list(self): + payload = _per_turn_payload() + payload.pop("libs", None) + token = _encode(payload, self.key) + claims = resolve_mcp_jwt(token) + self.assertEqual(claims["libs"], []) + + def test_ws_defaults_to_none(self): + payload = _per_turn_payload() + payload.pop("ws", None) + token = _encode(payload, self.key) + claims = resolve_mcp_jwt(token) + self.assertIsNone(claims["ws"]) + + def test_invalid_libs_shape_rejected(self): + payload = _per_turn_payload(libs="not-a-list") + token = _encode(payload, self.key) + with self.assertRaises(MCPAuthError): + resolve_mcp_jwt(token) + + def test_unknown_issuer_rejected(self): + payload = _per_turn_payload(iss="attacker") + token = _encode(payload, self.key) + with self.assertRaises(MCPAuthError): + resolve_mcp_jwt(token) + + def test_expired_token_rejected(self): + payload = _per_turn_payload( + iat=int(time.time()) - 600, + exp=int(time.time()) - 60, + ) + token = _encode(payload, self.key) + with self.assertRaises(MCPAuthError): + resolve_mcp_jwt(token) + + def test_retired_signing_key_rejected(self): + self.key.is_active = False + self.key.save(update_fields=["is_active"]) + token = _encode(_per_turn_payload(), self.key) + with self.assertRaises(MCPAuthError): + resolve_mcp_jwt(token) + + def test_unknown_kid_rejected(self): + foreign_secret_hex = "b" * 64 + token = pyjwt.encode( + _per_turn_payload(), + bytes.fromhex(foreign_secret_hex), + algorithm="HS256", + headers={"kid": "no-such-kid"}, + ) + with self.assertRaises(MCPAuthError): + resolve_mcp_jwt(token) + + def test_bad_signature_rejected(self): + # Re-sign with a different secret but reuse the stored kid. + foreign_secret_hex = "b" * 64 + token = pyjwt.encode( + _per_turn_payload(), + bytes.fromhex(foreign_secret_hex), + algorithm="HS256", + headers={"kid": self.key.kid}, + ) + with self.assertRaises(MCPAuthError): + resolve_mcp_jwt(token) + + +class PerTurnReplayCacheTest(TestCase): + """``_remember_jti`` flags a jti as replay iff we see it *after* its exp.""" + + def setUp(self): + auth_module._JTI_CACHE.clear() + + def test_first_sighting_not_replay(self): + self.assertFalse(_remember_jti("jti-1", time.time() + 300)) + + def test_same_jti_within_validity_not_replay(self): + # Daedalus routinely reuses a per-turn jti across multiple tool + # calls in a single turn; that must not trip the replay gate. + exp = time.time() + 300 + _remember_jti("jti-2", exp) + self.assertFalse(_remember_jti("jti-2", exp)) + self.assertFalse(_remember_jti("jti-2", exp)) + + def test_same_jti_after_exp_is_replay(self): + exp = time.time() - 60 # already past + # First sighting just records; second is the belt-and-braces trip. + _remember_jti("jti-3", exp) + self.assertTrue(_remember_jti("jti-3", exp)) + + +# --------------------------------------------------------------------------- +# Team JWT branch (iss=mnemosyne, typ=team) +# --------------------------------------------------------------------------- + + +class ResolveTeamJWTTest(TestCase): + def setUp(self): + self.key = _make_signing_key() + self.team = Team.objects.create( + id=uuid.uuid4(), + name="Pallas-Harper", + active=True, + active_jti=uuid.uuid4(), + ) + + def test_team_jwt_happy_path_populates_team_id(self): + token = _encode(_team_payload(self.team), self.key) + claims = resolve_mcp_jwt(token) + self.assertEqual(claims["typ"], "team") + self.assertEqual(claims["team_id"], self.team.id) + # Team tokens must never be normalized into the per-turn shape. + self.assertNotIn("libs", claims) + self.assertNotIn("ws", claims) + + def test_malformed_sub_rejected(self): + token = _encode( + _team_payload(self.team, sub="not-a-team-sub"), self.key + ) + with self.assertRaises(MCPAuthError): + resolve_mcp_jwt(token) + + def test_non_uuid_sub_rejected(self): + token = _encode( + _team_payload(self.team, sub="team:not-a-uuid"), self.key + ) + with self.assertRaises(MCPAuthError): + resolve_mcp_jwt(token) + + def test_team_jwt_bypasses_replay_cache(self): + """``typ=team`` tokens are intentionally reused; no jti pollution.""" + auth_module._JTI_CACHE.clear() + token = _encode(_team_payload(self.team), self.key) + resolve_mcp_jwt(token) + resolve_mcp_jwt(token) # a second validation must succeed + # And the replay cache must not have been touched. + self.assertEqual(auth_module._JTI_CACHE, {}) + + +# --------------------------------------------------------------------------- +# _libraries_for_team — Postgres + mocked Neo4j +# --------------------------------------------------------------------------- + + +class LibrariesForTeamTest(TestCase): + def setUp(self): + self.team = Team.objects.create( + id=uuid.uuid4(), + name="T", + active=True, + active_jti=uuid.uuid4(), + ) + + def test_unknown_team_rejected(self): + with self.assertRaises(MCPAuthError): + _libraries_for_team(uuid.uuid4(), str(uuid.uuid4())) + + def test_inactive_team_rejected(self): + self.team.active = False + self.team.save(update_fields=["active"]) + with self.assertRaises(MCPAuthError): + _libraries_for_team(self.team.id, str(self.team.active_jti)) + + def test_stale_jti_rejected(self): + # Incoming jti no longer matches — rotation happened since mint. + with self.assertRaises(MCPAuthError): + _libraries_for_team(self.team.id, str(uuid.uuid4())) + + def test_no_workspace_assignments_returns_empty_list(self): + # Fail-closed: a team with no assignments sees no libraries. + libs = _libraries_for_team(self.team.id, str(self.team.active_jti)) + self.assertEqual(libs, []) + + def test_workspace_assignments_translated_to_library_uids(self): + TeamWorkspaceAssignment.objects.create(team=self.team, workspace_id="ws-a") + TeamWorkspaceAssignment.objects.create(team=self.team, workspace_id="ws-b") + fake_rows = [["lib-1"], ["lib-2"], ["lib-3"]] + with mock.patch( + "neomodel.db.cypher_query", + return_value=(fake_rows, ["l.uid"]), + ) as cypher: + libs = _libraries_for_team( + self.team.id, str(self.team.active_jti) + ) + self.assertEqual(sorted(libs), ["lib-1", "lib-2", "lib-3"]) + # Assert we pass the workspace ids as the ``ws`` parameter, which + # the Cypher query projects through to the Neo4j driver. + (args, kwargs) = cypher.call_args + self.assertIn("WHERE l.workspace_id IN $ws", args[0]) + self.assertEqual(sorted(args[1]["ws"]), ["ws-a", "ws-b"]) + + +# --------------------------------------------------------------------------- +# _resolved_libraries_for_jwt dispatcher +# --------------------------------------------------------------------------- + + +class ResolvedLibrariesForJWTDispatcherTest(TestCase): + def test_per_turn_claims_use_libs(self): + claims = {"libs": ["one", "two"]} + self.assertEqual( + _resolved_libraries_for_jwt(claims), ["one", "two"] + ) + + def test_team_claims_use_team_lookup(self): + team = Team.objects.create( + id=uuid.uuid4(), + name="T", + active=True, + active_jti=uuid.uuid4(), + ) + claims = { + "typ": "team", + "team_id": team.id, + "jti": str(team.active_jti), + } + with mock.patch( + "neomodel.db.cypher_query", + return_value=([["lib-team"]], ["l.uid"]), + ): + TeamWorkspaceAssignment.objects.create( + team=team, workspace_id="ws-x" + ) + out = _resolved_libraries_for_jwt(claims) + self.assertEqual(out, ["lib-team"]) diff --git a/mnemosyne/mcp_server/tests/test_backfill_library_memberships.py b/mnemosyne/mcp_server/tests/test_backfill_library_memberships.py new file mode 100644 index 0000000..c7be10f --- /dev/null +++ b/mnemosyne/mcp_server/tests/test_backfill_library_memberships.py @@ -0,0 +1,161 @@ +"""Tests for the ``backfill_library_memberships`` management command. + +The command queries ``library.models.Library.nodes`` (Neo4j), which we +don't want to exercise in unit tests. We substitute a lightweight fake +library iterable onto ``Library.nodes.filter`` via ``unittest.mock``. +""" + +from __future__ import annotations + +from io import StringIO +from unittest import mock + +from django.contrib.auth import get_user_model +from django.core.management import call_command +from django.core.management.base import CommandError +from django.test import TestCase + +from mcp_server.models import LibraryMembership + +User = get_user_model() + + +class _FakeLibrary: + def __init__(self, uid: str, workspace_id=None): + self.uid = uid + self.workspace_id = workspace_id + + +class _FakeNodes: + """Stand-in for the neomodel ``Library.nodes`` node set.""" + + def __init__(self, all_libs): + self._all = list(all_libs) + + def filter(self, **kwargs): + out = list(self._all) + if "workspace_id__isnull" in kwargs: + want = kwargs["workspace_id__isnull"] + out = [l for l in out if (l.workspace_id is None) == want] + return out + + +def _run(*cli_args): + """Invoke the command with stdout/stderr buffered. Returns stdout.""" + buf = StringIO() + call_command( + "backfill_library_memberships", *cli_args, stdout=buf, stderr=StringIO() + ) + return buf.getvalue() + + +class BackfillCommandTest(TestCase): + def setUp(self): + self.superuser = User.objects.create_superuser( + username="admin", password="pw", email="a@b.c" + ) + self.patcher = mock.patch( + "library.models.Library.nodes", + new=_FakeNodes( + [ + _FakeLibrary("lib-a"), # global + _FakeLibrary("lib-b"), # global + _FakeLibrary("lib-ws", "ws-1"), # workspace-scoped + ] + ), + ) + self.patcher.start() + self.addCleanup(self.patcher.stop) + + def test_assigns_owner_to_all_global_libraries(self): + _run() + owners = LibraryMembership.objects.filter(user=self.superuser) + self.assertEqual( + sorted(owners.values_list("library_uid", flat=True)), + ["lib-a", "lib-b"], + ) + self.assertTrue( + all( + m.role == LibraryMembership.Role.OWNER + for m in owners + ) + ) + + def test_skips_workspace_scoped_libraries(self): + _run() + self.assertFalse( + LibraryMembership.objects.filter(library_uid="lib-ws").exists() + ) + + def test_idempotent(self): + _run() + count_after_first = LibraryMembership.objects.count() + _run() + self.assertEqual(LibraryMembership.objects.count(), count_after_first) + + def test_skips_libraries_with_existing_membership(self): + # Pre-seed someone else as manager on lib-a; the command should + # not overwrite or add a second row. + other = User.objects.create_user(username="other", password="pw") + LibraryMembership.objects.create( + user=other, + library_uid="lib-a", + role=LibraryMembership.Role.MANAGER, + ) + _run() + + # lib-a keeps its original manager-row; the superuser is NOT + # granted on it because any membership existing is an opt-out. + memberships_for_lib_a = LibraryMembership.objects.filter( + library_uid="lib-a" + ) + self.assertEqual(memberships_for_lib_a.count(), 1) + self.assertEqual(memberships_for_lib_a.first().user, other) + + # lib-b had none, so the superuser got it. + self.assertTrue( + LibraryMembership.objects.filter( + library_uid="lib-b", user=self.superuser + ).exists() + ) + + def test_dry_run_does_not_persist(self): + out = _run("--dry-run") + self.assertIn("would insert", out) + self.assertEqual(LibraryMembership.objects.count(), 0) + + def test_user_arg_overrides_default_superuser(self): + target = User.objects.create_user(username="librarian", password="pw") + _run("--user", "librarian") + self.assertTrue( + LibraryMembership.objects.filter(user=target).exists() + ) + self.assertFalse( + LibraryMembership.objects.filter(user=self.superuser).exists() + ) + + def test_unknown_user_raises_command_error(self): + with self.assertRaises(CommandError): + _run("--user", "nobody") + + def test_no_superuser_raises_command_error(self): + self.superuser.is_superuser = False + self.superuser.save(update_fields=["is_superuser"]) + with self.assertRaises(CommandError): + _run() + + +class BackfillCommandEmptyNeo4jTest(TestCase): + def setUp(self): + self.superuser = User.objects.create_superuser( + username="admin", password="pw", email="a@b.c" + ) + self.patcher = mock.patch( + "library.models.Library.nodes", new=_FakeNodes([]) + ) + self.patcher.start() + self.addCleanup(self.patcher.stop) + + def test_no_libraries_is_noop(self): + _run() + self.assertEqual(LibraryMembership.objects.count(), 0) diff --git a/mnemosyne/mcp_server/tests/test_models.py b/mnemosyne/mcp_server/tests/test_models.py new file mode 100644 index 0000000..dadf536 --- /dev/null +++ b/mnemosyne/mcp_server/tests/test_models.py @@ -0,0 +1,251 @@ +"""Tests for the Team / LibraryMembership / TeamWorkspaceAssignment models. + +``MCPToken``'s hash-at-rest semantics live in ``test_token.py``; this +module exercises the new Phase 2 tables introduced by +``docs/DAEDALUS_PALLAS_INTEGRATION_v1.md`` §4 plus the +``allowed_libraries`` JSONField attached to the existing +:class:`~mcp_server.models.MCPToken`. +""" + +from __future__ import annotations + +import uuid + +from django.contrib.auth import get_user_model +from django.db import IntegrityError, transaction +from django.test import TestCase + +from mcp_server.models import ( + LibraryMembership, + MCPSigningKey, + MCPToken, + Team, + TeamWorkspaceAssignment, +) + +User = get_user_model() + + +# --------------------------------------------------------------------------- +# LibraryMembership +# --------------------------------------------------------------------------- + + +class LibraryMembershipTest(TestCase): + def setUp(self): + self.user = User.objects.create_user(username="u", password="pw") + + def test_role_choices_exposed(self): + self.assertEqual(LibraryMembership.Role.OWNER, "owner") + self.assertEqual(LibraryMembership.Role.MANAGER, "manager") + self.assertEqual(LibraryMembership.Role.READER, "reader") + + def test_unique_per_user_and_library(self): + LibraryMembership.objects.create( + user=self.user, + library_uid="lib-1", + role=LibraryMembership.Role.OWNER, + ) + # Second membership on the same (user, library_uid) must fail + # even if the role differs — callers consolidate to the + # higher role rather than stacking rows. + with transaction.atomic(): + with self.assertRaises(IntegrityError): + LibraryMembership.objects.create( + user=self.user, + library_uid="lib-1", + role=LibraryMembership.Role.READER, + ) + + def test_same_library_different_users_allowed(self): + other = User.objects.create_user(username="u2", password="pw") + LibraryMembership.objects.create( + user=self.user, + library_uid="lib-1", + role=LibraryMembership.Role.OWNER, + ) + LibraryMembership.objects.create( + user=other, + library_uid="lib-1", + role=LibraryMembership.Role.READER, + ) + self.assertEqual(LibraryMembership.objects.count(), 2) + + def test_same_user_different_libraries_allowed(self): + LibraryMembership.objects.create( + user=self.user, + library_uid="lib-1", + role=LibraryMembership.Role.OWNER, + ) + LibraryMembership.objects.create( + user=self.user, + library_uid="lib-2", + role=LibraryMembership.Role.MANAGER, + ) + self.assertEqual( + set( + LibraryMembership.objects.filter(user=self.user) + .values_list("library_uid", flat=True) + ), + {"lib-1", "lib-2"}, + ) + + +# --------------------------------------------------------------------------- +# MCPToken.allowed_libraries +# --------------------------------------------------------------------------- + + +class MCPTokenAllowedLibrariesTest(TestCase): + def setUp(self): + self.user = User.objects.create_user(username="u", password="pw") + + def test_defaults_to_empty_list(self): + token, _ = MCPToken.objects.create_token(user=self.user, name="t") + self.assertEqual(token.allowed_libraries, []) + + def test_create_token_accepts_allowed_libraries(self): + token, _ = MCPToken.objects.create_token( + user=self.user, name="t", allowed_libraries=["lib-a", "lib-b"] + ) + self.assertEqual(token.allowed_libraries, ["lib-a", "lib-b"]) + + def test_allowed_libraries_round_trips(self): + token, _ = MCPToken.objects.create_token( + user=self.user, + name="t", + allowed_libraries=["lib-a", "lib-b", "lib-c"], + ) + token.refresh_from_db() + self.assertEqual( + token.allowed_libraries, ["lib-a", "lib-b", "lib-c"] + ) + + +# --------------------------------------------------------------------------- +# Team + rotate_jti / deactivate +# --------------------------------------------------------------------------- + + +class TeamTest(TestCase): + def test_create_with_explicit_uuid(self): + tid = uuid.uuid4() + team = Team.objects.create(id=tid, name="Harper") + self.assertEqual(team.id, tid) + self.assertTrue(team.active) + self.assertIsNone(team.active_jti) + + def test_rotate_jti_installs_fresh_uuid(self): + team = Team.objects.create(id=uuid.uuid4(), name="t") + first = team.rotate_jti() + self.assertIsInstance(first, uuid.UUID) + self.assertEqual(team.active_jti, first) + + second = team.rotate_jti() + self.assertNotEqual(first, second) + self.assertEqual(team.active_jti, second) + + def test_rotate_jti_persists(self): + team = Team.objects.create(id=uuid.uuid4(), name="t") + team.rotate_jti() + # Reload from DB and make sure the UUID was committed. + reloaded = Team.objects.get(pk=team.id) + self.assertEqual(reloaded.active_jti, team.active_jti) + + def test_deactivate_clears_active_jti(self): + team = Team.objects.create(id=uuid.uuid4(), name="t") + team.rotate_jti() + self.assertTrue(team.active) + team.deactivate() + self.assertFalse(team.active) + self.assertIsNone(team.active_jti) + # And it persisted. + reloaded = Team.objects.get(pk=team.id) + self.assertFalse(reloaded.active) + self.assertIsNone(reloaded.active_jti) + + +# --------------------------------------------------------------------------- +# TeamWorkspaceAssignment +# --------------------------------------------------------------------------- + + +class TeamWorkspaceAssignmentTest(TestCase): + def setUp(self): + self.team = Team.objects.create(id=uuid.uuid4(), name="t") + + def test_unique_team_workspace_pair(self): + TeamWorkspaceAssignment.objects.create( + team=self.team, workspace_id="ws-a" + ) + with transaction.atomic(): + with self.assertRaises(IntegrityError): + TeamWorkspaceAssignment.objects.create( + team=self.team, workspace_id="ws-a" + ) + + def test_same_workspace_different_teams_allowed(self): + other = Team.objects.create(id=uuid.uuid4(), name="t2") + TeamWorkspaceAssignment.objects.create( + team=self.team, workspace_id="ws-a" + ) + TeamWorkspaceAssignment.objects.create( + team=other, workspace_id="ws-a" + ) + self.assertEqual(TeamWorkspaceAssignment.objects.count(), 2) + + def test_cascade_on_team_delete(self): + TeamWorkspaceAssignment.objects.create( + team=self.team, workspace_id="ws-a" + ) + self.team.delete() + self.assertEqual(TeamWorkspaceAssignment.objects.count(), 0) + + def test_related_name_workspace_assignments(self): + # auth._libraries_for_team reaches through ``team.workspace_assignments``; + # make sure that attribute actually works. + TeamWorkspaceAssignment.objects.create( + team=self.team, workspace_id="ws-a" + ) + TeamWorkspaceAssignment.objects.create( + team=self.team, workspace_id="ws-b" + ) + ws_ids = sorted( + self.team.workspace_assignments.values_list( + "workspace_id", flat=True + ) + ) + self.assertEqual(ws_ids, ["ws-a", "ws-b"]) + + +# --------------------------------------------------------------------------- +# MCPSigningKey.objects.current() — used by mint_team_jwt +# --------------------------------------------------------------------------- + + +class MCPSigningKeyCurrentTest(TestCase): + def test_none_when_no_keys(self): + self.assertIsNone(MCPSigningKey.objects.current()) + + def test_returns_active_not_retired(self): + retired = MCPSigningKey.objects.create( + kid="old", secret_hex="a" * 64, is_active=False + ) + active = MCPSigningKey.objects.create( + kid="new", secret_hex="b" * 64, is_active=True + ) + self.assertEqual(MCPSigningKey.objects.current().pk, active.pk) + self.assertNotEqual(MCPSigningKey.objects.current().pk, retired.pk) + + def test_returns_newest_when_multiple_active(self): + older = MCPSigningKey.objects.create( + kid="older", secret_hex="a" * 64, is_active=True + ) + newer = MCPSigningKey.objects.create( + kid="newer", secret_hex="b" * 64, is_active=True + ) + self.assertEqual(MCPSigningKey.objects.current().pk, newer.pk) + # Sanity: older is still active, just older. + self.assertTrue( + MCPSigningKey.objects.filter(pk=older.pk, is_active=True).exists() + ) diff --git a/mnemosyne/mcp_server/tests/test_teams_api.py b/mnemosyne/mcp_server/tests/test_teams_api.py new file mode 100644 index 0000000..d3fbcd7 --- /dev/null +++ b/mnemosyne/mcp_server/tests/test_teams_api.py @@ -0,0 +1,359 @@ +"""Tests for the ``/mcp_server/api/teams/`` REST control plane. + +This is the Daedalus-facing surface described in §7 of +``docs/DAEDALUS_PALLAS_INTEGRATION_v1.md``. We do NOT exercise HTTP +Basic auth here (that's part of DRF / the project's session auth +stack); instead we use :meth:`APIClient.force_authenticate` to focus +on the endpoints' own idempotence and state-transition rules. +""" + +from __future__ import annotations + +import uuid + +import jwt as pyjwt +from django.contrib.auth import get_user_model +from django.test import TestCase +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + +from mcp_server.models import ( + MCPSigningKey, + Team, + TeamWorkspaceAssignment, +) + +User = get_user_model() + + +def _seed_signing_key() -> MCPSigningKey: + return MCPSigningKey.objects.create( + kid=f"test-{uuid.uuid4().hex[:6]}", + secret_hex="a" * 64, + is_active=True, + ) + + +class _AuthenticatedAPITest(TestCase): + """Shared ``APIClient`` authenticated as the service user. + + The real deployment has Daedalus hit these endpoints as + ``daedalus-service`` over HTTP Basic, but the view decorator is a + plain ``IsAuthenticated`` — so for unit-test purposes we use + ``force_authenticate`` with any active user. + """ + + def setUp(self): + self.service_user = User.objects.create_user( + username="daedalus-service", password="pw" + ) + self.client = APIClient() + self.client.force_authenticate(user=self.service_user) + + +# --------------------------------------------------------------------------- +# POST /mcp_server/api/teams/ +# --------------------------------------------------------------------------- + + +class TeamCreateTest(_AuthenticatedAPITest): + def setUp(self): + super().setUp() + self.url = reverse("mcp-server-api:team-create") + _seed_signing_key() + + def test_requires_authentication(self): + self.client.force_authenticate(user=None) + resp = self.client.post( + self.url, {"id": str(uuid.uuid4()), "name": "t"}, format="json" + ) + self.assertIn(resp.status_code, (401, 403)) + + def test_rejects_missing_fields(self): + resp = self.client.post(self.url, {}, format="json") + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn("id", resp.data) + self.assertIn("name", resp.data) + + def test_rejects_non_uuid_id(self): + resp = self.client.post( + self.url, {"id": "not-a-uuid", "name": "t"}, format="json" + ) + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + + def test_creates_team_and_returns_jwt(self): + tid = uuid.uuid4() + resp = self.client.post( + self.url, {"id": str(tid), "name": "Harper"}, format="json" + ) + self.assertEqual(resp.status_code, status.HTTP_201_CREATED) + self.assertEqual(uuid.UUID(resp.data["id"]), tid) + self.assertEqual(resp.data["name"], "Harper") + self.assertTrue(resp.data["active"]) + self.assertIn("jwt", resp.data) + + # JWT decodes and carries the right sub + jti. + team = Team.objects.get(pk=tid) + header = pyjwt.get_unverified_header(resp.data["jwt"]) + key = MCPSigningKey.objects.by_kid(header["kid"]) + decoded = pyjwt.decode( + resp.data["jwt"], + bytes.fromhex(key.secret_hex), + algorithms=["HS256"], + options={"verify_aud": False}, + ) + self.assertEqual(decoded["sub"], f"team:{tid}") + self.assertEqual(decoded["jti"], str(team.active_jti)) + + def test_idempotent_on_same_id_returns_200_without_jwt(self): + tid = uuid.uuid4() + first = self.client.post( + self.url, {"id": str(tid), "name": "first"}, format="json" + ) + self.assertEqual(first.status_code, status.HTTP_201_CREATED) + first_jti = Team.objects.get(pk=tid).active_jti + + # Second POST with same id: idempotent hit. Must NOT rotate the + # jti (otherwise every retry storm re-issues a fresh credential). + second = self.client.post( + self.url, + {"id": str(tid), "name": "ignored-on-hit"}, + format="json", + ) + self.assertEqual(second.status_code, status.HTTP_200_OK) + self.assertNotIn("jwt", second.data) + self.assertEqual( + Team.objects.get(pk=tid).active_jti, first_jti + ) + + def test_mint_failure_returns_503(self): + # Retire all signing keys so mint_team_jwt cannot succeed. + MCPSigningKey.objects.update(is_active=False) + resp = self.client.post( + self.url, + {"id": str(uuid.uuid4()), "name": "x"}, + format="json", + ) + self.assertEqual( + resp.status_code, status.HTTP_503_SERVICE_UNAVAILABLE + ) + # The transaction.atomic() wrapper must also have rolled back the + # Team row so we don't leave a team with no usable JWT. + self.assertEqual(Team.objects.count(), 0) + + +# --------------------------------------------------------------------------- +# GET / DELETE /mcp_server/api/teams/{id}/ +# --------------------------------------------------------------------------- + + +class TeamDetailTest(_AuthenticatedAPITest): + def setUp(self): + super().setUp() + self.team = Team.objects.create( + id=uuid.uuid4(), name="t", active=True, active_jti=uuid.uuid4() + ) + TeamWorkspaceAssignment.objects.create( + team=self.team, workspace_id="ws-a" + ) + TeamWorkspaceAssignment.objects.create( + team=self.team, workspace_id="ws-b" + ) + self.url = reverse( + "mcp-server-api:team-detail", kwargs={"team_id": self.team.id} + ) + + def test_get_returns_team_state(self): + resp = self.client.get(self.url) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(uuid.UUID(resp.data["id"]), self.team.id) + self.assertEqual(sorted(resp.data["workspace_ids"]), ["ws-a", "ws-b"]) + # Never leak the JWT via GET. + self.assertNotIn("jwt", resp.data) + + def test_get_unknown_team_returns_404(self): + url = reverse( + "mcp-server-api:team-detail", kwargs={"team_id": uuid.uuid4()} + ) + resp = self.client.get(url) + self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND) + + def test_delete_soft_deletes(self): + resp = self.client.delete(self.url) + self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) + + reloaded = Team.objects.get(pk=self.team.id) + self.assertFalse(reloaded.active) + self.assertIsNone(reloaded.active_jti) + # Workspace rows stay for audit — deactivation only flips flags. + self.assertEqual( + TeamWorkspaceAssignment.objects.filter(team=reloaded).count(), + 2, + ) + + def test_delete_idempotent(self): + self.client.delete(self.url) # first — 204 + resp = self.client.delete(self.url) + self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) + + +# --------------------------------------------------------------------------- +# PUT /mcp_server/api/teams/{id}/workspaces/ +# --------------------------------------------------------------------------- + + +class TeamWorkspacesTest(_AuthenticatedAPITest): + def setUp(self): + super().setUp() + self.team = Team.objects.create( + id=uuid.uuid4(), name="t", active=True, active_jti=uuid.uuid4() + ) + self.url = reverse( + "mcp-server-api:team-workspaces", + kwargs={"team_id": self.team.id}, + ) + + def _ws_ids(self): + return sorted( + self.team.workspace_assignments.values_list( + "workspace_id", flat=True + ) + ) + + def test_unknown_team_returns_404(self): + url = reverse( + "mcp-server-api:team-workspaces", + kwargs={"team_id": uuid.uuid4()}, + ) + resp = self.client.put( + url, {"workspace_ids": ["x"]}, format="json" + ) + self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND) + + def test_replace_adds_all(self): + resp = self.client.put( + self.url, + {"workspace_ids": ["ws-a", "ws-b"]}, + format="json", + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(sorted(resp.data["workspace_ids"]), ["ws-a", "ws-b"]) + self.assertEqual(self._ws_ids(), ["ws-a", "ws-b"]) + + def test_replace_idempotent_second_call_is_noop(self): + self.client.put( + self.url, + {"workspace_ids": ["ws-a", "ws-b"]}, + format="json", + ) + existing = list(self.team.workspace_assignments.all()) + pks_before = sorted(a.pk for a in existing) + + # Second identical PUT should not re-create rows. + resp = self.client.put( + self.url, + {"workspace_ids": ["ws-a", "ws-b"]}, + format="json", + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + pks_after = sorted( + a.pk for a in self.team.workspace_assignments.all() + ) + self.assertEqual(pks_before, pks_after) + + def test_replace_removes_dropped(self): + # Start with a, b, c + self.client.put( + self.url, + {"workspace_ids": ["ws-a", "ws-b", "ws-c"]}, + format="json", + ) + # Drop b, add d. + resp = self.client.put( + self.url, + {"workspace_ids": ["ws-a", "ws-c", "ws-d"]}, + format="json", + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(self._ws_ids(), ["ws-a", "ws-c", "ws-d"]) + + def test_replace_with_empty_set_fail_closed(self): + self.client.put( + self.url, + {"workspace_ids": ["ws-a"]}, + format="json", + ) + resp = self.client.put( + self.url, {"workspace_ids": []}, format="json" + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(self._ws_ids(), []) + + def test_duplicates_deduped(self): + resp = self.client.put( + self.url, + {"workspace_ids": ["ws-a", "ws-a", "ws-b"]}, + format="json", + ) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(sorted(resp.data["workspace_ids"]), ["ws-a", "ws-b"]) + self.assertEqual(self._ws_ids(), ["ws-a", "ws-b"]) + + def test_empty_string_rejected(self): + resp = self.client.put( + self.url, + {"workspace_ids": ["ws-a", ""]}, + format="json", + ) + self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + + +# --------------------------------------------------------------------------- +# POST /mcp_server/api/teams/{id}/rotate/ +# --------------------------------------------------------------------------- + + +class TeamRotateTest(_AuthenticatedAPITest): + def setUp(self): + super().setUp() + _seed_signing_key() + self.team = Team.objects.create( + id=uuid.uuid4(), name="t", active=True, active_jti=uuid.uuid4() + ) + self.url = reverse( + "mcp-server-api:team-rotate", + kwargs={"team_id": self.team.id}, + ) + + def test_unknown_team_returns_404(self): + url = reverse( + "mcp-server-api:team-rotate", + kwargs={"team_id": uuid.uuid4()}, + ) + resp = self.client.post(url) + self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND) + + def test_rotate_returns_new_jwt_and_changes_active_jti(self): + before = self.team.active_jti + resp = self.client.post(self.url) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertIn("jwt", resp.data) + self.team.refresh_from_db() + self.assertNotEqual(self.team.active_jti, before) + + def test_rotate_inactive_team_409(self): + self.team.deactivate() + resp = self.client.post(self.url) + self.assertEqual(resp.status_code, status.HTTP_409_CONFLICT) + # And we did NOT revive the team by side effect. + self.team.refresh_from_db() + self.assertFalse(self.team.active) + self.assertIsNone(self.team.active_jti) + + def test_rotate_without_signing_key_returns_503(self): + MCPSigningKey.objects.update(is_active=False) + resp = self.client.post(self.url) + self.assertEqual( + resp.status_code, status.HTTP_503_SERVICE_UNAVAILABLE + ) diff --git a/mnemosyne/mcp_server/tests/test_teams_jwt.py b/mnemosyne/mcp_server/tests/test_teams_jwt.py new file mode 100644 index 0000000..a649a71 --- /dev/null +++ b/mnemosyne/mcp_server/tests/test_teams_jwt.py @@ -0,0 +1,111 @@ +"""Tests for :mod:`mcp_server.teams` — team-JWT mint. + +Covers the mint path in isolation. The *validation* half of the +round-trip (``resolve_mcp_jwt`` with ``typ=team``) lives in +``test_auth.py``; this module asserts what ``mint_team_jwt`` encodes +into the token header / payload so the two halves match up. +""" + +from __future__ import annotations + +import time +import uuid + +import jwt as pyjwt +from django.test import TestCase + +from mcp_server.models import MCPSigningKey, Team +from mcp_server.teams import TeamJWTError, mint_team_jwt + + +def _make_key(kid: str = "k", is_active: bool = True) -> MCPSigningKey: + return MCPSigningKey.objects.create( + kid=kid, secret_hex="a" * 64, is_active=is_active + ) + + +def _make_team(**overrides) -> Team: + data = dict( + id=uuid.uuid4(), + name="t", + active=True, + active_jti=uuid.uuid4(), + ) + data.update(overrides) + return Team.objects.create(**data) + + +class MintTeamJWTHappyPathTest(TestCase): + def setUp(self): + self.key = _make_key("k-1") + self.team = _make_team() + + def test_returns_signed_jwt_string(self): + token = mint_team_jwt(self.team) + self.assertIsInstance(token, str) + # Three segments, HS256 header. + self.assertEqual(token.count("."), 2) + header = pyjwt.get_unverified_header(token) + self.assertEqual(header["alg"], "HS256") + self.assertEqual(header["kid"], self.key.kid) + + def test_payload_matches_spec(self): + before = int(time.time()) + token = mint_team_jwt(self.team) + after = int(time.time()) + + decoded = pyjwt.decode( + token, + bytes.fromhex(self.key.secret_hex), + algorithms=["HS256"], + options={"verify_aud": False}, + ) + self.assertEqual(decoded["iss"], "mnemosyne") + self.assertEqual(decoded["aud"], "mnemosyne") + self.assertEqual(decoded["typ"], "team") + self.assertEqual(decoded["sub"], f"team:{self.team.id}") + self.assertEqual(decoded["jti"], str(self.team.active_jti)) + # iat is "now-ish" (inclusive bounds because the call is fast). + self.assertGreaterEqual(decoded["iat"], before) + self.assertLessEqual(decoded["iat"], after) + # exp is 10 years in the future. + ten_years = 10 * 365 * 24 * 60 * 60 + self.assertEqual(decoded["exp"], decoded["iat"] + ten_years) + + def test_picks_newest_active_signing_key(self): + newer = _make_key("k-2") + token = mint_team_jwt(self.team) + header = pyjwt.get_unverified_header(token) + self.assertEqual(header["kid"], newer.kid) + + def test_ignores_retired_keys(self): + # Retire the only active key; newer retired key should still + # be skipped — mint fails because there is no active key. + self.key.is_active = False + self.key.save(update_fields=["is_active"]) + with self.assertRaises(TeamJWTError): + mint_team_jwt(self.team) + + +class MintTeamJWTFailureModesTest(TestCase): + def test_no_signing_key_raises(self): + team = _make_team() + with self.assertRaises(TeamJWTError) as ctx: + mint_team_jwt(team) + self.assertIn("signing", str(ctx.exception).lower()) + + def test_missing_active_jti_raises(self): + _make_key() + team = _make_team(active_jti=None) + with self.assertRaises(TeamJWTError) as ctx: + mint_team_jwt(team) + # Should name the thing the caller forgot to do. + self.assertIn("rotate_jti", str(ctx.exception)) + + def test_invalid_hex_secret_raises(self): + MCPSigningKey.objects.create( + kid="broken", secret_hex="not-hex!!", is_active=True + ) + team = _make_team() + with self.assertRaises(TeamJWTError): + mint_team_jwt(team) diff --git a/test-postgres.sh b/test-postgres.sh new file mode 100755 index 0000000..fae454d --- /dev/null +++ b/test-postgres.sh @@ -0,0 +1,127 @@ +#!/bin/bash +# Run the Mnemosyne Django test suite against an ephemeral Postgres + Neo4j. +# +# Pattern borrowed from spelunker/test-postgres.sh. Both databases run as +# throwaway Docker containers on a private bridge network; each invocation +# gets a fresh pair. The Django test runner applies migrations to the +# empty Postgres and tears the whole network down on exit. +# +# Neo4j is included because a handful of library/tests/* modules touch +# neomodel at import time even though the assertions themselves stub the +# cypher layer; having a reachable bolt endpoint keeps startup probes +# from crashing. +# +# Usage: +# ./test-postgres.sh # run everything +# ./test-postgres.sh mcp_server # scope to one app +# ./test-postgres.sh mcp_server.tests.test_auth +# ./test-postgres.sh mcp_server -v 2 --failfast + +set -euo pipefail + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +NET="mnemosyne-test-$$" +PG_CONTAINER="mnemosyne-test-pg-$$" +NEO_CONTAINER="mnemosyne-test-neo4j-$$" +PG_USER="mnemosyne" +PG_PASS="mnemosyne" +PG_DB="mnemosyne_test" +NEO_PASS="mnemosynetestpw" # Neo4j rejects short or obvious secrets. + +IMAGE="git.helu.ca/r/mnemosyne:latest" + +# Colours (skipped when not a TTY). +if [ -t 1 ]; then + GREEN='\033[0;32m'; RED='\033[0;31m'; YELLOW='\033[1;33m'; NC='\033[0m' +else + GREEN=''; RED=''; YELLOW=''; NC='' +fi + +say() { printf "${GREEN}==> %s${NC}\n" "$*"; } +warn() { printf "${YELLOW}[!] %s${NC}\n" "$*"; } +die() { printf "${RED}[✗] %s${NC}\n" "$*" >&2; exit 1; } + +# --------------------------------------------------------------------------- +# Cleanup +# --------------------------------------------------------------------------- + +cleanup() { + say "Cleaning up test containers" + docker rm -f "$PG_CONTAINER" >/dev/null 2>&1 || true + docker rm -f "$NEO_CONTAINER" >/dev/null 2>&1 || true + docker network rm "$NET" >/dev/null 2>&1 || true +} +trap cleanup EXIT + +# --------------------------------------------------------------------------- +# Start the support services +# --------------------------------------------------------------------------- + +say "Creating Docker network $NET" +docker network create "$NET" >/dev/null + +say "Starting Postgres ($PG_CONTAINER)" +docker run -d --rm \ + --name "$PG_CONTAINER" \ + --network "$NET" \ + -e POSTGRES_USER="$PG_USER" \ + -e POSTGRES_PASSWORD="$PG_PASS" \ + -e POSTGRES_DB="$PG_DB" \ + postgres:16-alpine \ + >/dev/null + +say "Starting Neo4j ($NEO_CONTAINER)" +docker run -d --rm \ + --name "$NEO_CONTAINER" \ + --network "$NET" \ + -e NEO4J_AUTH="neo4j/$NEO_PASS" \ + -e NEO4J_PLUGINS='["apoc"]' \ + neo4j:5-community \ + >/dev/null + +# Wait for Postgres to accept connections. +say "Waiting for Postgres to become ready" +for i in $(seq 1 30); do + if docker exec "$PG_CONTAINER" pg_isready -U "$PG_USER" -d "$PG_DB" >/dev/null 2>&1; then + break + fi + sleep 1 +done +docker exec "$PG_CONTAINER" pg_isready -U "$PG_USER" -d "$PG_DB" >/dev/null \ + || die "Postgres never became ready" + +# Wait for Neo4j's HTTP endpoint — bolt comes up just after it. +say "Waiting for Neo4j to become ready" +for i in $(seq 1 60); do + if docker exec "$NEO_CONTAINER" wget -qO- http://localhost:7474/ >/dev/null 2>&1; then + break + fi + sleep 1 +done + +# --------------------------------------------------------------------------- +# Run the test command +# --------------------------------------------------------------------------- + +SOURCE_DIR="$(cd "$(dirname "$0")" && pwd)/mnemosyne" + +say "Running Django test suite" +docker run --rm \ + --network "$NET" \ + -v "$SOURCE_DIR":/app \ + -w /app \ + -e DJANGO_SETTINGS_MODULE=mnemosyne.settings \ + -e APP_DB_NAME="$PG_DB" \ + -e APP_DB_USER="$PG_USER" \ + -e APP_DB_PASSWORD="$PG_PASS" \ + -e DB_HOST="$PG_CONTAINER" \ + -e DB_PORT=5432 \ + -e NEOMODEL_NEO4J_BOLT_URL="bolt://neo4j:$NEO_PASS@$NEO_CONTAINER:7687" \ + -e SECRET_KEY=test-only-not-a-real-secret \ + -e DEBUG=0 \ + -e MCP_REQUIRE_AUTH=True \ + "$IMAGE" \ + python manage.py test "$@"