diff --git a/mnemosyne/mcp_server/auth.py b/mnemosyne/mcp_server/auth.py index 3a5ff29..4954529 100644 --- a/mnemosyne/mcp_server/auth.py +++ b/mnemosyne/mcp_server/auth.py @@ -50,6 +50,15 @@ _JWT_ISS = "daedalus" # Bounded LRU of recently-seen jti values to discourage replay within # a single Mnemosyne process. Real defense is short ``exp`` + HMAC; this # is best-effort and per-process. +# +# Daedalus mints one JWT per *chat turn*, and a single turn routinely +# drives several tool calls (list_libraries → search → get_document …) +# that all re-use that bearer. Treating any repeat ``jti`` as replay +# breaks the legitimate use case, so we store ``jti -> exp`` and only +# flag a request as replay if the ``jti`` shows up *after* its own +# ``exp`` has passed — that's the scenario PyJWT's own ``exp`` check +# would have already rejected, this is belt-and-braces for clock skew +# or a resurrected captured token. _JTI_CACHE_MAX = 4096 _JTI_CACHE: "OrderedDict[str, float]" = OrderedDict() @@ -109,18 +118,41 @@ def _b64url_decode(segment: str) -> bytes: return base64.urlsafe_b64decode(segment + pad) -def _remember_jti(jti: str) -> bool: - """Return True if this jti has been seen before in this process.""" +def _remember_jti(jti: str, exp: float) -> bool: + """Record or re-check a jti; return True iff this is a genuine replay. + + A "genuine replay" is a ``jti`` we've already recorded whose ``exp`` + has passed (+ leeway) — that token *should not* be presentable + anymore and PyJWT's ``exp`` check would normally have rejected it; + this is the belt-and-braces path for clock skew or a captured + token being resurrected slightly after its natural lifetime. + + Re-use of the same ``jti`` *within* its validity window is the + intended case: Daedalus mints one token per chat turn and a turn + commonly fires several Mnemosyne tool calls against that same + bearer. That is **not** replay and we return False. + """ now = time.time() - # Drop any cached jti whose entry is older than 1h — generous given exp ≤ 600s. - cutoff = now - 3600 - while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < cutoff: + # GC anything whose exp passed more than an hour ago — generous + # given exp ≤ 600s, bounds the cache even for pathological jti + # floods. + gc_cutoff = now - 3600 + while _JTI_CACHE and next(iter(_JTI_CACHE.values())) < gc_cutoff: _JTI_CACHE.popitem(last=False) - if jti in _JTI_CACHE: - return True + + cached_exp = _JTI_CACHE.get(jti) + if cached_exp is not None: + # Refresh LRU position. + _JTI_CACHE.move_to_end(jti) + # Replay iff the token is already past its own expiry (with + # the same symmetric leeway PyJWT applied on signature check). + if now > cached_exp + _JWT_LEEWAY_SECONDS: + return True + return False + if len(_JTI_CACHE) >= _JTI_CACHE_MAX: _JTI_CACHE.popitem(last=False) - _JTI_CACHE[jti] = now + _JTI_CACHE[jti] = float(exp) return False @@ -175,7 +207,12 @@ def resolve_mcp_jwt(token_string: str) -> dict: jti = claims.get("jti") if not isinstance(jti, str) or not jti: raise MCPAuthError("JWT jti must be a non-empty string.") - if _remember_jti(jti): + exp = claims.get("exp") + if not isinstance(exp, (int, float)): + # ``require=["exp", ...]`` above guarantees presence + numeric; this + # is defence in depth against future PyJWT changes. + raise MCPAuthError("JWT exp must be numeric.") + if _remember_jti(jti, float(exp)): raise MCPAuthError("Token replay detected.") # Normalize claim shapes: ws may be null/absent, libs default to [].