Introduces `pallas.loop_guard` module that detects and halts agentic loops
where the same `(tool, args) → result` repeats consecutively, preventing
wasted LLM turns when upstream MCP servers return contradictory data.
- Add per-request `ToolRunnerHooks` tracking rolling tool-call signatures
- Halt loop after `loop_repeat_threshold` consecutive repeats (default 3)
- Collapse `max_iterations` on halt to terminate without further LLM call
- Append user-facing explanation to the turn with `stop_reason=endTurn`
- Expose `pallas_agent_loop_aborted_total{agent,reason}` counter
- Add per-agent `max_iterations` and `loop_repeat_threshold` config
- Document guard behavior, metric, and alerting query
232 lines
8.6 KiB
Python
232 lines
8.6 KiB
Python
"""Runaway-loop detection for the agentic tool loop.
|
|
|
|
A small model occasionally gets stuck emitting the *identical* tool call
|
|
every iteration — typically because an upstream MCP server returned a
|
|
contradictory or malformed result the model keeps trying to reconcile.
|
|
Left alone the loop burns LLM turns and context until the client times
|
|
out and the user sees ``empty_response``.
|
|
|
|
This guard installs ``ToolRunnerHooks`` on the request-scoped agent that
|
|
track a per-turn signature of ``(tool, normalized_args) -> result_hash``.
|
|
When the same signature repeats ``threshold`` times consecutively the
|
|
loop is halted **immediately** — we don't ask the model to troubleshoot,
|
|
because the fault is almost always upstream and self-recovery is slow,
|
|
unpredictable, and token-hungry. Instead we:
|
|
|
|
* force fast-agent's own ``max_iterations`` termination path so the turn
|
|
ends after the current tool result with a partial answer rather than a
|
|
client timeout;
|
|
* replace the dangling turn with an honest, user-facing explanation;
|
|
* log the offending tool, arguments, and result with full detail so the
|
|
real (upstream) bug can be fixed durably; and
|
|
* increment ``pallas_agent_loop_aborted_total`` for alerting.
|
|
|
|
The hooks compose after any already-installed hooks (e.g. the assistant
|
|
stream) via the same merge strategy ``assistant_stream`` uses.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
from fast_agent.agents.tool_runner import ToolRunnerHooks
|
|
from fast_agent.types import PromptMessageExtended
|
|
from fast_agent.types.llm_stop_reason import LlmStopReason
|
|
from mcp.types import TextContent
|
|
|
|
from pallas import metrics as _pallas_metrics
|
|
from pallas.assistant_stream import _merge_hooks, _split_server
|
|
from pallas.progress import format_args_preview, format_result_preview
|
|
|
|
logger = logging.getLogger("pallas.loop_guard")
|
|
|
|
DEFAULT_THRESHOLD = 3
|
|
|
|
|
|
def _normalize_args(arguments: Any) -> str:
|
|
"""Deterministically serialize tool arguments for signature comparison."""
|
|
try:
|
|
return json.dumps(arguments, sort_keys=True, default=str, ensure_ascii=False)
|
|
except (TypeError, ValueError):
|
|
return repr(arguments)
|
|
|
|
|
|
def _result_hash(result: Any) -> str:
|
|
"""Stable hash of a tool result's content for repeat detection."""
|
|
parts: list[str] = []
|
|
for block in getattr(result, "content", None) or []:
|
|
text = getattr(block, "text", None)
|
|
parts.append(text if isinstance(text, str) else repr(block))
|
|
parts.append(f"isError={bool(getattr(result, 'isError', False))}")
|
|
digest = hashlib.sha256("\x00".join(parts).encode("utf-8", "replace"))
|
|
return digest.hexdigest()
|
|
|
|
|
|
class LoopGuard:
|
|
"""Per-request consecutive-identical-tool-call detector."""
|
|
|
|
def __init__(
|
|
self, *, agent_name: str, conversation_id: str | None, threshold: int
|
|
) -> None:
|
|
self._agent_name = agent_name
|
|
self._conversation_id = conversation_id
|
|
self._threshold = threshold
|
|
# call_id -> (tool_name, normalized_args, raw_arguments), staged by
|
|
# before_tool_call and consumed by the matching after_tool_call.
|
|
self._pending: dict[str, tuple[str | None, str, Any]] = {}
|
|
self._last_signature: str | None = None
|
|
self._repeat_count = 0
|
|
self._halted = False
|
|
|
|
def as_before_tool_call_hook(self):
|
|
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
|
|
for call_id, call in (message.tool_calls or {}).items():
|
|
params = getattr(call, "params", None)
|
|
name = getattr(params, "name", None)
|
|
arguments = getattr(params, "arguments", None)
|
|
self._pending[call_id] = (name, _normalize_args(arguments), arguments)
|
|
|
|
return _hook
|
|
|
|
def as_after_tool_call_hook(self):
|
|
async def _hook(runner: Any, message: PromptMessageExtended) -> None:
|
|
if self._halted:
|
|
return
|
|
try:
|
|
self._evaluate(runner, message)
|
|
except Exception: # never let the guard break a live turn
|
|
logger.warning(
|
|
"loop_guard evaluation failed",
|
|
exc_info=True,
|
|
extra={
|
|
"agent": self._agent_name,
|
|
"conversation_id": self._conversation_id,
|
|
},
|
|
)
|
|
|
|
return _hook
|
|
|
|
def _evaluate(self, runner: Any, message: PromptMessageExtended) -> None:
|
|
results = message.tool_results or {}
|
|
if not results:
|
|
return
|
|
|
|
components: list[tuple[str, str, str]] = []
|
|
names: list[str] = []
|
|
raw_args: list[Any] = []
|
|
for call_id, result in results.items():
|
|
name, args_sig, arguments = self._pending.pop(call_id, (None, "null", None))
|
|
components.append((name or "", args_sig, _result_hash(result)))
|
|
if name:
|
|
names.append(name)
|
|
raw_args.append(arguments)
|
|
components.sort()
|
|
signature = "|".join(f"{n}:{a}:{r}" for n, a, r in components)
|
|
|
|
if signature == self._last_signature:
|
|
self._repeat_count += 1
|
|
else:
|
|
self._last_signature = signature
|
|
self._repeat_count = 1
|
|
|
|
if self._repeat_count >= self._threshold:
|
|
tool_label = ", ".join(dict.fromkeys(names)) or "unknown"
|
|
args_preview = format_args_preview(raw_args[0]) if raw_args else ""
|
|
self._halt(runner, results, tool_label, args_preview)
|
|
|
|
def _halt(
|
|
self,
|
|
runner: Any,
|
|
results: dict[str, Any],
|
|
tool_label: str,
|
|
args_preview: str,
|
|
) -> None:
|
|
self._halted = True
|
|
|
|
first = next(iter(results.values()), None)
|
|
result_preview = (
|
|
format_result_preview(list(getattr(first, "content", []) or []))
|
|
if first is not None
|
|
else ""
|
|
)
|
|
|
|
# Force fast-agent's own termination check (tool_runner: `_iteration >
|
|
# max_iterations`) to fire after this tool result — no further LLM
|
|
# call, partial answer returned instead of a client timeout.
|
|
params = getattr(runner, "request_params", None)
|
|
iteration = getattr(runner, "iteration", 0)
|
|
if params is not None:
|
|
params.max_iterations = iteration
|
|
|
|
# Replace the dangling tool-use turn with an honest final message so
|
|
# the client gets an explanation rather than an empty/truncated turn.
|
|
note = (
|
|
f"Halted: the '{tool_label}' tool returned an identical result "
|
|
f"{self._repeat_count} times in a row, so the agent was looping "
|
|
"without making progress. This is usually an upstream data or "
|
|
"tool-server issue rather than a problem with the request. "
|
|
"Stopping early to avoid a runaway loop."
|
|
)
|
|
last = getattr(runner, "last_message", None)
|
|
if last is not None:
|
|
if last.content is None:
|
|
last.content = []
|
|
last.content.append(TextContent(type="text", text=note))
|
|
last.stop_reason = LlmStopReason.END_TURN
|
|
|
|
logger.warning(
|
|
"agentic loop halted: identical tool call repeated",
|
|
extra={
|
|
"event": "loop_halt",
|
|
"agent": self._agent_name,
|
|
"conversation_id": self._conversation_id,
|
|
"tool": tool_label,
|
|
"server": _split_server(tool_label),
|
|
"repeat_count": self._repeat_count,
|
|
"threshold": self._threshold,
|
|
"iteration": iteration,
|
|
"arguments_preview": args_preview,
|
|
"result_preview": result_preview,
|
|
},
|
|
)
|
|
_pallas_metrics.record_loop_abort(self._agent_name, "repeat")
|
|
|
|
|
|
def install_for_request(
|
|
agent: Any,
|
|
*,
|
|
agent_name: str,
|
|
conversation_id: str | None,
|
|
threshold: int = DEFAULT_THRESHOLD,
|
|
):
|
|
"""Install the loop guard on a request-scoped agent instance.
|
|
|
|
Returns a no-arg ``restore`` callable that reinstates the agent's prior
|
|
``tool_runner_hooks`` — call it in a ``finally``. A non-positive
|
|
``threshold`` disables the guard (returns a no-op restore).
|
|
"""
|
|
if threshold is None or threshold < 1:
|
|
return lambda: None
|
|
|
|
guard = LoopGuard(
|
|
agent_name=agent_name,
|
|
conversation_id=conversation_id,
|
|
threshold=threshold,
|
|
)
|
|
|
|
extra = ToolRunnerHooks(
|
|
before_tool_call=guard.as_before_tool_call_hook(),
|
|
after_tool_call=guard.as_after_tool_call_hook(),
|
|
)
|
|
|
|
previous = getattr(agent, "tool_runner_hooks", None)
|
|
agent.tool_runner_hooks = _merge_hooks(previous, extra)
|
|
|
|
def restore() -> None:
|
|
agent.tool_runner_hooks = previous
|
|
|
|
return restore
|