"""Tests for ``pallas.loop_guard``. Drives the ``before_tool_call`` / ``after_tool_call`` hooks with handcrafted ``PromptMessageExtended`` objects against a fake ToolRunner and asserts the halt behaviour: the runner's ``max_iterations`` is collapsed to the current iteration (so fast-agent terminates on its next check), the dangling turn is annotated with an explanation, and the abort metric is incremented. No fast-agent runtime is involved — the hooks are pure async functions. Uses ``asyncio.run`` directly to match the convention in the other test modules (pallas has no pytest-asyncio dependency). """ from __future__ import annotations import asyncio from types import SimpleNamespace from typing import Any from fast_agent.types import PromptMessageExtended from fast_agent.types.llm_stop_reason import LlmStopReason from mcp.types import ( CallToolRequest, CallToolRequestParams, CallToolResult, TextContent, ) from pallas import metrics as _pallas_metrics from pallas.loop_guard import LoopGuard, install_for_request def _run(coro): return asyncio.run(coro) def _tool_call(name: str, arguments: dict | None = None) -> CallToolRequest: return CallToolRequest( method="tools/call", params=CallToolRequestParams(name=name, arguments=arguments or {}), ) def _tool_result(text: str = "ok", *, is_error: bool = False) -> CallToolResult: return CallToolResult( content=[TextContent(type="text", text=text)], isError=is_error ) def _request(name: str, arguments: dict, call_id: str = "toolu_1"): return PromptMessageExtended( role="assistant", content=[], tool_calls={call_id: _tool_call(name, arguments)}, ) def _result(text: str, call_id: str = "toolu_1"): return PromptMessageExtended( role="user", content=[], tool_results={call_id: _tool_result(text)} ) class _FakeRunner: def __init__(self, *, iteration: int = 5, max_iterations: int = 30) -> None: self.request_params = SimpleNamespace(max_iterations=max_iterations) self.iteration = iteration self.last_message = PromptMessageExtended( role="assistant", content=[], stop_reason=LlmStopReason.TOOL_USE, ) async def _drive(guard: LoopGuard, runner: _FakeRunner, name, args, result) -> None: """Run one tool round through the guard's hooks.""" before = guard.as_before_tool_call_hook() after = guard.as_after_tool_call_hook() await before(runner, _request(name, args)) await after(runner, _result(result)) def _abort_count(agent: str) -> float: return ( _pallas_metrics.agent_loop_aborted_total.labels(agent=agent, reason="repeat") ._value.get() ) def _last_text(runner: _FakeRunner) -> str: return "".join( b.text for b in (runner.last_message.content or []) if hasattr(b, "text") ) def test_halts_on_third_identical_round(): guard = LoopGuard(agent_name="shawn", conversation_id="c1", threshold=3) runner = _FakeRunner(iteration=12, max_iterations=30) before = _abort_count("shawn") async def go(): for _ in range(2): await _drive(guard, runner, "kairos-update_task", {"task_id": 494}, "same") # not yet halted after rounds 1 and 2 assert runner.request_params.max_iterations == 30 assert runner.last_message.stop_reason == LlmStopReason.TOOL_USE # third identical round trips the guard await _drive(guard, runner, "kairos-update_task", {"task_id": 494}, "same") _run(go()) # max_iterations collapsed to the current iteration -> fast-agent stops # on its next `_iteration > max_iterations` check, no further LLM call. assert runner.request_params.max_iterations == 12 assert runner.last_message.stop_reason == LlmStopReason.END_TURN assert "Halted" in _last_text(runner) assert "kairos-update_task" in _last_text(runner) assert _abort_count("shawn") == before + 1 def test_no_halt_when_result_changes(): guard = LoopGuard(agent_name="a1", conversation_id=None, threshold=3) runner = _FakeRunner() async def go(): for i in range(6): await _drive( guard, runner, "kairos-update_task", {"task_id": 494}, f"r{i}" ) _run(go()) assert runner.request_params.max_iterations == 30 assert runner.last_message.stop_reason == LlmStopReason.TOOL_USE def test_no_halt_when_args_change(): guard = LoopGuard(agent_name="a2", conversation_id=None, threshold=3) runner = _FakeRunner() async def go(): for i in range(6): await _drive(guard, runner, "kairos-update_task", {"task_id": i}, "same") _run(go()) assert runner.request_params.max_iterations == 30 def test_threshold_respected(): guard = LoopGuard(agent_name="a3", conversation_id=None, threshold=5) runner = _FakeRunner() async def go(): for _ in range(4): await _drive(guard, runner, "t", {"x": 1}, "same") _run(go()) # 4 identical rounds, threshold 5 -> still running assert runner.request_params.max_iterations == 30 def test_halt_fires_once(): guard = LoopGuard(agent_name="a4", conversation_id=None, threshold=3) runner = _FakeRunner() before = _abort_count("a4") async def go(): for _ in range(6): await _drive(guard, runner, "t", {"x": 1}, "same") _run(go()) assert _abort_count("a4") == before + 1 def test_install_disabled_with_nonpositive_threshold(): agent = SimpleNamespace(tool_runner_hooks="sentinel") restore = install_for_request( agent, agent_name="a", conversation_id=None, threshold=0 ) assert agent.tool_runner_hooks == "sentinel" # untouched restore() # no-op, must not raise def test_install_merges_and_restores(): agent = SimpleNamespace(tool_runner_hooks=None) restore = install_for_request( agent, agent_name="a", conversation_id=None, threshold=3 ) assert agent.tool_runner_hooks is not None assert agent.tool_runner_hooks.after_tool_call is not None restore() assert agent.tool_runner_hooks is None