Team JWTs include `aud=mnemosyne` while per-turn JWTs omit `aud` entirely. Since `iss` + `typ` already partition the two token populations, explicitly skip audience verification to avoid rejecting valid tokens. Also expand test coverage for the MCP auth surface to exercise all three credential types (opaque MCPToken, per-turn JWT, team JWT), including replay cache behavior and Neo4j-backed library resolution via mocked cypher queries.
112 lines
3.7 KiB
Python
112 lines
3.7 KiB
Python
"""Tests for :mod:`mcp_server.teams` — team-JWT mint.
|
|
|
|
Covers the mint path in isolation. The *validation* half of the
|
|
round-trip (``resolve_mcp_jwt`` with ``typ=team``) lives in
|
|
``test_auth.py``; this module asserts what ``mint_team_jwt`` encodes
|
|
into the token header / payload so the two halves match up.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
import uuid
|
|
|
|
import jwt as pyjwt
|
|
from django.test import TestCase
|
|
|
|
from mcp_server.models import MCPSigningKey, Team
|
|
from mcp_server.teams import TeamJWTError, mint_team_jwt
|
|
|
|
|
|
def _make_key(kid: str = "k", is_active: bool = True) -> MCPSigningKey:
|
|
return MCPSigningKey.objects.create(
|
|
kid=kid, secret_hex="a" * 64, is_active=is_active
|
|
)
|
|
|
|
|
|
def _make_team(**overrides) -> Team:
|
|
data = dict(
|
|
id=uuid.uuid4(),
|
|
name="t",
|
|
active=True,
|
|
active_jti=uuid.uuid4(),
|
|
)
|
|
data.update(overrides)
|
|
return Team.objects.create(**data)
|
|
|
|
|
|
class MintTeamJWTHappyPathTest(TestCase):
|
|
def setUp(self):
|
|
self.key = _make_key("k-1")
|
|
self.team = _make_team()
|
|
|
|
def test_returns_signed_jwt_string(self):
|
|
token = mint_team_jwt(self.team)
|
|
self.assertIsInstance(token, str)
|
|
# Three segments, HS256 header.
|
|
self.assertEqual(token.count("."), 2)
|
|
header = pyjwt.get_unverified_header(token)
|
|
self.assertEqual(header["alg"], "HS256")
|
|
self.assertEqual(header["kid"], self.key.kid)
|
|
|
|
def test_payload_matches_spec(self):
|
|
before = int(time.time())
|
|
token = mint_team_jwt(self.team)
|
|
after = int(time.time())
|
|
|
|
decoded = pyjwt.decode(
|
|
token,
|
|
bytes.fromhex(self.key.secret_hex),
|
|
algorithms=["HS256"],
|
|
options={"verify_aud": False},
|
|
)
|
|
self.assertEqual(decoded["iss"], "mnemosyne")
|
|
self.assertEqual(decoded["aud"], "mnemosyne")
|
|
self.assertEqual(decoded["typ"], "team")
|
|
self.assertEqual(decoded["sub"], f"team:{self.team.id}")
|
|
self.assertEqual(decoded["jti"], str(self.team.active_jti))
|
|
# iat is "now-ish" (inclusive bounds because the call is fast).
|
|
self.assertGreaterEqual(decoded["iat"], before)
|
|
self.assertLessEqual(decoded["iat"], after)
|
|
# exp is 10 years in the future.
|
|
ten_years = 10 * 365 * 24 * 60 * 60
|
|
self.assertEqual(decoded["exp"], decoded["iat"] + ten_years)
|
|
|
|
def test_picks_newest_active_signing_key(self):
|
|
newer = _make_key("k-2")
|
|
token = mint_team_jwt(self.team)
|
|
header = pyjwt.get_unverified_header(token)
|
|
self.assertEqual(header["kid"], newer.kid)
|
|
|
|
def test_ignores_retired_keys(self):
|
|
# Retire the only active key; newer retired key should still
|
|
# be skipped — mint fails because there is no active key.
|
|
self.key.is_active = False
|
|
self.key.save(update_fields=["is_active"])
|
|
with self.assertRaises(TeamJWTError):
|
|
mint_team_jwt(self.team)
|
|
|
|
|
|
class MintTeamJWTFailureModesTest(TestCase):
|
|
def test_no_signing_key_raises(self):
|
|
team = _make_team()
|
|
with self.assertRaises(TeamJWTError) as ctx:
|
|
mint_team_jwt(team)
|
|
self.assertIn("signing", str(ctx.exception).lower())
|
|
|
|
def test_missing_active_jti_raises(self):
|
|
_make_key()
|
|
team = _make_team(active_jti=None)
|
|
with self.assertRaises(TeamJWTError) as ctx:
|
|
mint_team_jwt(team)
|
|
# Should name the thing the caller forgot to do.
|
|
self.assertIn("rotate_jti", str(ctx.exception))
|
|
|
|
def test_invalid_hex_secret_raises(self):
|
|
MCPSigningKey.objects.create(
|
|
kid="broken", secret_hex="not-hex!!", is_active=True
|
|
)
|
|
team = _make_team()
|
|
with self.assertRaises(TeamJWTError):
|
|
mint_team_jwt(team)
|