Introduces `pallas.loop_guard` module that detects and halts agentic loops
where the same `(tool, args) → result` repeats consecutively, preventing
wasted LLM turns when upstream MCP servers return contradictory data.
- Add per-request `ToolRunnerHooks` tracking rolling tool-call signatures
- Halt loop after `loop_repeat_threshold` consecutive repeats (default 3)
- Collapse `max_iterations` on halt to terminate without further LLM call
- Append user-facing explanation to the turn with `stop_reason=endTurn`
- Expose `pallas_agent_loop_aborted_total{agent,reason}` counter
- Add per-agent `max_iterations` and `loop_repeat_threshold` config
- Document guard behavior, metric, and alerting query
191 lines
6.1 KiB
Python
191 lines
6.1 KiB
Python
"""Tests for ``pallas.loop_guard``.
|
|
|
|
Drives the ``before_tool_call`` / ``after_tool_call`` hooks with handcrafted
|
|
``PromptMessageExtended`` objects against a fake ToolRunner and asserts the
|
|
halt behaviour: the runner's ``max_iterations`` is collapsed to the current
|
|
iteration (so fast-agent terminates on its next check), the dangling turn is
|
|
annotated with an explanation, and the abort metric is incremented.
|
|
|
|
No fast-agent runtime is involved — the hooks are pure async functions.
|
|
Uses ``asyncio.run`` directly to match the convention in the other test
|
|
modules (pallas has no pytest-asyncio dependency).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from types import SimpleNamespace
|
|
from typing import Any
|
|
|
|
from fast_agent.types import PromptMessageExtended
|
|
from fast_agent.types.llm_stop_reason import LlmStopReason
|
|
from mcp.types import (
|
|
CallToolRequest,
|
|
CallToolRequestParams,
|
|
CallToolResult,
|
|
TextContent,
|
|
)
|
|
|
|
from pallas import metrics as _pallas_metrics
|
|
from pallas.loop_guard import LoopGuard, install_for_request
|
|
|
|
|
|
def _run(coro):
|
|
return asyncio.run(coro)
|
|
|
|
|
|
def _tool_call(name: str, arguments: dict | None = None) -> CallToolRequest:
|
|
return CallToolRequest(
|
|
method="tools/call",
|
|
params=CallToolRequestParams(name=name, arguments=arguments or {}),
|
|
)
|
|
|
|
|
|
def _tool_result(text: str = "ok", *, is_error: bool = False) -> CallToolResult:
|
|
return CallToolResult(
|
|
content=[TextContent(type="text", text=text)], isError=is_error
|
|
)
|
|
|
|
|
|
def _request(name: str, arguments: dict, call_id: str = "toolu_1"):
|
|
return PromptMessageExtended(
|
|
role="assistant",
|
|
content=[],
|
|
tool_calls={call_id: _tool_call(name, arguments)},
|
|
)
|
|
|
|
|
|
def _result(text: str, call_id: str = "toolu_1"):
|
|
return PromptMessageExtended(
|
|
role="user", content=[], tool_results={call_id: _tool_result(text)}
|
|
)
|
|
|
|
|
|
class _FakeRunner:
|
|
def __init__(self, *, iteration: int = 5, max_iterations: int = 30) -> None:
|
|
self.request_params = SimpleNamespace(max_iterations=max_iterations)
|
|
self.iteration = iteration
|
|
self.last_message = PromptMessageExtended(
|
|
role="assistant",
|
|
content=[],
|
|
stop_reason=LlmStopReason.TOOL_USE,
|
|
)
|
|
|
|
|
|
async def _drive(guard: LoopGuard, runner: _FakeRunner, name, args, result) -> None:
|
|
"""Run one tool round through the guard's hooks."""
|
|
before = guard.as_before_tool_call_hook()
|
|
after = guard.as_after_tool_call_hook()
|
|
await before(runner, _request(name, args))
|
|
await after(runner, _result(result))
|
|
|
|
|
|
def _abort_count(agent: str) -> float:
|
|
return (
|
|
_pallas_metrics.agent_loop_aborted_total.labels(agent=agent, reason="repeat")
|
|
._value.get()
|
|
)
|
|
|
|
|
|
def _last_text(runner: _FakeRunner) -> str:
|
|
return "".join(
|
|
b.text for b in (runner.last_message.content or []) if hasattr(b, "text")
|
|
)
|
|
|
|
|
|
def test_halts_on_third_identical_round():
|
|
guard = LoopGuard(agent_name="shawn", conversation_id="c1", threshold=3)
|
|
runner = _FakeRunner(iteration=12, max_iterations=30)
|
|
before = _abort_count("shawn")
|
|
|
|
async def go():
|
|
for _ in range(2):
|
|
await _drive(guard, runner, "kairos-update_task", {"task_id": 494}, "same")
|
|
# not yet halted after rounds 1 and 2
|
|
assert runner.request_params.max_iterations == 30
|
|
assert runner.last_message.stop_reason == LlmStopReason.TOOL_USE
|
|
# third identical round trips the guard
|
|
await _drive(guard, runner, "kairos-update_task", {"task_id": 494}, "same")
|
|
|
|
_run(go())
|
|
|
|
# max_iterations collapsed to the current iteration -> fast-agent stops
|
|
# on its next `_iteration > max_iterations` check, no further LLM call.
|
|
assert runner.request_params.max_iterations == 12
|
|
assert runner.last_message.stop_reason == LlmStopReason.END_TURN
|
|
assert "Halted" in _last_text(runner)
|
|
assert "kairos-update_task" in _last_text(runner)
|
|
assert _abort_count("shawn") == before + 1
|
|
|
|
|
|
def test_no_halt_when_result_changes():
|
|
guard = LoopGuard(agent_name="a1", conversation_id=None, threshold=3)
|
|
runner = _FakeRunner()
|
|
|
|
async def go():
|
|
for i in range(6):
|
|
await _drive(
|
|
guard, runner, "kairos-update_task", {"task_id": 494}, f"r{i}"
|
|
)
|
|
|
|
_run(go())
|
|
assert runner.request_params.max_iterations == 30
|
|
assert runner.last_message.stop_reason == LlmStopReason.TOOL_USE
|
|
|
|
|
|
def test_no_halt_when_args_change():
|
|
guard = LoopGuard(agent_name="a2", conversation_id=None, threshold=3)
|
|
runner = _FakeRunner()
|
|
|
|
async def go():
|
|
for i in range(6):
|
|
await _drive(guard, runner, "kairos-update_task", {"task_id": i}, "same")
|
|
|
|
_run(go())
|
|
assert runner.request_params.max_iterations == 30
|
|
|
|
|
|
def test_threshold_respected():
|
|
guard = LoopGuard(agent_name="a3", conversation_id=None, threshold=5)
|
|
runner = _FakeRunner()
|
|
|
|
async def go():
|
|
for _ in range(4):
|
|
await _drive(guard, runner, "t", {"x": 1}, "same")
|
|
|
|
_run(go())
|
|
# 4 identical rounds, threshold 5 -> still running
|
|
assert runner.request_params.max_iterations == 30
|
|
|
|
|
|
def test_halt_fires_once():
|
|
guard = LoopGuard(agent_name="a4", conversation_id=None, threshold=3)
|
|
runner = _FakeRunner()
|
|
before = _abort_count("a4")
|
|
|
|
async def go():
|
|
for _ in range(6):
|
|
await _drive(guard, runner, "t", {"x": 1}, "same")
|
|
|
|
_run(go())
|
|
assert _abort_count("a4") == before + 1
|
|
|
|
|
|
def test_install_disabled_with_nonpositive_threshold():
|
|
agent = SimpleNamespace(tool_runner_hooks="sentinel")
|
|
restore = install_for_request(
|
|
agent, agent_name="a", conversation_id=None, threshold=0
|
|
)
|
|
assert agent.tool_runner_hooks == "sentinel" # untouched
|
|
restore() # no-op, must not raise
|
|
|
|
|
|
def test_install_merges_and_restores():
|
|
agent = SimpleNamespace(tool_runner_hooks=None)
|
|
restore = install_for_request(
|
|
agent, agent_name="a", conversation_id=None, threshold=3
|
|
)
|
|
assert agent.tool_runner_hooks is not None
|
|
assert agent.tool_runner_hooks.after_tool_call is not None
|
|
restore()
|
|
assert agent.tool_runner_hooks is None
|