mcp_auth: allow jti re-use within its exp window
All checks were successful
CVE Scan & Docker Build / security-scan (push) Successful in 1m6s
CVE Scan & Docker Build / build-and-push (push) Successful in 2m27s

Daedalus mints one JWT per chat turn; a turn routinely drives several
Mnemosyne tool calls (list_libraries -> search -> get_document ...)
re-using that same bearer. The old _remember_jti flagged every repeat
as replay, so the 2nd+Nth tool call in each turn failed with
'Token replay detected.'.

Change the cache to store jti -> exp. A repeat within the token's own
validity window is legitimate and allowed. A repeat *past* exp (+ the
symmetric _JWT_LEEWAY_SECONDS PyJWT uses on the signature check) is
a genuine replay and still rejected -- this is belt-and-braces since
PyJWT's own exp check would have already caught an expired token.

Also validate exp is numeric at the call site for defence in depth
against future PyJWT changes to claim shapes.
This commit is contained in:
2026-05-05 22:03:36 -04:00
parent 8b2e2068e0
commit 15d70c2cf9

View File

@@ -50,6 +50,15 @@ _JWT_ISS = "daedalus"
# Bounded LRU of recently-seen jti values to discourage replay within
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
# is best-effort and per-process.
#
# Daedalus mints one JWT per *chat turn*, and a single turn routinely
# drives several tool calls (list_libraries → search → get_document …)
# that all re-use that bearer. Treating any repeat ``jti`` as replay
# breaks the legitimate use case, so we store ``jti -> exp`` and only
# flag a request as replay if the ``jti`` shows up *after* its own
# ``exp`` has passed — that's the scenario PyJWT's own ``exp`` check
# would have already rejected, this is belt-and-braces for clock skew
# or a resurrected captured token.
_JTI_CACHE_MAX = 4096
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
@@ -109,18 +118,41 @@ def _b64url_decode(segment: str) -> bytes:
return base64.urlsafe_b64decode(segment + pad)
def _remember_jti(jti: str) -> bool:
"""Return True if this jti has been seen before in this process."""
def _remember_jti(jti: str, exp: float) -> bool:
"""Record or re-check a jti; return True iff this is a genuine replay.
A "genuine replay" is a ``jti`` we've already recorded whose ``exp``
has passed (+ leeway) — that token *should not* be presentable
anymore and PyJWT's ``exp`` check would normally have rejected it;
this is the belt-and-braces path for clock skew or a captured
token being resurrected slightly after its natural lifetime.
Re-use of the same ``jti`` *within* its validity window is the
intended case: Daedalus mints one token per chat turn and a turn
commonly fires several Mnemosyne tool calls against that same
bearer. That is **not** replay and we return False.
"""
now = time.time()
# Drop any cached jti whose entry is older than 1h — generous given exp ≤ 600s.
cutoff = now - 3600
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < cutoff:
# GC anything whose exp passed more than an hour ago — generous
# given exp ≤ 600s, bounds the cache even for pathological jti
# floods.
gc_cutoff = now - 3600
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < gc_cutoff:
_JTI_CACHE.popitem(last=False)
if jti in _JTI_CACHE:
cached_exp = _JTI_CACHE.get(jti)
if cached_exp is not None:
# Refresh LRU position.
_JTI_CACHE.move_to_end(jti)
# Replay iff the token is already past its own expiry (with
# the same symmetric leeway PyJWT applied on signature check).
if now > cached_exp + _JWT_LEEWAY_SECONDS:
return True
return False
if len(_JTI_CACHE) >= _JTI_CACHE_MAX:
_JTI_CACHE.popitem(last=False)
_JTI_CACHE[jti] = now
_JTI_CACHE[jti] = float(exp)
return False
@@ -175,7 +207,12 @@ def resolve_mcp_jwt(token_string: str) -> dict:
jti = claims.get("jti")
if not isinstance(jti, str) or not jti:
raise MCPAuthError("JWT jti must be a non-empty string.")
if _remember_jti(jti):
exp = claims.get("exp")
if not isinstance(exp, (int, float)):
# ``require=["exp", ...]`` above guarantees presence + numeric; this
# is defence in depth against future PyJWT changes.
raise MCPAuthError("JWT exp must be numeric.")
if _remember_jti(jti, float(exp)):
raise MCPAuthError("Token replay detected.")
# Normalize claim shapes: ws may be null/absent, libs default to [].