- Rename MCPToken to UserToken across models, views, and tests - Update URL names from mcp-token-* to token-* - Add Daedalus/Pallas integration design doc (v2) - Switch docker-compose to build local mnemosyne:local image via shared build config instead of pulling from git.helu.ca
543 lines
18 KiB
Python
543 lines
18 KiB
Python
"""Tests for the MCP auth surface.
|
|
|
|
Covers all three credential types described in
|
|
``docs/DAEDALUS_PALLAS_INTEGRATION_v1.md`` §3.2:
|
|
|
|
1. Opaque :class:`~mcp_server.models.UserToken` — ``resolve_mcp_user``
|
|
+ ``MCPAuthMiddleware`` opaque branch.
|
|
2. Per-turn JWT (``iss=daedalus``, legacy) — ``resolve_mcp_jwt`` normal
|
|
path, ``_remember_jti`` replay cache, ``claims["libs"]``-derived
|
|
resolved_libraries.
|
|
3. Team JWT (``iss=mnemosyne typ=team``) — ``resolve_mcp_jwt`` team
|
|
branch (bypasses replay cache), ``_libraries_for_team`` enforcing
|
|
``Team.active`` + ``Team.active_jti``, ``resolved_libraries`` coming
|
|
from a mocked Neo4j lookup.
|
|
|
|
Neo4j is stubbed by patching ``neomodel.db.cypher_query`` — these tests
|
|
do not require a running Neo4j instance.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
import uuid
|
|
from datetime import timedelta
|
|
from unittest import mock
|
|
|
|
import jwt as pyjwt
|
|
from django.contrib.auth import get_user_model
|
|
from django.test import TestCase
|
|
from django.utils import timezone
|
|
|
|
from mcp_server import auth as auth_module
|
|
from mcp_server.auth import (
|
|
MCPAuthError,
|
|
_libraries_for_team,
|
|
_remember_jti,
|
|
_resolve_jwt_actor,
|
|
_resolved_libraries_for_jwt,
|
|
looks_like_jwt,
|
|
resolve_mcp_jwt,
|
|
resolve_mcp_user,
|
|
)
|
|
from mcp_server.models import (
|
|
MCPSigningKey,
|
|
UserToken,
|
|
Team,
|
|
TeamWorkspaceAssignment,
|
|
)
|
|
|
|
User = get_user_model()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Opaque UserToken
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class ResolveMCPUserTest(TestCase):
|
|
"""The pre-existing coverage of the opaque bearer path."""
|
|
|
|
def setUp(self):
|
|
self.user = User.objects.create_user(
|
|
username="bob", email="bob@example.com", password="pw"
|
|
)
|
|
self.token, self.plaintext = UserToken.objects.create_token(
|
|
user=self.user, name="t"
|
|
)
|
|
|
|
def test_resolves_valid_token(self):
|
|
user, token = resolve_mcp_user(self.plaintext)
|
|
self.assertEqual(user.pk, self.user.pk)
|
|
self.assertEqual(token.pk, self.token.pk)
|
|
|
|
def test_records_usage(self):
|
|
self.assertIsNone(self.token.last_used_at)
|
|
resolve_mcp_user(self.plaintext)
|
|
self.token.refresh_from_db()
|
|
self.assertIsNotNone(self.token.last_used_at)
|
|
|
|
def test_invalid_token_raises(self):
|
|
with self.assertRaises(MCPAuthError):
|
|
resolve_mcp_user("not-a-real-token")
|
|
|
|
def test_inactive_token_raises(self):
|
|
self.token.is_active = False
|
|
self.token.save()
|
|
with self.assertRaises(MCPAuthError):
|
|
resolve_mcp_user(self.plaintext)
|
|
|
|
def test_expired_token_raises(self):
|
|
self.token.expires_at = timezone.now() - timedelta(hours=1)
|
|
self.token.save()
|
|
with self.assertRaises(MCPAuthError):
|
|
resolve_mcp_user(self.plaintext)
|
|
|
|
def test_disabled_user_raises(self):
|
|
self.user.is_active = False
|
|
self.user.save()
|
|
with self.assertRaises(MCPAuthError):
|
|
resolve_mcp_user(self.plaintext)
|
|
|
|
def test_plaintext_not_in_db(self):
|
|
# Defense in depth: scan every column for the plaintext value.
|
|
from django.db import connection
|
|
|
|
plaintext = self.plaintext
|
|
with connection.cursor() as cur:
|
|
cur.execute("SELECT * FROM mcp_server_usertoken")
|
|
rows = cur.fetchall()
|
|
for row in rows:
|
|
for value in row:
|
|
self.assertNotEqual(
|
|
value, plaintext,
|
|
f"Plaintext token leaked into the database: {value!r}",
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# JWT plumbing helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_signing_key() -> MCPSigningKey:
|
|
"""Fresh active HS256 key with a hex secret the auth module accepts."""
|
|
secret_hex = "a" * 64 # 256 bits / 64 hex chars
|
|
return MCPSigningKey.objects.create(
|
|
kid=f"test-kid-{uuid.uuid4().hex[:8]}",
|
|
secret_hex=secret_hex,
|
|
is_active=True,
|
|
)
|
|
|
|
|
|
def _encode(payload: dict, key: MCPSigningKey) -> str:
|
|
return pyjwt.encode(
|
|
payload,
|
|
bytes.fromhex(key.secret_hex),
|
|
algorithm="HS256",
|
|
headers={"kid": key.kid},
|
|
)
|
|
|
|
|
|
def _per_turn_payload(**overrides) -> dict:
|
|
now = int(time.time())
|
|
payload = {
|
|
"iss": "daedalus",
|
|
"sub": "user:someone",
|
|
"iat": now,
|
|
"exp": now + 300,
|
|
"jti": str(uuid.uuid4()),
|
|
"libs": ["lib-a", "lib-b"],
|
|
"ws": "ws-42",
|
|
}
|
|
payload.update(overrides)
|
|
return payload
|
|
|
|
|
|
def _team_payload(team: Team, **overrides) -> dict:
|
|
now = int(time.time())
|
|
payload = {
|
|
"iss": "mnemosyne",
|
|
"aud": "mnemosyne",
|
|
"sub": f"team:{team.id}",
|
|
"typ": "team",
|
|
"iat": now,
|
|
"exp": now + 3600,
|
|
"jti": str(team.active_jti),
|
|
}
|
|
payload.update(overrides)
|
|
return payload
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# looks_like_jwt
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class LooksLikeJWTTest(TestCase):
|
|
def test_three_segments_with_json_header(self):
|
|
key = _make_signing_key()
|
|
token = _encode(_per_turn_payload(), key)
|
|
self.assertTrue(looks_like_jwt(token))
|
|
|
|
def test_two_segments_rejected(self):
|
|
self.assertFalse(looks_like_jwt("aaa.bbb"))
|
|
|
|
def test_four_segments_rejected(self):
|
|
self.assertFalse(looks_like_jwt("aaa.bbb.ccc.ddd"))
|
|
|
|
def test_garbage_header_rejected(self):
|
|
self.assertFalse(looks_like_jwt("!!!.bbb.ccc"))
|
|
|
|
def test_opaque_token_rejected(self):
|
|
# Real ``UserToken.create_token`` plaintext is 48-byte base64, often
|
|
# contains dashes but never two dots.
|
|
self.assertFalse(
|
|
looks_like_jwt("CxGb3rThJ7_4jUGl0q2_fakey_fakey_fakey_fakey_fakey")
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Per-turn JWT branch (iss=daedalus)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class ResolvePerTurnJWTTest(TestCase):
|
|
def setUp(self):
|
|
self.key = _make_signing_key()
|
|
# Make sure the replay cache is clean across tests; it is a
|
|
# module-level dict.
|
|
auth_module._JTI_CACHE.clear()
|
|
|
|
def test_happy_path_returns_normalized_claims(self):
|
|
token = _encode(_per_turn_payload(), self.key)
|
|
claims = resolve_mcp_jwt(token)
|
|
self.assertEqual(claims["iss"], "daedalus")
|
|
self.assertEqual(claims["libs"], ["lib-a", "lib-b"])
|
|
self.assertEqual(claims["ws"], "ws-42")
|
|
self.assertEqual(claims["kid"], self.key.kid)
|
|
self.assertNotIn("typ", claims) # per-turn tokens omit typ
|
|
self.assertNotIn("team_id", claims)
|
|
|
|
def test_libs_defaults_to_empty_list(self):
|
|
payload = _per_turn_payload()
|
|
payload.pop("libs", None)
|
|
token = _encode(payload, self.key)
|
|
claims = resolve_mcp_jwt(token)
|
|
self.assertEqual(claims["libs"], [])
|
|
|
|
def test_ws_defaults_to_none(self):
|
|
payload = _per_turn_payload()
|
|
payload.pop("ws", None)
|
|
token = _encode(payload, self.key)
|
|
claims = resolve_mcp_jwt(token)
|
|
self.assertIsNone(claims["ws"])
|
|
|
|
def test_invalid_libs_shape_rejected(self):
|
|
payload = _per_turn_payload(libs="not-a-list")
|
|
token = _encode(payload, self.key)
|
|
with self.assertRaises(MCPAuthError):
|
|
resolve_mcp_jwt(token)
|
|
|
|
def test_unknown_issuer_rejected(self):
|
|
payload = _per_turn_payload(iss="attacker")
|
|
token = _encode(payload, self.key)
|
|
with self.assertRaises(MCPAuthError):
|
|
resolve_mcp_jwt(token)
|
|
|
|
def test_expired_token_rejected(self):
|
|
payload = _per_turn_payload(
|
|
iat=int(time.time()) - 600,
|
|
exp=int(time.time()) - 60,
|
|
)
|
|
token = _encode(payload, self.key)
|
|
with self.assertRaises(MCPAuthError):
|
|
resolve_mcp_jwt(token)
|
|
|
|
def test_retired_signing_key_rejected(self):
|
|
self.key.is_active = False
|
|
self.key.save(update_fields=["is_active"])
|
|
token = _encode(_per_turn_payload(), self.key)
|
|
with self.assertRaises(MCPAuthError):
|
|
resolve_mcp_jwt(token)
|
|
|
|
def test_unknown_kid_rejected(self):
|
|
foreign_secret_hex = "b" * 64
|
|
token = pyjwt.encode(
|
|
_per_turn_payload(),
|
|
bytes.fromhex(foreign_secret_hex),
|
|
algorithm="HS256",
|
|
headers={"kid": "no-such-kid"},
|
|
)
|
|
with self.assertRaises(MCPAuthError):
|
|
resolve_mcp_jwt(token)
|
|
|
|
def test_bad_signature_rejected(self):
|
|
# Re-sign with a different secret but reuse the stored kid.
|
|
foreign_secret_hex = "b" * 64
|
|
token = pyjwt.encode(
|
|
_per_turn_payload(),
|
|
bytes.fromhex(foreign_secret_hex),
|
|
algorithm="HS256",
|
|
headers={"kid": self.key.kid},
|
|
)
|
|
with self.assertRaises(MCPAuthError):
|
|
resolve_mcp_jwt(token)
|
|
|
|
|
|
class PerTurnReplayCacheTest(TestCase):
|
|
"""``_remember_jti`` flags a jti as replay iff we see it *after* its exp."""
|
|
|
|
def setUp(self):
|
|
auth_module._JTI_CACHE.clear()
|
|
|
|
def test_first_sighting_not_replay(self):
|
|
self.assertFalse(_remember_jti("jti-1", time.time() + 300))
|
|
|
|
def test_same_jti_within_validity_not_replay(self):
|
|
# Daedalus routinely reuses a per-turn jti across multiple tool
|
|
# calls in a single turn; that must not trip the replay gate.
|
|
exp = time.time() + 300
|
|
_remember_jti("jti-2", exp)
|
|
self.assertFalse(_remember_jti("jti-2", exp))
|
|
self.assertFalse(_remember_jti("jti-2", exp))
|
|
|
|
def test_same_jti_after_exp_is_replay(self):
|
|
exp = time.time() - 60 # already past
|
|
# First sighting just records; second is the belt-and-braces trip.
|
|
_remember_jti("jti-3", exp)
|
|
self.assertTrue(_remember_jti("jti-3", exp))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Team JWT branch (iss=mnemosyne, typ=team)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class ResolveTeamJWTTest(TestCase):
|
|
@classmethod
|
|
def setUpTestData(cls):
|
|
cls.owner = User.objects.create_user(username="alice", password="pw")
|
|
|
|
def setUp(self):
|
|
self.key = _make_signing_key()
|
|
self.team = Team.objects.create(
|
|
id=uuid.uuid4(),
|
|
name="Pallas-Harper",
|
|
owner=self.owner,
|
|
active=True,
|
|
active_jti=uuid.uuid4(),
|
|
)
|
|
|
|
def test_team_jwt_happy_path_populates_team_id(self):
|
|
token = _encode(_team_payload(self.team), self.key)
|
|
claims = resolve_mcp_jwt(token)
|
|
self.assertEqual(claims["typ"], "team")
|
|
self.assertEqual(claims["team_id"], self.team.id)
|
|
# Team tokens must never be normalized into the per-turn shape.
|
|
self.assertNotIn("libs", claims)
|
|
self.assertNotIn("ws", claims)
|
|
|
|
def test_malformed_sub_rejected(self):
|
|
token = _encode(
|
|
_team_payload(self.team, sub="not-a-team-sub"), self.key
|
|
)
|
|
with self.assertRaises(MCPAuthError):
|
|
resolve_mcp_jwt(token)
|
|
|
|
def test_non_uuid_sub_rejected(self):
|
|
token = _encode(
|
|
_team_payload(self.team, sub="team:not-a-uuid"), self.key
|
|
)
|
|
with self.assertRaises(MCPAuthError):
|
|
resolve_mcp_jwt(token)
|
|
|
|
def test_team_jwt_bypasses_replay_cache(self):
|
|
"""``typ=team`` tokens are intentionally reused; no jti pollution."""
|
|
auth_module._JTI_CACHE.clear()
|
|
token = _encode(_team_payload(self.team), self.key)
|
|
resolve_mcp_jwt(token)
|
|
resolve_mcp_jwt(token) # a second validation must succeed
|
|
# And the replay cache must not have been touched.
|
|
self.assertEqual(auth_module._JTI_CACHE, {})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _libraries_for_team — Postgres + mocked Neo4j
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class LibrariesForTeamTest(TestCase):
|
|
@classmethod
|
|
def setUpTestData(cls):
|
|
cls.owner = User.objects.create_user(username="alice", password="pw")
|
|
|
|
def setUp(self):
|
|
self.team = Team.objects.create(
|
|
id=uuid.uuid4(),
|
|
name="T",
|
|
owner=self.owner,
|
|
active=True,
|
|
active_jti=uuid.uuid4(),
|
|
)
|
|
|
|
def test_unknown_team_rejected(self):
|
|
with self.assertRaises(MCPAuthError):
|
|
_libraries_for_team(uuid.uuid4(), str(uuid.uuid4()))
|
|
|
|
def test_inactive_team_rejected(self):
|
|
self.team.active = False
|
|
self.team.save(update_fields=["active"])
|
|
with self.assertRaises(MCPAuthError):
|
|
_libraries_for_team(self.team.id, str(self.team.active_jti))
|
|
|
|
def test_stale_jti_rejected(self):
|
|
# Incoming jti no longer matches — rotation happened since mint.
|
|
with self.assertRaises(MCPAuthError):
|
|
_libraries_for_team(self.team.id, str(uuid.uuid4()))
|
|
|
|
def test_no_workspace_assignments_returns_empty_list(self):
|
|
# Fail-closed: a team with no assignments sees no libraries.
|
|
libs = _libraries_for_team(self.team.id, str(self.team.active_jti))
|
|
self.assertEqual(libs, [])
|
|
|
|
def test_workspace_assignments_translated_to_library_uids(self):
|
|
TeamWorkspaceAssignment.objects.create(team=self.team, workspace_id="ws-a")
|
|
TeamWorkspaceAssignment.objects.create(team=self.team, workspace_id="ws-b")
|
|
fake_rows = [["lib-1"], ["lib-2"], ["lib-3"]]
|
|
with mock.patch(
|
|
"neomodel.db.cypher_query",
|
|
return_value=(fake_rows, ["l.uid"]),
|
|
) as cypher:
|
|
libs = _libraries_for_team(
|
|
self.team.id, str(self.team.active_jti)
|
|
)
|
|
self.assertEqual(sorted(libs), ["lib-1", "lib-2", "lib-3"])
|
|
# Assert we pass the workspace ids as the ``ws`` parameter, which
|
|
# the Cypher query projects through to the Neo4j driver.
|
|
(args, kwargs) = cypher.call_args
|
|
self.assertIn("WHERE l.workspace_id IN $ws", args[0])
|
|
self.assertEqual(sorted(args[1]["ws"]), ["ws-a", "ws-b"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _resolved_libraries_for_jwt dispatcher
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class ResolvedLibrariesForJWTDispatcherTest(TestCase):
|
|
@classmethod
|
|
def setUpTestData(cls):
|
|
cls.owner = User.objects.create_user(username="alice", password="pw")
|
|
|
|
def test_per_turn_claims_use_libs(self):
|
|
claims = {"libs": ["one", "two"]}
|
|
self.assertEqual(
|
|
_resolved_libraries_for_jwt(claims), ["one", "two"]
|
|
)
|
|
|
|
def test_team_claims_use_team_lookup(self):
|
|
team = Team.objects.create(
|
|
id=uuid.uuid4(),
|
|
name="T",
|
|
owner=self.owner,
|
|
active=True,
|
|
active_jti=uuid.uuid4(),
|
|
)
|
|
claims = {
|
|
"typ": "team",
|
|
"team_id": team.id,
|
|
"jti": str(team.active_jti),
|
|
}
|
|
with mock.patch(
|
|
"neomodel.db.cypher_query",
|
|
return_value=([["lib-team"]], ["l.uid"]),
|
|
):
|
|
TeamWorkspaceAssignment.objects.create(
|
|
team=team, workspace_id="ws-x"
|
|
)
|
|
out = _resolved_libraries_for_jwt(claims)
|
|
self.assertEqual(out, ["lib-team"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _resolve_jwt_actor — team JWTs resolve to the team's owner
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class ResolveJWTActorTest(TestCase):
|
|
@classmethod
|
|
def setUpTestData(cls):
|
|
cls.owner = User.objects.create_user(username="alice", password="pw")
|
|
|
|
def setUp(self):
|
|
self.team = Team.objects.create(
|
|
id=uuid.uuid4(),
|
|
name="T",
|
|
owner=self.owner,
|
|
active=True,
|
|
active_jti=uuid.uuid4(),
|
|
)
|
|
|
|
def test_team_jwt_resolves_to_team_owner(self):
|
|
claims = {"typ": "team", "team_id": self.team.id}
|
|
self.assertEqual(_resolve_jwt_actor(claims), self.owner)
|
|
|
|
def test_inactive_team_rejected(self):
|
|
self.team.deactivate()
|
|
claims = {"typ": "team", "team_id": self.team.id}
|
|
with self.assertRaises(MCPAuthError):
|
|
_resolve_jwt_actor(claims)
|
|
|
|
def test_disabled_owner_rejected(self):
|
|
self.owner.is_active = False
|
|
self.owner.save(update_fields=["is_active"])
|
|
claims = {"typ": "team", "team_id": self.team.id}
|
|
with self.assertRaises(MCPAuthError):
|
|
_resolve_jwt_actor(claims)
|
|
|
|
def test_per_turn_jwt_rejected(self):
|
|
# No more service-user fallback: per-turn JWTs can't be attributed
|
|
# to a Mnemosyne user, so the resolver refuses them.
|
|
with self.assertRaises(MCPAuthError):
|
|
_resolve_jwt_actor({"libs": ["x"]})
|
|
|
|
def test_unknown_team_rejected(self):
|
|
claims = {"typ": "team", "team_id": uuid.uuid4()}
|
|
with self.assertRaises(MCPAuthError):
|
|
_resolve_jwt_actor(claims)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level import guards (regression tests)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class AuthModuleImportsTest(TestCase):
|
|
"""Pin the imports the runtime depends on.
|
|
|
|
These are regressions waiting to happen: a quick `grep` doesn't
|
|
catch a missing import when the consuming code is only reached via
|
|
a runtime path (FastMCP middleware, async tool dispatch, …) that
|
|
the test suite doesn't exercise end-to-end.
|
|
|
|
Add a check here whenever production fails with a
|
|
``NameError: name 'X' is not defined`` that the test suite missed.
|
|
"""
|
|
|
|
def test_settings_is_importable(self):
|
|
"""``MCPAuthMiddleware.on_call_tool`` reads
|
|
``settings.MCP_REQUIRE_AUTH`` on every tool call (including
|
|
unauthenticated ``get_health`` polls from Pallas). Removing the
|
|
``from django.conf import settings`` import — as happened during
|
|
the v2 token-consolidation cleanup — surfaces as
|
|
``NameError: name 'settings' is not defined`` for *every* MCP
|
|
client. Keep this import alive.
|
|
"""
|
|
from django.conf import settings as dj_settings
|
|
|
|
from mcp_server import auth as auth_module
|
|
|
|
self.assertTrue(hasattr(auth_module, "settings"))
|
|
self.assertIs(auth_module.settings, dj_settings)
|