Files
mnemosyne/mnemosyne/mcp_server/tests/test_teams_jwt.py
Robert Helewka 6a4fecf488
All checks were successful
CVE Scan & Docker Build / security-scan (push) Successful in 50s
CVE Scan & Docker Build / build-and-push (push) Successful in 2m16s
fix(mcp): disable audience verification in resolve_mcp_jwt
Team JWTs include `aud=mnemosyne` while per-turn JWTs omit `aud`
entirely. Since `iss` + `typ` already partition the two token
populations, explicitly skip audience verification to avoid rejecting
valid tokens.

Also expand test coverage for the MCP auth surface to exercise all
three credential types (opaque MCPToken, per-turn JWT, team JWT),
including replay cache behavior and Neo4j-backed library resolution
via mocked cypher queries.
2026-05-10 12:32:58 -04:00

112 lines
3.7 KiB
Python

"""Tests for :mod:`mcp_server.teams` — team-JWT mint.
Covers the mint path in isolation. The *validation* half of the
round-trip (``resolve_mcp_jwt`` with ``typ=team``) lives in
``test_auth.py``; this module asserts what ``mint_team_jwt`` encodes
into the token header / payload so the two halves match up.
"""
from __future__ import annotations
import time
import uuid
import jwt as pyjwt
from django.test import TestCase
from mcp_server.models import MCPSigningKey, Team
from mcp_server.teams import TeamJWTError, mint_team_jwt
def _make_key(kid: str = "k", is_active: bool = True) -> MCPSigningKey:
return MCPSigningKey.objects.create(
kid=kid, secret_hex="a" * 64, is_active=is_active
)
def _make_team(**overrides) -> Team:
data = dict(
id=uuid.uuid4(),
name="t",
active=True,
active_jti=uuid.uuid4(),
)
data.update(overrides)
return Team.objects.create(**data)
class MintTeamJWTHappyPathTest(TestCase):
def setUp(self):
self.key = _make_key("k-1")
self.team = _make_team()
def test_returns_signed_jwt_string(self):
token = mint_team_jwt(self.team)
self.assertIsInstance(token, str)
# Three segments, HS256 header.
self.assertEqual(token.count("."), 2)
header = pyjwt.get_unverified_header(token)
self.assertEqual(header["alg"], "HS256")
self.assertEqual(header["kid"], self.key.kid)
def test_payload_matches_spec(self):
before = int(time.time())
token = mint_team_jwt(self.team)
after = int(time.time())
decoded = pyjwt.decode(
token,
bytes.fromhex(self.key.secret_hex),
algorithms=["HS256"],
options={"verify_aud": False},
)
self.assertEqual(decoded["iss"], "mnemosyne")
self.assertEqual(decoded["aud"], "mnemosyne")
self.assertEqual(decoded["typ"], "team")
self.assertEqual(decoded["sub"], f"team:{self.team.id}")
self.assertEqual(decoded["jti"], str(self.team.active_jti))
# iat is "now-ish" (inclusive bounds because the call is fast).
self.assertGreaterEqual(decoded["iat"], before)
self.assertLessEqual(decoded["iat"], after)
# exp is 10 years in the future.
ten_years = 10 * 365 * 24 * 60 * 60
self.assertEqual(decoded["exp"], decoded["iat"] + ten_years)
def test_picks_newest_active_signing_key(self):
newer = _make_key("k-2")
token = mint_team_jwt(self.team)
header = pyjwt.get_unverified_header(token)
self.assertEqual(header["kid"], newer.kid)
def test_ignores_retired_keys(self):
# Retire the only active key; newer retired key should still
# be skipped — mint fails because there is no active key.
self.key.is_active = False
self.key.save(update_fields=["is_active"])
with self.assertRaises(TeamJWTError):
mint_team_jwt(self.team)
class MintTeamJWTFailureModesTest(TestCase):
def test_no_signing_key_raises(self):
team = _make_team()
with self.assertRaises(TeamJWTError) as ctx:
mint_team_jwt(team)
self.assertIn("signing", str(ctx.exception).lower())
def test_missing_active_jti_raises(self):
_make_key()
team = _make_team(active_jti=None)
with self.assertRaises(TeamJWTError) as ctx:
mint_team_jwt(team)
# Should name the thing the caller forgot to do.
self.assertIn("rotate_jti", str(ctx.exception))
def test_invalid_hex_secret_raises(self):
MCPSigningKey.objects.create(
kid="broken", secret_hex="not-hex!!", is_active=True
)
team = _make_team()
with self.assertRaises(TeamJWTError):
mint_team_jwt(team)