mcp_auth: allow jti re-use within its exp window
Daedalus mints one JWT per chat turn; a turn routinely drives several Mnemosyne tool calls (list_libraries -> search -> get_document ...) re-using that same bearer. The old _remember_jti flagged every repeat as replay, so the 2nd+Nth tool call in each turn failed with 'Token replay detected.'. Change the cache to store jti -> exp. A repeat within the token's own validity window is legitimate and allowed. A repeat *past* exp (+ the symmetric _JWT_LEEWAY_SECONDS PyJWT uses on the signature check) is a genuine replay and still rejected -- this is belt-and-braces since PyJWT's own exp check would have already caught an expired token. Also validate exp is numeric at the call site for defence in depth against future PyJWT changes to claim shapes.
This commit is contained in:
@@ -50,6 +50,15 @@ _JWT_ISS = "daedalus"
|
||||
# Bounded LRU of recently-seen jti values to discourage replay within
|
||||
# a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this
|
||||
# is best-effort and per-process.
|
||||
#
|
||||
# Daedalus mints one JWT per *chat turn*, and a single turn routinely
|
||||
# drives several tool calls (list_libraries → search → get_document …)
|
||||
# that all re-use that bearer. Treating any repeat ``jti`` as replay
|
||||
# breaks the legitimate use case, so we store ``jti -> exp`` and only
|
||||
# flag a request as replay if the ``jti`` shows up *after* its own
|
||||
# ``exp`` has passed — that's the scenario PyJWT's own ``exp`` check
|
||||
# would have already rejected, this is belt-and-braces for clock skew
|
||||
# or a resurrected captured token.
|
||||
_JTI_CACHE_MAX = 4096
|
||||
_JTI_CACHE: "OrderedDict[str, float]" = OrderedDict()
|
||||
|
||||
@@ -109,18 +118,41 @@ def _b64url_decode(segment: str) -> bytes:
|
||||
return base64.urlsafe_b64decode(segment + pad)
|
||||
|
||||
|
||||
def _remember_jti(jti: str) -> bool:
|
||||
"""Return True if this jti has been seen before in this process."""
|
||||
def _remember_jti(jti: str, exp: float) -> bool:
|
||||
"""Record or re-check a jti; return True iff this is a genuine replay.
|
||||
|
||||
A "genuine replay" is a ``jti`` we've already recorded whose ``exp``
|
||||
has passed (+ leeway) — that token *should not* be presentable
|
||||
anymore and PyJWT's ``exp`` check would normally have rejected it;
|
||||
this is the belt-and-braces path for clock skew or a captured
|
||||
token being resurrected slightly after its natural lifetime.
|
||||
|
||||
Re-use of the same ``jti`` *within* its validity window is the
|
||||
intended case: Daedalus mints one token per chat turn and a turn
|
||||
commonly fires several Mnemosyne tool calls against that same
|
||||
bearer. That is **not** replay and we return False.
|
||||
"""
|
||||
now = time.time()
|
||||
# Drop any cached jti whose entry is older than 1h — generous given exp ≤ 600s.
|
||||
cutoff = now - 3600
|
||||
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < cutoff:
|
||||
# GC anything whose exp passed more than an hour ago — generous
|
||||
# given exp ≤ 600s, bounds the cache even for pathological jti
|
||||
# floods.
|
||||
gc_cutoff = now - 3600
|
||||
while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < gc_cutoff:
|
||||
_JTI_CACHE.popitem(last=False)
|
||||
if jti in _JTI_CACHE:
|
||||
|
||||
cached_exp = _JTI_CACHE.get(jti)
|
||||
if cached_exp is not None:
|
||||
# Refresh LRU position.
|
||||
_JTI_CACHE.move_to_end(jti)
|
||||
# Replay iff the token is already past its own expiry (with
|
||||
# the same symmetric leeway PyJWT applied on signature check).
|
||||
if now > cached_exp + _JWT_LEEWAY_SECONDS:
|
||||
return True
|
||||
return False
|
||||
|
||||
if len(_JTI_CACHE) >= _JTI_CACHE_MAX:
|
||||
_JTI_CACHE.popitem(last=False)
|
||||
_JTI_CACHE[jti] = now
|
||||
_JTI_CACHE[jti] = float(exp)
|
||||
return False
|
||||
|
||||
|
||||
@@ -175,7 +207,12 @@ def resolve_mcp_jwt(token_string: str) -> dict:
|
||||
jti = claims.get("jti")
|
||||
if not isinstance(jti, str) or not jti:
|
||||
raise MCPAuthError("JWT jti must be a non-empty string.")
|
||||
if _remember_jti(jti):
|
||||
exp = claims.get("exp")
|
||||
if not isinstance(exp, (int, float)):
|
||||
# ``require=["exp", ...]`` above guarantees presence + numeric; this
|
||||
# is defence in depth against future PyJWT changes.
|
||||
raise MCPAuthError("JWT exp must be numeric.")
|
||||
if _remember_jti(jti, float(exp)):
|
||||
raise MCPAuthError("Token replay detected.")
|
||||
|
||||
# Normalize claim shapes: ws may be null/absent, libs default to [].
|
||||
|
||||
Reference in New Issue
Block a user