feat: add loop guard to halt repeated-identical tool call loops
Introduces `pallas.loop_guard` module that detects and halts agentic loops
where the same `(tool, args) → result` repeats consecutively, preventing
wasted LLM turns when upstream MCP servers return contradictory data.
- Add per-request `ToolRunnerHooks` tracking rolling tool-call signatures
- Halt loop after `loop_repeat_threshold` consecutive repeats (default 3)
- Collapse `max_iterations` on halt to terminate without further LLM call
- Append user-facing explanation to the turn with `stop_reason=endTurn`
- Expose `pallas_agent_loop_aborted_total{agent,reason}` counter
- Add per-agent `max_iterations` and `loop_repeat_threshold` config
- Document guard behavior, metric, and alerting query
This commit is contained in:
190
tests/test_loop_guard.py
Normal file
190
tests/test_loop_guard.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Tests for ``pallas.loop_guard``.
|
||||
|
||||
Drives the ``before_tool_call`` / ``after_tool_call`` hooks with handcrafted
|
||||
``PromptMessageExtended`` objects against a fake ToolRunner and asserts the
|
||||
halt behaviour: the runner's ``max_iterations`` is collapsed to the current
|
||||
iteration (so fast-agent terminates on its next check), the dangling turn is
|
||||
annotated with an explanation, and the abort metric is incremented.
|
||||
|
||||
No fast-agent runtime is involved — the hooks are pure async functions.
|
||||
Uses ``asyncio.run`` directly to match the convention in the other test
|
||||
modules (pallas has no pytest-asyncio dependency).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from fast_agent.types import PromptMessageExtended
|
||||
from fast_agent.types.llm_stop_reason import LlmStopReason
|
||||
from mcp.types import (
|
||||
CallToolRequest,
|
||||
CallToolRequestParams,
|
||||
CallToolResult,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
from pallas import metrics as _pallas_metrics
|
||||
from pallas.loop_guard import LoopGuard, install_for_request
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _tool_call(name: str, arguments: dict | None = None) -> CallToolRequest:
|
||||
return CallToolRequest(
|
||||
method="tools/call",
|
||||
params=CallToolRequestParams(name=name, arguments=arguments or {}),
|
||||
)
|
||||
|
||||
|
||||
def _tool_result(text: str = "ok", *, is_error: bool = False) -> CallToolResult:
|
||||
return CallToolResult(
|
||||
content=[TextContent(type="text", text=text)], isError=is_error
|
||||
)
|
||||
|
||||
|
||||
def _request(name: str, arguments: dict, call_id: str = "toolu_1"):
|
||||
return PromptMessageExtended(
|
||||
role="assistant",
|
||||
content=[],
|
||||
tool_calls={call_id: _tool_call(name, arguments)},
|
||||
)
|
||||
|
||||
|
||||
def _result(text: str, call_id: str = "toolu_1"):
|
||||
return PromptMessageExtended(
|
||||
role="user", content=[], tool_results={call_id: _tool_result(text)}
|
||||
)
|
||||
|
||||
|
||||
class _FakeRunner:
|
||||
def __init__(self, *, iteration: int = 5, max_iterations: int = 30) -> None:
|
||||
self.request_params = SimpleNamespace(max_iterations=max_iterations)
|
||||
self.iteration = iteration
|
||||
self.last_message = PromptMessageExtended(
|
||||
role="assistant",
|
||||
content=[],
|
||||
stop_reason=LlmStopReason.TOOL_USE,
|
||||
)
|
||||
|
||||
|
||||
async def _drive(guard: LoopGuard, runner: _FakeRunner, name, args, result) -> None:
|
||||
"""Run one tool round through the guard's hooks."""
|
||||
before = guard.as_before_tool_call_hook()
|
||||
after = guard.as_after_tool_call_hook()
|
||||
await before(runner, _request(name, args))
|
||||
await after(runner, _result(result))
|
||||
|
||||
|
||||
def _abort_count(agent: str) -> float:
|
||||
return (
|
||||
_pallas_metrics.agent_loop_aborted_total.labels(agent=agent, reason="repeat")
|
||||
._value.get()
|
||||
)
|
||||
|
||||
|
||||
def _last_text(runner: _FakeRunner) -> str:
|
||||
return "".join(
|
||||
b.text for b in (runner.last_message.content or []) if hasattr(b, "text")
|
||||
)
|
||||
|
||||
|
||||
def test_halts_on_third_identical_round():
|
||||
guard = LoopGuard(agent_name="shawn", conversation_id="c1", threshold=3)
|
||||
runner = _FakeRunner(iteration=12, max_iterations=30)
|
||||
before = _abort_count("shawn")
|
||||
|
||||
async def go():
|
||||
for _ in range(2):
|
||||
await _drive(guard, runner, "kairos-update_task", {"task_id": 494}, "same")
|
||||
# not yet halted after rounds 1 and 2
|
||||
assert runner.request_params.max_iterations == 30
|
||||
assert runner.last_message.stop_reason == LlmStopReason.TOOL_USE
|
||||
# third identical round trips the guard
|
||||
await _drive(guard, runner, "kairos-update_task", {"task_id": 494}, "same")
|
||||
|
||||
_run(go())
|
||||
|
||||
# max_iterations collapsed to the current iteration -> fast-agent stops
|
||||
# on its next `_iteration > max_iterations` check, no further LLM call.
|
||||
assert runner.request_params.max_iterations == 12
|
||||
assert runner.last_message.stop_reason == LlmStopReason.END_TURN
|
||||
assert "Halted" in _last_text(runner)
|
||||
assert "kairos-update_task" in _last_text(runner)
|
||||
assert _abort_count("shawn") == before + 1
|
||||
|
||||
|
||||
def test_no_halt_when_result_changes():
|
||||
guard = LoopGuard(agent_name="a1", conversation_id=None, threshold=3)
|
||||
runner = _FakeRunner()
|
||||
|
||||
async def go():
|
||||
for i in range(6):
|
||||
await _drive(
|
||||
guard, runner, "kairos-update_task", {"task_id": 494}, f"r{i}"
|
||||
)
|
||||
|
||||
_run(go())
|
||||
assert runner.request_params.max_iterations == 30
|
||||
assert runner.last_message.stop_reason == LlmStopReason.TOOL_USE
|
||||
|
||||
|
||||
def test_no_halt_when_args_change():
|
||||
guard = LoopGuard(agent_name="a2", conversation_id=None, threshold=3)
|
||||
runner = _FakeRunner()
|
||||
|
||||
async def go():
|
||||
for i in range(6):
|
||||
await _drive(guard, runner, "kairos-update_task", {"task_id": i}, "same")
|
||||
|
||||
_run(go())
|
||||
assert runner.request_params.max_iterations == 30
|
||||
|
||||
|
||||
def test_threshold_respected():
|
||||
guard = LoopGuard(agent_name="a3", conversation_id=None, threshold=5)
|
||||
runner = _FakeRunner()
|
||||
|
||||
async def go():
|
||||
for _ in range(4):
|
||||
await _drive(guard, runner, "t", {"x": 1}, "same")
|
||||
|
||||
_run(go())
|
||||
# 4 identical rounds, threshold 5 -> still running
|
||||
assert runner.request_params.max_iterations == 30
|
||||
|
||||
|
||||
def test_halt_fires_once():
|
||||
guard = LoopGuard(agent_name="a4", conversation_id=None, threshold=3)
|
||||
runner = _FakeRunner()
|
||||
before = _abort_count("a4")
|
||||
|
||||
async def go():
|
||||
for _ in range(6):
|
||||
await _drive(guard, runner, "t", {"x": 1}, "same")
|
||||
|
||||
_run(go())
|
||||
assert _abort_count("a4") == before + 1
|
||||
|
||||
|
||||
def test_install_disabled_with_nonpositive_threshold():
|
||||
agent = SimpleNamespace(tool_runner_hooks="sentinel")
|
||||
restore = install_for_request(
|
||||
agent, agent_name="a", conversation_id=None, threshold=0
|
||||
)
|
||||
assert agent.tool_runner_hooks == "sentinel" # untouched
|
||||
restore() # no-op, must not raise
|
||||
|
||||
|
||||
def test_install_merges_and_restores():
|
||||
agent = SimpleNamespace(tool_runner_hooks=None)
|
||||
restore = install_for_request(
|
||||
agent, agent_name="a", conversation_id=None, threshold=3
|
||||
)
|
||||
assert agent.tool_runner_hooks is not None
|
||||
assert agent.tool_runner_hooks.after_tool_call is not None
|
||||
restore()
|
||||
assert agent.tool_runner_hooks is None
|
||||
Reference in New Issue
Block a user