diff --git a/docs/pallas.md b/docs/pallas.md index ca84b0b..6f06539 100644 --- a/docs/pallas.md +++ b/docs/pallas.md @@ -193,6 +193,8 @@ agents: | `agents..title` | no | Display name in registry. Default: `name.title()` | | `agents..description` | no | Description in registry | | `agents..depends_on` | no | List of agent names that must start and become ready before this agent | +| `agents..max_iterations` | no | Hard cap on agentic-loop turns per `send_message`. Default: `15`. fast-agent returns a partial answer once exceeded | +| `agents..loop_repeat_threshold` | no | Halt the loop after this many consecutive identical `(tool, args) → result` rounds. Default: `3`. `0` disables the guard | ### `fastagent.config.yaml` Extensions @@ -530,6 +532,39 @@ Registered on each agent's MCP server. Checks: --- +## Loop Guard + +A small model occasionally gets stuck emitting the *identical* tool call every +iteration — usually because an upstream MCP server returned a contradictory or +malformed result it keeps trying to reconcile. Left alone the loop burns LLM +turns and context until the client times out and the user sees +`empty_response`. + +`pallas.loop_guard` installs per-request `ToolRunnerHooks` (composed on top of +the assistant-stream hooks) that track a rolling signature of +`(tool, normalized_args) → result_hash`. When the same signature repeats +`loop_repeat_threshold` times consecutively (default **3**), the loop is +**halted immediately** — the runtime does *not* ask the model to troubleshoot, +because the fault is almost always upstream and self-recovery is slow, +unpredictable, and token-hungry. On halt it: + +- collapses the request's `max_iterations` to the current iteration, so + fast-agent's own `_iteration > max_iterations` check terminates the turn + after the current tool result with **no further LLM call**; +- appends an honest, user-facing explanation to the returned turn (and sets + `stop_reason = endTurn`) so the client gets a real message instead of an + empty/truncated one; +- logs the offending tool, arguments, and result at WARNING (`event=loop_halt` + in `pallas.loop_guard`) so the upstream bug can be fixed durably; and +- increments `pallas_agent_loop_aborted_total{reason="repeat"}`. + +This fires well before the `max_iterations` cap (a 3-round repeat halts within +~3 turns regardless of the configured ceiling), which is the point: the cap is +a backstop, the guard is the fast path. Set `loop_repeat_threshold: 0` on an +agent to disable it. + +--- + ## Metrics Pallas exposes Prometheus metrics for scraping and alerting. One scrape target per Pallas deployment is sufficient — all agents run as coroutines in a single process under `asyncio.gather`, so metrics are process-global. @@ -570,6 +605,7 @@ scrape_configs: | `pallas_downstream_up` | gauge | `agent`, `server` | `1` when the named downstream MCP server passed the last `get_health` probe | | `pallas_llm_provider_up` | gauge | `provider` | `1` when the active LLM provider passed its last preflight or runtime re-probe | | `pallas_agent_health_status` | gauge | `agent` | Aggregate from the last `get_health`: `1`=ok, `0.5`=degraded, `0`=error | +| `pallas_agent_loop_aborted_total` | counter | `agent`, `reason` | Agentic loops force-stopped by a runtime guard. `reason` ∈ `repeat` (identical-tool-call loop detected) | Standard process metrics (RSS, CPU, GC, open FDs) are emitted by `prometheus-client`'s default collectors on the same endpoint. @@ -616,6 +652,7 @@ pallas_llm_provider_up == 0 | Agent error rate elevated | `rate(pallas_send_message_total{outcome="error"}[10m]) > 0.1` | >10% errors over 10 min | | Latency regression | `histogram_quantile(0.95, sum by (agent, le) (rate(pallas_send_message_duration_seconds_bucket[10m]))) > 60` | p95 over 60 s | | Token burn | `sum(rate(pallas_llm_tokens_total{kind="output"}[1h])) > N` | Set N to your budget | +| Agent loop halted | `increase(pallas_agent_loop_aborted_total[15m]) > 0` | A repeated-tool-call loop was force-stopped — investigate the upstream tool/data | --- @@ -645,6 +682,7 @@ This avoids the brittle pattern of inferring capabilities from model name substr | `pallas.registry` | `registry.py` | Starlette app serving `GET /.well-known/mcp/server.json` — agent catalogue built from config | | `pallas.multimodal_server` | `multimodal_server.py` | `MultimodalAgentMCPServer` — extends `AgentMCPServer` with image support, conversation history prompts, bearer token propagation | | `pallas.health` | `health.py` | LLM provider preflight validation, downstream MCP server probing, `get_health` tool registration | +| `pallas.loop_guard` | `loop_guard.py` | Per-request `ToolRunnerHooks` that halt the agentic loop on repeated-identical tool calls | | `pallas.log` | `log.py` | JSON log configuration, third-party traceback capture, Rich-TUI-safe handler attachment | | `pallas._fastagent_patch` | `_fastagent_patch.py` | Monkey-patches fast-agent at import time: per-request bearer forwarding via `httpx.Auth`, diagnostic trace-capture wrappers around `send_request` / `session.call_tool` / `_execute_on_server` | diff --git a/pallas/log.py b/pallas/log.py index 8a8f8f0..fb67cbf 100644 --- a/pallas/log.py +++ b/pallas/log.py @@ -90,6 +90,16 @@ class _StaticFieldsFilter(logging.Filter): return True +# Standard ``LogRecord`` attributes — everything else on ``record.__dict__`` is +# an ``extra={...}`` field a caller attached and wants serialized. ``message`` +# and ``asctime`` are populated during formatting; ``taskName`` exists on 3.12+. +_STANDARD_LOGRECORD_KEYS = set(logging.makeLogRecord({}).__dict__) | { + "message", + "asctime", + "taskName", +} + + class _JSONFormatter(logging.Formatter): """Single-line JSON formatter compatible with Alloy's ``| json`` pipeline. @@ -123,6 +133,12 @@ class _JSONFormatter(logging.Formatter): "project": getattr(record, "project", _PROJECT), "component": getattr(record, "component", _COMPONENT_CTX.get()), } + # Merge caller-supplied ``extra={...}`` fields (anything on the record + # that isn't a standard LogRecord attribute or already emitted above). + for key, value in record.__dict__.items(): + if key in _STANDARD_LOGRECORD_KEYS or key in payload: + continue + payload[key] = value if record.exc_info: if not record.exc_text: record.exc_text = self.formatException(record.exc_info) @@ -131,7 +147,8 @@ class _JSONFormatter(logging.Formatter): payload["traceback"] = record.exc_text if record.stack_info: payload["stack"] = self.formatStack(record.stack_info) - return json.dumps(payload) + # default=str keeps a non-serializable extra value from crashing logging. + return json.dumps(payload, default=str) class _HealthAccessFilter(logging.Filter): diff --git a/pallas/loop_guard.py b/pallas/loop_guard.py new file mode 100644 index 0000000..381c615 --- /dev/null +++ b/pallas/loop_guard.py @@ -0,0 +1,231 @@ +"""Runaway-loop detection for the agentic tool loop. + +A small model occasionally gets stuck emitting the *identical* tool call +every iteration — typically because an upstream MCP server returned a +contradictory or malformed result the model keeps trying to reconcile. +Left alone the loop burns LLM turns and context until the client times +out and the user sees ``empty_response``. + +This guard installs ``ToolRunnerHooks`` on the request-scoped agent that +track a per-turn signature of ``(tool, normalized_args) -> result_hash``. +When the same signature repeats ``threshold`` times consecutively the +loop is halted **immediately** — we don't ask the model to troubleshoot, +because the fault is almost always upstream and self-recovery is slow, +unpredictable, and token-hungry. Instead we: + +* force fast-agent's own ``max_iterations`` termination path so the turn + ends after the current tool result with a partial answer rather than a + client timeout; +* replace the dangling turn with an honest, user-facing explanation; +* log the offending tool, arguments, and result with full detail so the + real (upstream) bug can be fixed durably; and +* increment ``pallas_agent_loop_aborted_total`` for alerting. + +The hooks compose after any already-installed hooks (e.g. the assistant +stream) via the same merge strategy ``assistant_stream`` uses. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +from typing import Any + +from fast_agent.agents.tool_runner import ToolRunnerHooks +from fast_agent.types import PromptMessageExtended +from fast_agent.types.llm_stop_reason import LlmStopReason +from mcp.types import TextContent + +from pallas import metrics as _pallas_metrics +from pallas.assistant_stream import _merge_hooks, _split_server +from pallas.progress import format_args_preview, format_result_preview + +logger = logging.getLogger("pallas.loop_guard") + +DEFAULT_THRESHOLD = 3 + + +def _normalize_args(arguments: Any) -> str: + """Deterministically serialize tool arguments for signature comparison.""" + try: + return json.dumps(arguments, sort_keys=True, default=str, ensure_ascii=False) + except (TypeError, ValueError): + return repr(arguments) + + +def _result_hash(result: Any) -> str: + """Stable hash of a tool result's content for repeat detection.""" + parts: list[str] = [] + for block in getattr(result, "content", None) or []: + text = getattr(block, "text", None) + parts.append(text if isinstance(text, str) else repr(block)) + parts.append(f"isError={bool(getattr(result, 'isError', False))}") + digest = hashlib.sha256("\x00".join(parts).encode("utf-8", "replace")) + return digest.hexdigest() + + +class LoopGuard: + """Per-request consecutive-identical-tool-call detector.""" + + def __init__( + self, *, agent_name: str, conversation_id: str | None, threshold: int + ) -> None: + self._agent_name = agent_name + self._conversation_id = conversation_id + self._threshold = threshold + # call_id -> (tool_name, normalized_args, raw_arguments), staged by + # before_tool_call and consumed by the matching after_tool_call. + self._pending: dict[str, tuple[str | None, str, Any]] = {} + self._last_signature: str | None = None + self._repeat_count = 0 + self._halted = False + + def as_before_tool_call_hook(self): + async def _hook(_runner: Any, message: PromptMessageExtended) -> None: + for call_id, call in (message.tool_calls or {}).items(): + params = getattr(call, "params", None) + name = getattr(params, "name", None) + arguments = getattr(params, "arguments", None) + self._pending[call_id] = (name, _normalize_args(arguments), arguments) + + return _hook + + def as_after_tool_call_hook(self): + async def _hook(runner: Any, message: PromptMessageExtended) -> None: + if self._halted: + return + try: + self._evaluate(runner, message) + except Exception: # never let the guard break a live turn + logger.warning( + "loop_guard evaluation failed", + exc_info=True, + extra={ + "agent": self._agent_name, + "conversation_id": self._conversation_id, + }, + ) + + return _hook + + def _evaluate(self, runner: Any, message: PromptMessageExtended) -> None: + results = message.tool_results or {} + if not results: + return + + components: list[tuple[str, str, str]] = [] + names: list[str] = [] + raw_args: list[Any] = [] + for call_id, result in results.items(): + name, args_sig, arguments = self._pending.pop(call_id, (None, "null", None)) + components.append((name or "", args_sig, _result_hash(result))) + if name: + names.append(name) + raw_args.append(arguments) + components.sort() + signature = "|".join(f"{n}:{a}:{r}" for n, a, r in components) + + if signature == self._last_signature: + self._repeat_count += 1 + else: + self._last_signature = signature + self._repeat_count = 1 + + if self._repeat_count >= self._threshold: + tool_label = ", ".join(dict.fromkeys(names)) or "unknown" + args_preview = format_args_preview(raw_args[0]) if raw_args else "" + self._halt(runner, results, tool_label, args_preview) + + def _halt( + self, + runner: Any, + results: dict[str, Any], + tool_label: str, + args_preview: str, + ) -> None: + self._halted = True + + first = next(iter(results.values()), None) + result_preview = ( + format_result_preview(list(getattr(first, "content", []) or [])) + if first is not None + else "" + ) + + # Force fast-agent's own termination check (tool_runner: `_iteration > + # max_iterations`) to fire after this tool result — no further LLM + # call, partial answer returned instead of a client timeout. + params = getattr(runner, "request_params", None) + iteration = getattr(runner, "iteration", 0) + if params is not None: + params.max_iterations = iteration + + # Replace the dangling tool-use turn with an honest final message so + # the client gets an explanation rather than an empty/truncated turn. + note = ( + f"Halted: the '{tool_label}' tool returned an identical result " + f"{self._repeat_count} times in a row, so the agent was looping " + "without making progress. This is usually an upstream data or " + "tool-server issue rather than a problem with the request. " + "Stopping early to avoid a runaway loop." + ) + last = getattr(runner, "last_message", None) + if last is not None: + if last.content is None: + last.content = [] + last.content.append(TextContent(type="text", text=note)) + last.stop_reason = LlmStopReason.END_TURN + + logger.warning( + "agentic loop halted: identical tool call repeated", + extra={ + "event": "loop_halt", + "agent": self._agent_name, + "conversation_id": self._conversation_id, + "tool": tool_label, + "server": _split_server(tool_label), + "repeat_count": self._repeat_count, + "threshold": self._threshold, + "iteration": iteration, + "arguments_preview": args_preview, + "result_preview": result_preview, + }, + ) + _pallas_metrics.record_loop_abort(self._agent_name, "repeat") + + +def install_for_request( + agent: Any, + *, + agent_name: str, + conversation_id: str | None, + threshold: int = DEFAULT_THRESHOLD, +): + """Install the loop guard on a request-scoped agent instance. + + Returns a no-arg ``restore`` callable that reinstates the agent's prior + ``tool_runner_hooks`` — call it in a ``finally``. A non-positive + ``threshold`` disables the guard (returns a no-op restore). + """ + if threshold is None or threshold < 1: + return lambda: None + + guard = LoopGuard( + agent_name=agent_name, + conversation_id=conversation_id, + threshold=threshold, + ) + + extra = ToolRunnerHooks( + before_tool_call=guard.as_before_tool_call_hook(), + after_tool_call=guard.as_after_tool_call_hook(), + ) + + previous = getattr(agent, "tool_runner_hooks", None) + agent.tool_runner_hooks = _merge_hooks(previous, extra) + + def restore() -> None: + agent.tool_runner_hooks = previous + + return restore diff --git a/pallas/metrics.py b/pallas/metrics.py index 2185e0e..eab1fb3 100644 --- a/pallas/metrics.py +++ b/pallas/metrics.py @@ -131,10 +131,22 @@ agent_health_status = Gauge( registry=REGISTRY, ) +agent_loop_aborted_total = Counter( + "pallas_agent_loop_aborted_total", + "Agentic loops force-stopped by a runtime guard", + labelnames=["agent", "reason"], # reason: repeat + registry=REGISTRY, +) + # ── Helpers ────────────────────────────────────────────────────────────────── +def record_loop_abort(agent: str, reason: str) -> None: + """Record one agentic loop aborted by a runtime guard.""" + agent_loop_aborted_total.labels(agent=agent, reason=reason).inc() + + def set_agent_info(agents: dict[str, dict]) -> None: """Record the deployment's configured agents (called once at startup).""" for name, agent in agents.items(): diff --git a/pallas/multimodal_server.py b/pallas/multimodal_server.py index 8cd5d7c..3ab6658 100644 --- a/pallas/multimodal_server.py +++ b/pallas/multimodal_server.py @@ -29,6 +29,7 @@ from fast_agent.mcp.server import AgentMCPServer from fast_agent.types import PromptMessageExtended, RequestParams from pallas.assistant_stream import install_for_request as _install_assistant_stream +from pallas.loop_guard import install_for_request as _install_loop_guard from pallas.progress import EnrichedMCPToolProgressManager from pallas import metrics as _pallas_metrics from fastmcp import Context as MCPContext @@ -229,12 +230,25 @@ class MultimodalAgentMCPServer(AgentMCPServer): # in earlier loop iterations stays trapped inside fast-agent's # ``message_history`` and the user sees a spinner that ends with # a thin wrap-up sentence. - restore_hooks = _install_assistant_stream( + restore_stream = _install_assistant_stream( agent, ctx=ctx, agent_name=agent_name, conversation_id=conversation_id, ) + # Compose the loop guard on top: it halts the agentic loop the + # moment a tool call repeats with an identical result, before + # the turn runs to the iteration cap or client timeout. + restore_guard = _install_loop_guard( + agent, + agent_name=agent_name, + conversation_id=conversation_id, + threshold=self._request_limits.get("loop_repeat_threshold", 3), + ) + + def restore_hooks() -> None: + restore_guard() + restore_stream() try: # Seed the freshly-created instance's message_history from the # caller-supplied history so the agent sees the full diff --git a/pallas/server.py b/pallas/server.py index accd7c5..2aa481a 100644 --- a/pallas/server.py +++ b/pallas/server.py @@ -65,6 +65,7 @@ def _build_agents_table(config: dict) -> dict[str, dict]: "max_iterations": agent.get("max_iterations"), "streaming_timeout": agent.get("streaming_timeout"), "turn_timeout": agent.get("turn_timeout"), + "loop_repeat_threshold": agent.get("loop_repeat_threshold"), } for name, agent in config["agents"].items() } @@ -264,7 +265,12 @@ async def _start_agent(name: str, agents: dict[str, dict]) -> None: # and the LLM sees exactly what Daedalus asks it to see. request_limits = { k: entry[k] - for k in ("max_iterations", "streaming_timeout", "turn_timeout") + for k in ( + "max_iterations", + "streaming_timeout", + "turn_timeout", + "loop_repeat_threshold", + ) if entry.get(k) is not None } diff --git a/tests/test_log.py b/tests/test_log.py new file mode 100644 index 0000000..665da44 --- /dev/null +++ b/tests/test_log.py @@ -0,0 +1,55 @@ +"""Tests for ``pallas.log._JSONFormatter``. + +The formatter must serialize caller-supplied ``extra={...}`` fields (the +loop guard and other diagnostics rely on this) while never emitting the +internal ``LogRecord`` bookkeeping attributes. +""" +from __future__ import annotations + +import json +import logging + +from pallas.log import _JSONFormatter + + +def _format(msg: str, **extra) -> dict: + record = logging.LogRecord( + name="pallas.test", + level=logging.WARNING, + pathname=__file__, + lineno=1, + msg=msg, + args=(), + exc_info=None, + ) + for key, value in extra.items(): + setattr(record, key, value) + return json.loads(_JSONFormatter().format(record)) + + +def test_extra_fields_are_serialized(): + out = _format( + "agentic loop halted", + event="loop_halt", + tool="kairos-update_task", + repeat_count=3, + result_preview="COMPLETED but 0%", + ) + assert out["message"] == "agentic loop halted" + assert out["event"] == "loop_halt" + assert out["tool"] == "kairos-update_task" + assert out["repeat_count"] == 3 + assert out["result_preview"] == "COMPLETED but 0%" + + +def test_standard_attributes_are_not_leaked(): + out = _format("plain message") + for noise in ("msg", "args", "levelno", "pathname", "lineno", "funcName"): + assert noise not in out + assert out["level"] == "WARNING" + assert out["logger"] == "pallas.test" + + +def test_non_serializable_extra_does_not_crash(): + out = _format("with object", obj=object()) + assert "obj" in out # coerced via default=str, not dropped or raised diff --git a/tests/test_loop_guard.py b/tests/test_loop_guard.py new file mode 100644 index 0000000..f75478b --- /dev/null +++ b/tests/test_loop_guard.py @@ -0,0 +1,190 @@ +"""Tests for ``pallas.loop_guard``. + +Drives the ``before_tool_call`` / ``after_tool_call`` hooks with handcrafted +``PromptMessageExtended`` objects against a fake ToolRunner and asserts the +halt behaviour: the runner's ``max_iterations`` is collapsed to the current +iteration (so fast-agent terminates on its next check), the dangling turn is +annotated with an explanation, and the abort metric is incremented. + +No fast-agent runtime is involved — the hooks are pure async functions. +Uses ``asyncio.run`` directly to match the convention in the other test +modules (pallas has no pytest-asyncio dependency). +""" +from __future__ import annotations + +import asyncio +from types import SimpleNamespace +from typing import Any + +from fast_agent.types import PromptMessageExtended +from fast_agent.types.llm_stop_reason import LlmStopReason +from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + CallToolResult, + TextContent, +) + +from pallas import metrics as _pallas_metrics +from pallas.loop_guard import LoopGuard, install_for_request + + +def _run(coro): + return asyncio.run(coro) + + +def _tool_call(name: str, arguments: dict | None = None) -> CallToolRequest: + return CallToolRequest( + method="tools/call", + params=CallToolRequestParams(name=name, arguments=arguments or {}), + ) + + +def _tool_result(text: str = "ok", *, is_error: bool = False) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text=text)], isError=is_error + ) + + +def _request(name: str, arguments: dict, call_id: str = "toolu_1"): + return PromptMessageExtended( + role="assistant", + content=[], + tool_calls={call_id: _tool_call(name, arguments)}, + ) + + +def _result(text: str, call_id: str = "toolu_1"): + return PromptMessageExtended( + role="user", content=[], tool_results={call_id: _tool_result(text)} + ) + + +class _FakeRunner: + def __init__(self, *, iteration: int = 5, max_iterations: int = 30) -> None: + self.request_params = SimpleNamespace(max_iterations=max_iterations) + self.iteration = iteration + self.last_message = PromptMessageExtended( + role="assistant", + content=[], + stop_reason=LlmStopReason.TOOL_USE, + ) + + +async def _drive(guard: LoopGuard, runner: _FakeRunner, name, args, result) -> None: + """Run one tool round through the guard's hooks.""" + before = guard.as_before_tool_call_hook() + after = guard.as_after_tool_call_hook() + await before(runner, _request(name, args)) + await after(runner, _result(result)) + + +def _abort_count(agent: str) -> float: + return ( + _pallas_metrics.agent_loop_aborted_total.labels(agent=agent, reason="repeat") + ._value.get() + ) + + +def _last_text(runner: _FakeRunner) -> str: + return "".join( + b.text for b in (runner.last_message.content or []) if hasattr(b, "text") + ) + + +def test_halts_on_third_identical_round(): + guard = LoopGuard(agent_name="shawn", conversation_id="c1", threshold=3) + runner = _FakeRunner(iteration=12, max_iterations=30) + before = _abort_count("shawn") + + async def go(): + for _ in range(2): + await _drive(guard, runner, "kairos-update_task", {"task_id": 494}, "same") + # not yet halted after rounds 1 and 2 + assert runner.request_params.max_iterations == 30 + assert runner.last_message.stop_reason == LlmStopReason.TOOL_USE + # third identical round trips the guard + await _drive(guard, runner, "kairos-update_task", {"task_id": 494}, "same") + + _run(go()) + + # max_iterations collapsed to the current iteration -> fast-agent stops + # on its next `_iteration > max_iterations` check, no further LLM call. + assert runner.request_params.max_iterations == 12 + assert runner.last_message.stop_reason == LlmStopReason.END_TURN + assert "Halted" in _last_text(runner) + assert "kairos-update_task" in _last_text(runner) + assert _abort_count("shawn") == before + 1 + + +def test_no_halt_when_result_changes(): + guard = LoopGuard(agent_name="a1", conversation_id=None, threshold=3) + runner = _FakeRunner() + + async def go(): + for i in range(6): + await _drive( + guard, runner, "kairos-update_task", {"task_id": 494}, f"r{i}" + ) + + _run(go()) + assert runner.request_params.max_iterations == 30 + assert runner.last_message.stop_reason == LlmStopReason.TOOL_USE + + +def test_no_halt_when_args_change(): + guard = LoopGuard(agent_name="a2", conversation_id=None, threshold=3) + runner = _FakeRunner() + + async def go(): + for i in range(6): + await _drive(guard, runner, "kairos-update_task", {"task_id": i}, "same") + + _run(go()) + assert runner.request_params.max_iterations == 30 + + +def test_threshold_respected(): + guard = LoopGuard(agent_name="a3", conversation_id=None, threshold=5) + runner = _FakeRunner() + + async def go(): + for _ in range(4): + await _drive(guard, runner, "t", {"x": 1}, "same") + + _run(go()) + # 4 identical rounds, threshold 5 -> still running + assert runner.request_params.max_iterations == 30 + + +def test_halt_fires_once(): + guard = LoopGuard(agent_name="a4", conversation_id=None, threshold=3) + runner = _FakeRunner() + before = _abort_count("a4") + + async def go(): + for _ in range(6): + await _drive(guard, runner, "t", {"x": 1}, "same") + + _run(go()) + assert _abort_count("a4") == before + 1 + + +def test_install_disabled_with_nonpositive_threshold(): + agent = SimpleNamespace(tool_runner_hooks="sentinel") + restore = install_for_request( + agent, agent_name="a", conversation_id=None, threshold=0 + ) + assert agent.tool_runner_hooks == "sentinel" # untouched + restore() # no-op, must not raise + + +def test_install_merges_and_restores(): + agent = SimpleNamespace(tool_runner_hooks=None) + restore = install_for_request( + agent, agent_name="a", conversation_id=None, threshold=3 + ) + assert agent.tool_runner_hooks is not None + assert agent.tool_runner_hooks.after_tool_call is not None + restore() + assert agent.tool_runner_hooks is None