Files
pallas/pallas/loop_guard.py
Robert Helewka ea37ab38c1 feat: add loop guard to halt repeated-identical tool call loops
Introduces `pallas.loop_guard` module that detects and halts agentic loops
where the same `(tool, args) → result` repeats consecutively, preventing
wasted LLM turns when upstream MCP servers return contradictory data.

- Add per-request `ToolRunnerHooks` tracking rolling tool-call signatures
- Halt loop after `loop_repeat_threshold` consecutive repeats (default 3)
- Collapse `max_iterations` on halt to terminate without further LLM call
- Append user-facing explanation to the turn with `stop_reason=endTurn`
- Expose `pallas_agent_loop_aborted_total{agent,reason}` counter
- Add per-agent `max_iterations` and `loop_repeat_threshold` config
- Document guard behavior, metric, and alerting query
2026-06-16 08:27:07 -04:00

232 lines
8.6 KiB
Python

"""Runaway-loop detection for the agentic tool loop.
A small model occasionally gets stuck emitting the *identical* tool call
every iteration — typically because an upstream MCP server returned a
contradictory or malformed result the model keeps trying to reconcile.
Left alone the loop burns LLM turns and context until the client times
out and the user sees ``empty_response``.
This guard installs ``ToolRunnerHooks`` on the request-scoped agent that
track a per-turn signature of ``(tool, normalized_args) -> result_hash``.
When the same signature repeats ``threshold`` times consecutively the
loop is halted **immediately** — we don't ask the model to troubleshoot,
because the fault is almost always upstream and self-recovery is slow,
unpredictable, and token-hungry. Instead we:
* force fast-agent's own ``max_iterations`` termination path so the turn
ends after the current tool result with a partial answer rather than a
client timeout;
* replace the dangling turn with an honest, user-facing explanation;
* log the offending tool, arguments, and result with full detail so the
real (upstream) bug can be fixed durably; and
* increment ``pallas_agent_loop_aborted_total`` for alerting.
The hooks compose after any already-installed hooks (e.g. the assistant
stream) via the same merge strategy ``assistant_stream`` uses.
"""
from __future__ import annotations
import hashlib
import json
import logging
from typing import Any
from fast_agent.agents.tool_runner import ToolRunnerHooks
from fast_agent.types import PromptMessageExtended
from fast_agent.types.llm_stop_reason import LlmStopReason
from mcp.types import TextContent
from pallas import metrics as _pallas_metrics
from pallas.assistant_stream import _merge_hooks, _split_server
from pallas.progress import format_args_preview, format_result_preview
logger = logging.getLogger("pallas.loop_guard")
DEFAULT_THRESHOLD = 3
def _normalize_args(arguments: Any) -> str:
"""Deterministically serialize tool arguments for signature comparison."""
try:
return json.dumps(arguments, sort_keys=True, default=str, ensure_ascii=False)
except (TypeError, ValueError):
return repr(arguments)
def _result_hash(result: Any) -> str:
"""Stable hash of a tool result's content for repeat detection."""
parts: list[str] = []
for block in getattr(result, "content", None) or []:
text = getattr(block, "text", None)
parts.append(text if isinstance(text, str) else repr(block))
parts.append(f"isError={bool(getattr(result, 'isError', False))}")
digest = hashlib.sha256("\x00".join(parts).encode("utf-8", "replace"))
return digest.hexdigest()
class LoopGuard:
"""Per-request consecutive-identical-tool-call detector."""
def __init__(
self, *, agent_name: str, conversation_id: str | None, threshold: int
) -> None:
self._agent_name = agent_name
self._conversation_id = conversation_id
self._threshold = threshold
# call_id -> (tool_name, normalized_args, raw_arguments), staged by
# before_tool_call and consumed by the matching after_tool_call.
self._pending: dict[str, tuple[str | None, str, Any]] = {}
self._last_signature: str | None = None
self._repeat_count = 0
self._halted = False
def as_before_tool_call_hook(self):
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
for call_id, call in (message.tool_calls or {}).items():
params = getattr(call, "params", None)
name = getattr(params, "name", None)
arguments = getattr(params, "arguments", None)
self._pending[call_id] = (name, _normalize_args(arguments), arguments)
return _hook
def as_after_tool_call_hook(self):
async def _hook(runner: Any, message: PromptMessageExtended) -> None:
if self._halted:
return
try:
self._evaluate(runner, message)
except Exception: # never let the guard break a live turn
logger.warning(
"loop_guard evaluation failed",
exc_info=True,
extra={
"agent": self._agent_name,
"conversation_id": self._conversation_id,
},
)
return _hook
def _evaluate(self, runner: Any, message: PromptMessageExtended) -> None:
results = message.tool_results or {}
if not results:
return
components: list[tuple[str, str, str]] = []
names: list[str] = []
raw_args: list[Any] = []
for call_id, result in results.items():
name, args_sig, arguments = self._pending.pop(call_id, (None, "null", None))
components.append((name or "", args_sig, _result_hash(result)))
if name:
names.append(name)
raw_args.append(arguments)
components.sort()
signature = "|".join(f"{n}:{a}:{r}" for n, a, r in components)
if signature == self._last_signature:
self._repeat_count += 1
else:
self._last_signature = signature
self._repeat_count = 1
if self._repeat_count >= self._threshold:
tool_label = ", ".join(dict.fromkeys(names)) or "unknown"
args_preview = format_args_preview(raw_args[0]) if raw_args else ""
self._halt(runner, results, tool_label, args_preview)
def _halt(
self,
runner: Any,
results: dict[str, Any],
tool_label: str,
args_preview: str,
) -> None:
self._halted = True
first = next(iter(results.values()), None)
result_preview = (
format_result_preview(list(getattr(first, "content", []) or []))
if first is not None
else ""
)
# Force fast-agent's own termination check (tool_runner: `_iteration >
# max_iterations`) to fire after this tool result — no further LLM
# call, partial answer returned instead of a client timeout.
params = getattr(runner, "request_params", None)
iteration = getattr(runner, "iteration", 0)
if params is not None:
params.max_iterations = iteration
# Replace the dangling tool-use turn with an honest final message so
# the client gets an explanation rather than an empty/truncated turn.
note = (
f"Halted: the '{tool_label}' tool returned an identical result "
f"{self._repeat_count} times in a row, so the agent was looping "
"without making progress. This is usually an upstream data or "
"tool-server issue rather than a problem with the request. "
"Stopping early to avoid a runaway loop."
)
last = getattr(runner, "last_message", None)
if last is not None:
if last.content is None:
last.content = []
last.content.append(TextContent(type="text", text=note))
last.stop_reason = LlmStopReason.END_TURN
logger.warning(
"agentic loop halted: identical tool call repeated",
extra={
"event": "loop_halt",
"agent": self._agent_name,
"conversation_id": self._conversation_id,
"tool": tool_label,
"server": _split_server(tool_label),
"repeat_count": self._repeat_count,
"threshold": self._threshold,
"iteration": iteration,
"arguments_preview": args_preview,
"result_preview": result_preview,
},
)
_pallas_metrics.record_loop_abort(self._agent_name, "repeat")
def install_for_request(
agent: Any,
*,
agent_name: str,
conversation_id: str | None,
threshold: int = DEFAULT_THRESHOLD,
):
"""Install the loop guard on a request-scoped agent instance.
Returns a no-arg ``restore`` callable that reinstates the agent's prior
``tool_runner_hooks`` — call it in a ``finally``. A non-positive
``threshold`` disables the guard (returns a no-op restore).
"""
if threshold is None or threshold < 1:
return lambda: None
guard = LoopGuard(
agent_name=agent_name,
conversation_id=conversation_id,
threshold=threshold,
)
extra = ToolRunnerHooks(
before_tool_call=guard.as_before_tool_call_hook(),
after_tool_call=guard.as_after_tool_call_hook(),
)
previous = getattr(agent, "tool_runner_hooks", None)
agent.tool_runner_hooks = _merge_hooks(previous, extra)
def restore() -> None:
agent.tool_runner_hooks = previous
return restore