fix(mcp): disable audience verification in resolve_mcp_jwt
All checks were successful
CVE Scan & Docker Build / security-scan (push) Successful in 50s
CVE Scan & Docker Build / build-and-push (push) Successful in 2m16s

Team JWTs include `aud=mnemosyne` while per-turn JWTs omit `aud`
entirely. Since `iss` + `typ` already partition the two token
populations, explicitly skip audience verification to avoid rejecting
valid tokens.

Also expand test coverage for the MCP auth surface to exercise all
three credential types (opaque MCPToken, per-turn JWT, team JWT),
including replay cache behavior and Neo4j-backed library resolution
via mocked cypher queries.
This commit is contained in:
2026-05-10 12:32:58 -04:00
parent 16fb7ff4dc
commit 6a4fecf488
7 changed files with 1394 additions and 4 deletions

View File

@@ -222,7 +222,14 @@ def resolve_mcp_jwt(token_string: str) -> dict:
secret,
algorithms=["HS256"],
leeway=_JWT_LEEWAY_SECONDS,
options={"require": ["exp", "iat", "iss", "sub", "jti"]},
options={
"require": ["exp", "iat", "iss", "sub", "jti"],
# Team JWTs carry ``aud=mnemosyne`` for informational
# purposes; per-turn JWTs omit ``aud`` entirely. We
# don't enforce either shape because ``iss`` + ``typ``
# already partition the two token populations.
"verify_aud": False,
},
# ``issuer=`` accepts ``str | Iterable[str]`` and raises
# ``InvalidIssuerError`` if the claim is outside the set.
issuer=list(_JWT_ISS_VALUES),

View File

@@ -1,18 +1,62 @@
"""Tests for resolve_mcp_user."""
"""Tests for the MCP auth surface.
Covers all three credential types described in
``docs/DAEDALUS_PALLAS_INTEGRATION_v1.md`` §3.2:
1. Opaque :class:`~mcp_server.models.MCPToken` — ``resolve_mcp_user``
+ ``MCPAuthMiddleware`` opaque branch.
2. Per-turn JWT (``iss=daedalus``, legacy) — ``resolve_mcp_jwt`` normal
path, ``_remember_jti`` replay cache, ``claims["libs"]``-derived
resolved_libraries.
3. Team JWT (``iss=mnemosyne typ=team``) — ``resolve_mcp_jwt`` team
branch (bypasses replay cache), ``_libraries_for_team`` enforcing
``Team.active`` + ``Team.active_jti``, ``resolved_libraries`` coming
from a mocked Neo4j lookup.
Neo4j is stubbed by patching ``neomodel.db.cypher_query`` — these tests
do not require a running Neo4j instance.
"""
from __future__ import annotations
import time
import uuid
from datetime import timedelta
from unittest import mock
import jwt as pyjwt
from django.contrib.auth import get_user_model
from django.test import TestCase
from django.utils import timezone
from mcp_server.auth import MCPAuthError, resolve_mcp_user
from mcp_server.models import MCPToken
from mcp_server import auth as auth_module
from mcp_server.auth import (
MCPAuthError,
_libraries_for_team,
_remember_jti,
_resolved_libraries_for_jwt,
looks_like_jwt,
resolve_mcp_jwt,
resolve_mcp_user,
)
from mcp_server.models import (
MCPSigningKey,
MCPToken,
Team,
TeamWorkspaceAssignment,
)
User = get_user_model()
# ---------------------------------------------------------------------------
# Opaque MCPToken
# ---------------------------------------------------------------------------
class ResolveMCPUserTest(TestCase):
"""The pre-existing coverage of the opaque bearer path."""
def setUp(self):
self.user = User.objects.create_user(
username="bob", email="bob@example.com", password="pw"
@@ -68,3 +112,333 @@ class ResolveMCPUserTest(TestCase):
value, plaintext,
f"Plaintext token leaked into the database: {value!r}",
)
# ---------------------------------------------------------------------------
# JWT plumbing helpers
# ---------------------------------------------------------------------------
def _make_signing_key() -> MCPSigningKey:
"""Fresh active HS256 key with a hex secret the auth module accepts."""
secret_hex = "a" * 64 # 256 bits / 64 hex chars
return MCPSigningKey.objects.create(
kid=f"test-kid-{uuid.uuid4().hex[:8]}",
secret_hex=secret_hex,
is_active=True,
)
def _encode(payload: dict, key: MCPSigningKey) -> str:
return pyjwt.encode(
payload,
bytes.fromhex(key.secret_hex),
algorithm="HS256",
headers={"kid": key.kid},
)
def _per_turn_payload(**overrides) -> dict:
now = int(time.time())
payload = {
"iss": "daedalus",
"sub": "user:someone",
"iat": now,
"exp": now + 300,
"jti": str(uuid.uuid4()),
"libs": ["lib-a", "lib-b"],
"ws": "ws-42",
}
payload.update(overrides)
return payload
def _team_payload(team: Team, **overrides) -> dict:
now = int(time.time())
payload = {
"iss": "mnemosyne",
"aud": "mnemosyne",
"sub": f"team:{team.id}",
"typ": "team",
"iat": now,
"exp": now + 3600,
"jti": str(team.active_jti),
}
payload.update(overrides)
return payload
# ---------------------------------------------------------------------------
# looks_like_jwt
# ---------------------------------------------------------------------------
class LooksLikeJWTTest(TestCase):
def test_three_segments_with_json_header(self):
key = _make_signing_key()
token = _encode(_per_turn_payload(), key)
self.assertTrue(looks_like_jwt(token))
def test_two_segments_rejected(self):
self.assertFalse(looks_like_jwt("aaa.bbb"))
def test_four_segments_rejected(self):
self.assertFalse(looks_like_jwt("aaa.bbb.ccc.ddd"))
def test_garbage_header_rejected(self):
self.assertFalse(looks_like_jwt("!!!.bbb.ccc"))
def test_opaque_token_rejected(self):
# Real ``MCPToken.create_token`` plaintext is 48-byte base64, often
# contains dashes but never two dots.
self.assertFalse(
looks_like_jwt("CxGb3rThJ7_4jUGl0q2_fakey_fakey_fakey_fakey_fakey")
)
# ---------------------------------------------------------------------------
# Per-turn JWT branch (iss=daedalus)
# ---------------------------------------------------------------------------
class ResolvePerTurnJWTTest(TestCase):
def setUp(self):
self.key = _make_signing_key()
# Make sure the replay cache is clean across tests; it is a
# module-level dict.
auth_module._JTI_CACHE.clear()
def test_happy_path_returns_normalized_claims(self):
token = _encode(_per_turn_payload(), self.key)
claims = resolve_mcp_jwt(token)
self.assertEqual(claims["iss"], "daedalus")
self.assertEqual(claims["libs"], ["lib-a", "lib-b"])
self.assertEqual(claims["ws"], "ws-42")
self.assertEqual(claims["kid"], self.key.kid)
self.assertNotIn("typ", claims) # per-turn tokens omit typ
self.assertNotIn("team_id", claims)
def test_libs_defaults_to_empty_list(self):
payload = _per_turn_payload()
payload.pop("libs", None)
token = _encode(payload, self.key)
claims = resolve_mcp_jwt(token)
self.assertEqual(claims["libs"], [])
def test_ws_defaults_to_none(self):
payload = _per_turn_payload()
payload.pop("ws", None)
token = _encode(payload, self.key)
claims = resolve_mcp_jwt(token)
self.assertIsNone(claims["ws"])
def test_invalid_libs_shape_rejected(self):
payload = _per_turn_payload(libs="not-a-list")
token = _encode(payload, self.key)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_unknown_issuer_rejected(self):
payload = _per_turn_payload(iss="attacker")
token = _encode(payload, self.key)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_expired_token_rejected(self):
payload = _per_turn_payload(
iat=int(time.time()) - 600,
exp=int(time.time()) - 60,
)
token = _encode(payload, self.key)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_retired_signing_key_rejected(self):
self.key.is_active = False
self.key.save(update_fields=["is_active"])
token = _encode(_per_turn_payload(), self.key)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_unknown_kid_rejected(self):
foreign_secret_hex = "b" * 64
token = pyjwt.encode(
_per_turn_payload(),
bytes.fromhex(foreign_secret_hex),
algorithm="HS256",
headers={"kid": "no-such-kid"},
)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_bad_signature_rejected(self):
# Re-sign with a different secret but reuse the stored kid.
foreign_secret_hex = "b" * 64
token = pyjwt.encode(
_per_turn_payload(),
bytes.fromhex(foreign_secret_hex),
algorithm="HS256",
headers={"kid": self.key.kid},
)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
class PerTurnReplayCacheTest(TestCase):
"""``_remember_jti`` flags a jti as replay iff we see it *after* its exp."""
def setUp(self):
auth_module._JTI_CACHE.clear()
def test_first_sighting_not_replay(self):
self.assertFalse(_remember_jti("jti-1", time.time() + 300))
def test_same_jti_within_validity_not_replay(self):
# Daedalus routinely reuses a per-turn jti across multiple tool
# calls in a single turn; that must not trip the replay gate.
exp = time.time() + 300
_remember_jti("jti-2", exp)
self.assertFalse(_remember_jti("jti-2", exp))
self.assertFalse(_remember_jti("jti-2", exp))
def test_same_jti_after_exp_is_replay(self):
exp = time.time() - 60 # already past
# First sighting just records; second is the belt-and-braces trip.
_remember_jti("jti-3", exp)
self.assertTrue(_remember_jti("jti-3", exp))
# ---------------------------------------------------------------------------
# Team JWT branch (iss=mnemosyne, typ=team)
# ---------------------------------------------------------------------------
class ResolveTeamJWTTest(TestCase):
def setUp(self):
self.key = _make_signing_key()
self.team = Team.objects.create(
id=uuid.uuid4(),
name="Pallas-Harper",
active=True,
active_jti=uuid.uuid4(),
)
def test_team_jwt_happy_path_populates_team_id(self):
token = _encode(_team_payload(self.team), self.key)
claims = resolve_mcp_jwt(token)
self.assertEqual(claims["typ"], "team")
self.assertEqual(claims["team_id"], self.team.id)
# Team tokens must never be normalized into the per-turn shape.
self.assertNotIn("libs", claims)
self.assertNotIn("ws", claims)
def test_malformed_sub_rejected(self):
token = _encode(
_team_payload(self.team, sub="not-a-team-sub"), self.key
)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_non_uuid_sub_rejected(self):
token = _encode(
_team_payload(self.team, sub="team:not-a-uuid"), self.key
)
with self.assertRaises(MCPAuthError):
resolve_mcp_jwt(token)
def test_team_jwt_bypasses_replay_cache(self):
"""``typ=team`` tokens are intentionally reused; no jti pollution."""
auth_module._JTI_CACHE.clear()
token = _encode(_team_payload(self.team), self.key)
resolve_mcp_jwt(token)
resolve_mcp_jwt(token) # a second validation must succeed
# And the replay cache must not have been touched.
self.assertEqual(auth_module._JTI_CACHE, {})
# ---------------------------------------------------------------------------
# _libraries_for_team — Postgres + mocked Neo4j
# ---------------------------------------------------------------------------
class LibrariesForTeamTest(TestCase):
def setUp(self):
self.team = Team.objects.create(
id=uuid.uuid4(),
name="T",
active=True,
active_jti=uuid.uuid4(),
)
def test_unknown_team_rejected(self):
with self.assertRaises(MCPAuthError):
_libraries_for_team(uuid.uuid4(), str(uuid.uuid4()))
def test_inactive_team_rejected(self):
self.team.active = False
self.team.save(update_fields=["active"])
with self.assertRaises(MCPAuthError):
_libraries_for_team(self.team.id, str(self.team.active_jti))
def test_stale_jti_rejected(self):
# Incoming jti no longer matches — rotation happened since mint.
with self.assertRaises(MCPAuthError):
_libraries_for_team(self.team.id, str(uuid.uuid4()))
def test_no_workspace_assignments_returns_empty_list(self):
# Fail-closed: a team with no assignments sees no libraries.
libs = _libraries_for_team(self.team.id, str(self.team.active_jti))
self.assertEqual(libs, [])
def test_workspace_assignments_translated_to_library_uids(self):
TeamWorkspaceAssignment.objects.create(team=self.team, workspace_id="ws-a")
TeamWorkspaceAssignment.objects.create(team=self.team, workspace_id="ws-b")
fake_rows = [["lib-1"], ["lib-2"], ["lib-3"]]
with mock.patch(
"neomodel.db.cypher_query",
return_value=(fake_rows, ["l.uid"]),
) as cypher:
libs = _libraries_for_team(
self.team.id, str(self.team.active_jti)
)
self.assertEqual(sorted(libs), ["lib-1", "lib-2", "lib-3"])
# Assert we pass the workspace ids as the ``ws`` parameter, which
# the Cypher query projects through to the Neo4j driver.
(args, kwargs) = cypher.call_args
self.assertIn("WHERE l.workspace_id IN $ws", args[0])
self.assertEqual(sorted(args[1]["ws"]), ["ws-a", "ws-b"])
# ---------------------------------------------------------------------------
# _resolved_libraries_for_jwt dispatcher
# ---------------------------------------------------------------------------
class ResolvedLibrariesForJWTDispatcherTest(TestCase):
def test_per_turn_claims_use_libs(self):
claims = {"libs": ["one", "two"]}
self.assertEqual(
_resolved_libraries_for_jwt(claims), ["one", "two"]
)
def test_team_claims_use_team_lookup(self):
team = Team.objects.create(
id=uuid.uuid4(),
name="T",
active=True,
active_jti=uuid.uuid4(),
)
claims = {
"typ": "team",
"team_id": team.id,
"jti": str(team.active_jti),
}
with mock.patch(
"neomodel.db.cypher_query",
return_value=([["lib-team"]], ["l.uid"]),
):
TeamWorkspaceAssignment.objects.create(
team=team, workspace_id="ws-x"
)
out = _resolved_libraries_for_jwt(claims)
self.assertEqual(out, ["lib-team"])

View File

@@ -0,0 +1,161 @@
"""Tests for the ``backfill_library_memberships`` management command.
The command queries ``library.models.Library.nodes`` (Neo4j), which we
don't want to exercise in unit tests. We substitute a lightweight fake
library iterable onto ``Library.nodes.filter`` via ``unittest.mock``.
"""
from __future__ import annotations
from io import StringIO
from unittest import mock
from django.contrib.auth import get_user_model
from django.core.management import call_command
from django.core.management.base import CommandError
from django.test import TestCase
from mcp_server.models import LibraryMembership
User = get_user_model()
class _FakeLibrary:
def __init__(self, uid: str, workspace_id=None):
self.uid = uid
self.workspace_id = workspace_id
class _FakeNodes:
"""Stand-in for the neomodel ``Library.nodes`` node set."""
def __init__(self, all_libs):
self._all = list(all_libs)
def filter(self, **kwargs):
out = list(self._all)
if "workspace_id__isnull" in kwargs:
want = kwargs["workspace_id__isnull"]
out = [l for l in out if (l.workspace_id is None) == want]
return out
def _run(*cli_args):
"""Invoke the command with stdout/stderr buffered. Returns stdout."""
buf = StringIO()
call_command(
"backfill_library_memberships", *cli_args, stdout=buf, stderr=StringIO()
)
return buf.getvalue()
class BackfillCommandTest(TestCase):
def setUp(self):
self.superuser = User.objects.create_superuser(
username="admin", password="pw", email="a@b.c"
)
self.patcher = mock.patch(
"library.models.Library.nodes",
new=_FakeNodes(
[
_FakeLibrary("lib-a"), # global
_FakeLibrary("lib-b"), # global
_FakeLibrary("lib-ws", "ws-1"), # workspace-scoped
]
),
)
self.patcher.start()
self.addCleanup(self.patcher.stop)
def test_assigns_owner_to_all_global_libraries(self):
_run()
owners = LibraryMembership.objects.filter(user=self.superuser)
self.assertEqual(
sorted(owners.values_list("library_uid", flat=True)),
["lib-a", "lib-b"],
)
self.assertTrue(
all(
m.role == LibraryMembership.Role.OWNER
for m in owners
)
)
def test_skips_workspace_scoped_libraries(self):
_run()
self.assertFalse(
LibraryMembership.objects.filter(library_uid="lib-ws").exists()
)
def test_idempotent(self):
_run()
count_after_first = LibraryMembership.objects.count()
_run()
self.assertEqual(LibraryMembership.objects.count(), count_after_first)
def test_skips_libraries_with_existing_membership(self):
# Pre-seed someone else as manager on lib-a; the command should
# not overwrite or add a second row.
other = User.objects.create_user(username="other", password="pw")
LibraryMembership.objects.create(
user=other,
library_uid="lib-a",
role=LibraryMembership.Role.MANAGER,
)
_run()
# lib-a keeps its original manager-row; the superuser is NOT
# granted on it because any membership existing is an opt-out.
memberships_for_lib_a = LibraryMembership.objects.filter(
library_uid="lib-a"
)
self.assertEqual(memberships_for_lib_a.count(), 1)
self.assertEqual(memberships_for_lib_a.first().user, other)
# lib-b had none, so the superuser got it.
self.assertTrue(
LibraryMembership.objects.filter(
library_uid="lib-b", user=self.superuser
).exists()
)
def test_dry_run_does_not_persist(self):
out = _run("--dry-run")
self.assertIn("would insert", out)
self.assertEqual(LibraryMembership.objects.count(), 0)
def test_user_arg_overrides_default_superuser(self):
target = User.objects.create_user(username="librarian", password="pw")
_run("--user", "librarian")
self.assertTrue(
LibraryMembership.objects.filter(user=target).exists()
)
self.assertFalse(
LibraryMembership.objects.filter(user=self.superuser).exists()
)
def test_unknown_user_raises_command_error(self):
with self.assertRaises(CommandError):
_run("--user", "nobody")
def test_no_superuser_raises_command_error(self):
self.superuser.is_superuser = False
self.superuser.save(update_fields=["is_superuser"])
with self.assertRaises(CommandError):
_run()
class BackfillCommandEmptyNeo4jTest(TestCase):
def setUp(self):
self.superuser = User.objects.create_superuser(
username="admin", password="pw", email="a@b.c"
)
self.patcher = mock.patch(
"library.models.Library.nodes", new=_FakeNodes([])
)
self.patcher.start()
self.addCleanup(self.patcher.stop)
def test_no_libraries_is_noop(self):
_run()
self.assertEqual(LibraryMembership.objects.count(), 0)

View File

@@ -0,0 +1,251 @@
"""Tests for the Team / LibraryMembership / TeamWorkspaceAssignment models.
``MCPToken``'s hash-at-rest semantics live in ``test_token.py``; this
module exercises the new Phase 2 tables introduced by
``docs/DAEDALUS_PALLAS_INTEGRATION_v1.md`` §4 plus the
``allowed_libraries`` JSONField attached to the existing
:class:`~mcp_server.models.MCPToken`.
"""
from __future__ import annotations
import uuid
from django.contrib.auth import get_user_model
from django.db import IntegrityError, transaction
from django.test import TestCase
from mcp_server.models import (
LibraryMembership,
MCPSigningKey,
MCPToken,
Team,
TeamWorkspaceAssignment,
)
User = get_user_model()
# ---------------------------------------------------------------------------
# LibraryMembership
# ---------------------------------------------------------------------------
class LibraryMembershipTest(TestCase):
def setUp(self):
self.user = User.objects.create_user(username="u", password="pw")
def test_role_choices_exposed(self):
self.assertEqual(LibraryMembership.Role.OWNER, "owner")
self.assertEqual(LibraryMembership.Role.MANAGER, "manager")
self.assertEqual(LibraryMembership.Role.READER, "reader")
def test_unique_per_user_and_library(self):
LibraryMembership.objects.create(
user=self.user,
library_uid="lib-1",
role=LibraryMembership.Role.OWNER,
)
# Second membership on the same (user, library_uid) must fail
# even if the role differs — callers consolidate to the
# higher role rather than stacking rows.
with transaction.atomic():
with self.assertRaises(IntegrityError):
LibraryMembership.objects.create(
user=self.user,
library_uid="lib-1",
role=LibraryMembership.Role.READER,
)
def test_same_library_different_users_allowed(self):
other = User.objects.create_user(username="u2", password="pw")
LibraryMembership.objects.create(
user=self.user,
library_uid="lib-1",
role=LibraryMembership.Role.OWNER,
)
LibraryMembership.objects.create(
user=other,
library_uid="lib-1",
role=LibraryMembership.Role.READER,
)
self.assertEqual(LibraryMembership.objects.count(), 2)
def test_same_user_different_libraries_allowed(self):
LibraryMembership.objects.create(
user=self.user,
library_uid="lib-1",
role=LibraryMembership.Role.OWNER,
)
LibraryMembership.objects.create(
user=self.user,
library_uid="lib-2",
role=LibraryMembership.Role.MANAGER,
)
self.assertEqual(
set(
LibraryMembership.objects.filter(user=self.user)
.values_list("library_uid", flat=True)
),
{"lib-1", "lib-2"},
)
# ---------------------------------------------------------------------------
# MCPToken.allowed_libraries
# ---------------------------------------------------------------------------
class MCPTokenAllowedLibrariesTest(TestCase):
def setUp(self):
self.user = User.objects.create_user(username="u", password="pw")
def test_defaults_to_empty_list(self):
token, _ = MCPToken.objects.create_token(user=self.user, name="t")
self.assertEqual(token.allowed_libraries, [])
def test_create_token_accepts_allowed_libraries(self):
token, _ = MCPToken.objects.create_token(
user=self.user, name="t", allowed_libraries=["lib-a", "lib-b"]
)
self.assertEqual(token.allowed_libraries, ["lib-a", "lib-b"])
def test_allowed_libraries_round_trips(self):
token, _ = MCPToken.objects.create_token(
user=self.user,
name="t",
allowed_libraries=["lib-a", "lib-b", "lib-c"],
)
token.refresh_from_db()
self.assertEqual(
token.allowed_libraries, ["lib-a", "lib-b", "lib-c"]
)
# ---------------------------------------------------------------------------
# Team + rotate_jti / deactivate
# ---------------------------------------------------------------------------
class TeamTest(TestCase):
def test_create_with_explicit_uuid(self):
tid = uuid.uuid4()
team = Team.objects.create(id=tid, name="Harper")
self.assertEqual(team.id, tid)
self.assertTrue(team.active)
self.assertIsNone(team.active_jti)
def test_rotate_jti_installs_fresh_uuid(self):
team = Team.objects.create(id=uuid.uuid4(), name="t")
first = team.rotate_jti()
self.assertIsInstance(first, uuid.UUID)
self.assertEqual(team.active_jti, first)
second = team.rotate_jti()
self.assertNotEqual(first, second)
self.assertEqual(team.active_jti, second)
def test_rotate_jti_persists(self):
team = Team.objects.create(id=uuid.uuid4(), name="t")
team.rotate_jti()
# Reload from DB and make sure the UUID was committed.
reloaded = Team.objects.get(pk=team.id)
self.assertEqual(reloaded.active_jti, team.active_jti)
def test_deactivate_clears_active_jti(self):
team = Team.objects.create(id=uuid.uuid4(), name="t")
team.rotate_jti()
self.assertTrue(team.active)
team.deactivate()
self.assertFalse(team.active)
self.assertIsNone(team.active_jti)
# And it persisted.
reloaded = Team.objects.get(pk=team.id)
self.assertFalse(reloaded.active)
self.assertIsNone(reloaded.active_jti)
# ---------------------------------------------------------------------------
# TeamWorkspaceAssignment
# ---------------------------------------------------------------------------
class TeamWorkspaceAssignmentTest(TestCase):
def setUp(self):
self.team = Team.objects.create(id=uuid.uuid4(), name="t")
def test_unique_team_workspace_pair(self):
TeamWorkspaceAssignment.objects.create(
team=self.team, workspace_id="ws-a"
)
with transaction.atomic():
with self.assertRaises(IntegrityError):
TeamWorkspaceAssignment.objects.create(
team=self.team, workspace_id="ws-a"
)
def test_same_workspace_different_teams_allowed(self):
other = Team.objects.create(id=uuid.uuid4(), name="t2")
TeamWorkspaceAssignment.objects.create(
team=self.team, workspace_id="ws-a"
)
TeamWorkspaceAssignment.objects.create(
team=other, workspace_id="ws-a"
)
self.assertEqual(TeamWorkspaceAssignment.objects.count(), 2)
def test_cascade_on_team_delete(self):
TeamWorkspaceAssignment.objects.create(
team=self.team, workspace_id="ws-a"
)
self.team.delete()
self.assertEqual(TeamWorkspaceAssignment.objects.count(), 0)
def test_related_name_workspace_assignments(self):
# auth._libraries_for_team reaches through ``team.workspace_assignments``;
# make sure that attribute actually works.
TeamWorkspaceAssignment.objects.create(
team=self.team, workspace_id="ws-a"
)
TeamWorkspaceAssignment.objects.create(
team=self.team, workspace_id="ws-b"
)
ws_ids = sorted(
self.team.workspace_assignments.values_list(
"workspace_id", flat=True
)
)
self.assertEqual(ws_ids, ["ws-a", "ws-b"])
# ---------------------------------------------------------------------------
# MCPSigningKey.objects.current() — used by mint_team_jwt
# ---------------------------------------------------------------------------
class MCPSigningKeyCurrentTest(TestCase):
def test_none_when_no_keys(self):
self.assertIsNone(MCPSigningKey.objects.current())
def test_returns_active_not_retired(self):
retired = MCPSigningKey.objects.create(
kid="old", secret_hex="a" * 64, is_active=False
)
active = MCPSigningKey.objects.create(
kid="new", secret_hex="b" * 64, is_active=True
)
self.assertEqual(MCPSigningKey.objects.current().pk, active.pk)
self.assertNotEqual(MCPSigningKey.objects.current().pk, retired.pk)
def test_returns_newest_when_multiple_active(self):
older = MCPSigningKey.objects.create(
kid="older", secret_hex="a" * 64, is_active=True
)
newer = MCPSigningKey.objects.create(
kid="newer", secret_hex="b" * 64, is_active=True
)
self.assertEqual(MCPSigningKey.objects.current().pk, newer.pk)
# Sanity: older is still active, just older.
self.assertTrue(
MCPSigningKey.objects.filter(pk=older.pk, is_active=True).exists()
)

View File

@@ -0,0 +1,359 @@
"""Tests for the ``/mcp_server/api/teams/`` REST control plane.
This is the Daedalus-facing surface described in §7 of
``docs/DAEDALUS_PALLAS_INTEGRATION_v1.md``. We do NOT exercise HTTP
Basic auth here (that's part of DRF / the project's session auth
stack); instead we use :meth:`APIClient.force_authenticate` to focus
on the endpoints' own idempotence and state-transition rules.
"""
from __future__ import annotations
import uuid
import jwt as pyjwt
from django.contrib.auth import get_user_model
from django.test import TestCase
from django.urls import reverse
from rest_framework import status
from rest_framework.test import APIClient
from mcp_server.models import (
MCPSigningKey,
Team,
TeamWorkspaceAssignment,
)
User = get_user_model()
def _seed_signing_key() -> MCPSigningKey:
return MCPSigningKey.objects.create(
kid=f"test-{uuid.uuid4().hex[:6]}",
secret_hex="a" * 64,
is_active=True,
)
class _AuthenticatedAPITest(TestCase):
"""Shared ``APIClient`` authenticated as the service user.
The real deployment has Daedalus hit these endpoints as
``daedalus-service`` over HTTP Basic, but the view decorator is a
plain ``IsAuthenticated`` — so for unit-test purposes we use
``force_authenticate`` with any active user.
"""
def setUp(self):
self.service_user = User.objects.create_user(
username="daedalus-service", password="pw"
)
self.client = APIClient()
self.client.force_authenticate(user=self.service_user)
# ---------------------------------------------------------------------------
# POST /mcp_server/api/teams/
# ---------------------------------------------------------------------------
class TeamCreateTest(_AuthenticatedAPITest):
def setUp(self):
super().setUp()
self.url = reverse("mcp-server-api:team-create")
_seed_signing_key()
def test_requires_authentication(self):
self.client.force_authenticate(user=None)
resp = self.client.post(
self.url, {"id": str(uuid.uuid4()), "name": "t"}, format="json"
)
self.assertIn(resp.status_code, (401, 403))
def test_rejects_missing_fields(self):
resp = self.client.post(self.url, {}, format="json")
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("id", resp.data)
self.assertIn("name", resp.data)
def test_rejects_non_uuid_id(self):
resp = self.client.post(
self.url, {"id": "not-a-uuid", "name": "t"}, format="json"
)
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
def test_creates_team_and_returns_jwt(self):
tid = uuid.uuid4()
resp = self.client.post(
self.url, {"id": str(tid), "name": "Harper"}, format="json"
)
self.assertEqual(resp.status_code, status.HTTP_201_CREATED)
self.assertEqual(uuid.UUID(resp.data["id"]), tid)
self.assertEqual(resp.data["name"], "Harper")
self.assertTrue(resp.data["active"])
self.assertIn("jwt", resp.data)
# JWT decodes and carries the right sub + jti.
team = Team.objects.get(pk=tid)
header = pyjwt.get_unverified_header(resp.data["jwt"])
key = MCPSigningKey.objects.by_kid(header["kid"])
decoded = pyjwt.decode(
resp.data["jwt"],
bytes.fromhex(key.secret_hex),
algorithms=["HS256"],
options={"verify_aud": False},
)
self.assertEqual(decoded["sub"], f"team:{tid}")
self.assertEqual(decoded["jti"], str(team.active_jti))
def test_idempotent_on_same_id_returns_200_without_jwt(self):
tid = uuid.uuid4()
first = self.client.post(
self.url, {"id": str(tid), "name": "first"}, format="json"
)
self.assertEqual(first.status_code, status.HTTP_201_CREATED)
first_jti = Team.objects.get(pk=tid).active_jti
# Second POST with same id: idempotent hit. Must NOT rotate the
# jti (otherwise every retry storm re-issues a fresh credential).
second = self.client.post(
self.url,
{"id": str(tid), "name": "ignored-on-hit"},
format="json",
)
self.assertEqual(second.status_code, status.HTTP_200_OK)
self.assertNotIn("jwt", second.data)
self.assertEqual(
Team.objects.get(pk=tid).active_jti, first_jti
)
def test_mint_failure_returns_503(self):
# Retire all signing keys so mint_team_jwt cannot succeed.
MCPSigningKey.objects.update(is_active=False)
resp = self.client.post(
self.url,
{"id": str(uuid.uuid4()), "name": "x"},
format="json",
)
self.assertEqual(
resp.status_code, status.HTTP_503_SERVICE_UNAVAILABLE
)
# The transaction.atomic() wrapper must also have rolled back the
# Team row so we don't leave a team with no usable JWT.
self.assertEqual(Team.objects.count(), 0)
# ---------------------------------------------------------------------------
# GET / DELETE /mcp_server/api/teams/{id}/
# ---------------------------------------------------------------------------
class TeamDetailTest(_AuthenticatedAPITest):
def setUp(self):
super().setUp()
self.team = Team.objects.create(
id=uuid.uuid4(), name="t", active=True, active_jti=uuid.uuid4()
)
TeamWorkspaceAssignment.objects.create(
team=self.team, workspace_id="ws-a"
)
TeamWorkspaceAssignment.objects.create(
team=self.team, workspace_id="ws-b"
)
self.url = reverse(
"mcp-server-api:team-detail", kwargs={"team_id": self.team.id}
)
def test_get_returns_team_state(self):
resp = self.client.get(self.url)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(uuid.UUID(resp.data["id"]), self.team.id)
self.assertEqual(sorted(resp.data["workspace_ids"]), ["ws-a", "ws-b"])
# Never leak the JWT via GET.
self.assertNotIn("jwt", resp.data)
def test_get_unknown_team_returns_404(self):
url = reverse(
"mcp-server-api:team-detail", kwargs={"team_id": uuid.uuid4()}
)
resp = self.client.get(url)
self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND)
def test_delete_soft_deletes(self):
resp = self.client.delete(self.url)
self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
reloaded = Team.objects.get(pk=self.team.id)
self.assertFalse(reloaded.active)
self.assertIsNone(reloaded.active_jti)
# Workspace rows stay for audit — deactivation only flips flags.
self.assertEqual(
TeamWorkspaceAssignment.objects.filter(team=reloaded).count(),
2,
)
def test_delete_idempotent(self):
self.client.delete(self.url) # first — 204
resp = self.client.delete(self.url)
self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
# ---------------------------------------------------------------------------
# PUT /mcp_server/api/teams/{id}/workspaces/
# ---------------------------------------------------------------------------
class TeamWorkspacesTest(_AuthenticatedAPITest):
def setUp(self):
super().setUp()
self.team = Team.objects.create(
id=uuid.uuid4(), name="t", active=True, active_jti=uuid.uuid4()
)
self.url = reverse(
"mcp-server-api:team-workspaces",
kwargs={"team_id": self.team.id},
)
def _ws_ids(self):
return sorted(
self.team.workspace_assignments.values_list(
"workspace_id", flat=True
)
)
def test_unknown_team_returns_404(self):
url = reverse(
"mcp-server-api:team-workspaces",
kwargs={"team_id": uuid.uuid4()},
)
resp = self.client.put(
url, {"workspace_ids": ["x"]}, format="json"
)
self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND)
def test_replace_adds_all(self):
resp = self.client.put(
self.url,
{"workspace_ids": ["ws-a", "ws-b"]},
format="json",
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(sorted(resp.data["workspace_ids"]), ["ws-a", "ws-b"])
self.assertEqual(self._ws_ids(), ["ws-a", "ws-b"])
def test_replace_idempotent_second_call_is_noop(self):
self.client.put(
self.url,
{"workspace_ids": ["ws-a", "ws-b"]},
format="json",
)
existing = list(self.team.workspace_assignments.all())
pks_before = sorted(a.pk for a in existing)
# Second identical PUT should not re-create rows.
resp = self.client.put(
self.url,
{"workspace_ids": ["ws-a", "ws-b"]},
format="json",
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
pks_after = sorted(
a.pk for a in self.team.workspace_assignments.all()
)
self.assertEqual(pks_before, pks_after)
def test_replace_removes_dropped(self):
# Start with a, b, c
self.client.put(
self.url,
{"workspace_ids": ["ws-a", "ws-b", "ws-c"]},
format="json",
)
# Drop b, add d.
resp = self.client.put(
self.url,
{"workspace_ids": ["ws-a", "ws-c", "ws-d"]},
format="json",
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(self._ws_ids(), ["ws-a", "ws-c", "ws-d"])
def test_replace_with_empty_set_fail_closed(self):
self.client.put(
self.url,
{"workspace_ids": ["ws-a"]},
format="json",
)
resp = self.client.put(
self.url, {"workspace_ids": []}, format="json"
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(self._ws_ids(), [])
def test_duplicates_deduped(self):
resp = self.client.put(
self.url,
{"workspace_ids": ["ws-a", "ws-a", "ws-b"]},
format="json",
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(sorted(resp.data["workspace_ids"]), ["ws-a", "ws-b"])
self.assertEqual(self._ws_ids(), ["ws-a", "ws-b"])
def test_empty_string_rejected(self):
resp = self.client.put(
self.url,
{"workspace_ids": ["ws-a", ""]},
format="json",
)
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
# ---------------------------------------------------------------------------
# POST /mcp_server/api/teams/{id}/rotate/
# ---------------------------------------------------------------------------
class TeamRotateTest(_AuthenticatedAPITest):
def setUp(self):
super().setUp()
_seed_signing_key()
self.team = Team.objects.create(
id=uuid.uuid4(), name="t", active=True, active_jti=uuid.uuid4()
)
self.url = reverse(
"mcp-server-api:team-rotate",
kwargs={"team_id": self.team.id},
)
def test_unknown_team_returns_404(self):
url = reverse(
"mcp-server-api:team-rotate",
kwargs={"team_id": uuid.uuid4()},
)
resp = self.client.post(url)
self.assertEqual(resp.status_code, status.HTTP_404_NOT_FOUND)
def test_rotate_returns_new_jwt_and_changes_active_jti(self):
before = self.team.active_jti
resp = self.client.post(self.url)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertIn("jwt", resp.data)
self.team.refresh_from_db()
self.assertNotEqual(self.team.active_jti, before)
def test_rotate_inactive_team_409(self):
self.team.deactivate()
resp = self.client.post(self.url)
self.assertEqual(resp.status_code, status.HTTP_409_CONFLICT)
# And we did NOT revive the team by side effect.
self.team.refresh_from_db()
self.assertFalse(self.team.active)
self.assertIsNone(self.team.active_jti)
def test_rotate_without_signing_key_returns_503(self):
MCPSigningKey.objects.update(is_active=False)
resp = self.client.post(self.url)
self.assertEqual(
resp.status_code, status.HTTP_503_SERVICE_UNAVAILABLE
)

View File

@@ -0,0 +1,111 @@
"""Tests for :mod:`mcp_server.teams` — team-JWT mint.
Covers the mint path in isolation. The *validation* half of the
round-trip (``resolve_mcp_jwt`` with ``typ=team``) lives in
``test_auth.py``; this module asserts what ``mint_team_jwt`` encodes
into the token header / payload so the two halves match up.
"""
from __future__ import annotations
import time
import uuid
import jwt as pyjwt
from django.test import TestCase
from mcp_server.models import MCPSigningKey, Team
from mcp_server.teams import TeamJWTError, mint_team_jwt
def _make_key(kid: str = "k", is_active: bool = True) -> MCPSigningKey:
return MCPSigningKey.objects.create(
kid=kid, secret_hex="a" * 64, is_active=is_active
)
def _make_team(**overrides) -> Team:
data = dict(
id=uuid.uuid4(),
name="t",
active=True,
active_jti=uuid.uuid4(),
)
data.update(overrides)
return Team.objects.create(**data)
class MintTeamJWTHappyPathTest(TestCase):
def setUp(self):
self.key = _make_key("k-1")
self.team = _make_team()
def test_returns_signed_jwt_string(self):
token = mint_team_jwt(self.team)
self.assertIsInstance(token, str)
# Three segments, HS256 header.
self.assertEqual(token.count("."), 2)
header = pyjwt.get_unverified_header(token)
self.assertEqual(header["alg"], "HS256")
self.assertEqual(header["kid"], self.key.kid)
def test_payload_matches_spec(self):
before = int(time.time())
token = mint_team_jwt(self.team)
after = int(time.time())
decoded = pyjwt.decode(
token,
bytes.fromhex(self.key.secret_hex),
algorithms=["HS256"],
options={"verify_aud": False},
)
self.assertEqual(decoded["iss"], "mnemosyne")
self.assertEqual(decoded["aud"], "mnemosyne")
self.assertEqual(decoded["typ"], "team")
self.assertEqual(decoded["sub"], f"team:{self.team.id}")
self.assertEqual(decoded["jti"], str(self.team.active_jti))
# iat is "now-ish" (inclusive bounds because the call is fast).
self.assertGreaterEqual(decoded["iat"], before)
self.assertLessEqual(decoded["iat"], after)
# exp is 10 years in the future.
ten_years = 10 * 365 * 24 * 60 * 60
self.assertEqual(decoded["exp"], decoded["iat"] + ten_years)
def test_picks_newest_active_signing_key(self):
newer = _make_key("k-2")
token = mint_team_jwt(self.team)
header = pyjwt.get_unverified_header(token)
self.assertEqual(header["kid"], newer.kid)
def test_ignores_retired_keys(self):
# Retire the only active key; newer retired key should still
# be skipped — mint fails because there is no active key.
self.key.is_active = False
self.key.save(update_fields=["is_active"])
with self.assertRaises(TeamJWTError):
mint_team_jwt(self.team)
class MintTeamJWTFailureModesTest(TestCase):
def test_no_signing_key_raises(self):
team = _make_team()
with self.assertRaises(TeamJWTError) as ctx:
mint_team_jwt(team)
self.assertIn("signing", str(ctx.exception).lower())
def test_missing_active_jti_raises(self):
_make_key()
team = _make_team(active_jti=None)
with self.assertRaises(TeamJWTError) as ctx:
mint_team_jwt(team)
# Should name the thing the caller forgot to do.
self.assertIn("rotate_jti", str(ctx.exception))
def test_invalid_hex_secret_raises(self):
MCPSigningKey.objects.create(
kid="broken", secret_hex="not-hex!!", is_active=True
)
team = _make_team()
with self.assertRaises(TeamJWTError):
mint_team_jwt(team)