feat: add loop guard to halt repeated-identical tool call loops
Introduces `pallas.loop_guard` module that detects and halts agentic loops
where the same `(tool, args) → result` repeats consecutively, preventing
wasted LLM turns when upstream MCP servers return contradictory data.
- Add per-request `ToolRunnerHooks` tracking rolling tool-call signatures
- Halt loop after `loop_repeat_threshold` consecutive repeats (default 3)
- Collapse `max_iterations` on halt to terminate without further LLM call
- Append user-facing explanation to the turn with `stop_reason=endTurn`
- Expose `pallas_agent_loop_aborted_total{agent,reason}` counter
- Add per-agent `max_iterations` and `loop_repeat_threshold` config
- Document guard behavior, metric, and alerting query
This commit is contained in:
@@ -193,6 +193,8 @@ agents:
|
|||||||
| `agents.<name>.title` | no | Display name in registry. Default: `name.title()` |
|
| `agents.<name>.title` | no | Display name in registry. Default: `name.title()` |
|
||||||
| `agents.<name>.description` | no | Description in registry |
|
| `agents.<name>.description` | no | Description in registry |
|
||||||
| `agents.<name>.depends_on` | no | List of agent names that must start and become ready before this agent |
|
| `agents.<name>.depends_on` | no | List of agent names that must start and become ready before this agent |
|
||||||
|
| `agents.<name>.max_iterations` | no | Hard cap on agentic-loop turns per `send_message`. Default: `15`. fast-agent returns a partial answer once exceeded |
|
||||||
|
| `agents.<name>.loop_repeat_threshold` | no | Halt the loop after this many consecutive identical `(tool, args) → result` rounds. Default: `3`. `0` disables the guard |
|
||||||
|
|
||||||
### `fastagent.config.yaml` Extensions
|
### `fastagent.config.yaml` Extensions
|
||||||
|
|
||||||
@@ -530,6 +532,39 @@ Registered on each agent's MCP server. Checks:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Loop Guard
|
||||||
|
|
||||||
|
A small model occasionally gets stuck emitting the *identical* tool call every
|
||||||
|
iteration — usually because an upstream MCP server returned a contradictory or
|
||||||
|
malformed result it keeps trying to reconcile. Left alone the loop burns LLM
|
||||||
|
turns and context until the client times out and the user sees
|
||||||
|
`empty_response`.
|
||||||
|
|
||||||
|
`pallas.loop_guard` installs per-request `ToolRunnerHooks` (composed on top of
|
||||||
|
the assistant-stream hooks) that track a rolling signature of
|
||||||
|
`(tool, normalized_args) → result_hash`. When the same signature repeats
|
||||||
|
`loop_repeat_threshold` times consecutively (default **3**), the loop is
|
||||||
|
**halted immediately** — the runtime does *not* ask the model to troubleshoot,
|
||||||
|
because the fault is almost always upstream and self-recovery is slow,
|
||||||
|
unpredictable, and token-hungry. On halt it:
|
||||||
|
|
||||||
|
- collapses the request's `max_iterations` to the current iteration, so
|
||||||
|
fast-agent's own `_iteration > max_iterations` check terminates the turn
|
||||||
|
after the current tool result with **no further LLM call**;
|
||||||
|
- appends an honest, user-facing explanation to the returned turn (and sets
|
||||||
|
`stop_reason = endTurn`) so the client gets a real message instead of an
|
||||||
|
empty/truncated one;
|
||||||
|
- logs the offending tool, arguments, and result at WARNING (`event=loop_halt`
|
||||||
|
in `pallas.loop_guard`) so the upstream bug can be fixed durably; and
|
||||||
|
- increments `pallas_agent_loop_aborted_total{reason="repeat"}`.
|
||||||
|
|
||||||
|
This fires well before the `max_iterations` cap (a 3-round repeat halts within
|
||||||
|
~3 turns regardless of the configured ceiling), which is the point: the cap is
|
||||||
|
a backstop, the guard is the fast path. Set `loop_repeat_threshold: 0` on an
|
||||||
|
agent to disable it.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Metrics
|
## Metrics
|
||||||
|
|
||||||
Pallas exposes Prometheus metrics for scraping and alerting. One scrape target per Pallas deployment is sufficient — all agents run as coroutines in a single process under `asyncio.gather`, so metrics are process-global.
|
Pallas exposes Prometheus metrics for scraping and alerting. One scrape target per Pallas deployment is sufficient — all agents run as coroutines in a single process under `asyncio.gather`, so metrics are process-global.
|
||||||
@@ -570,6 +605,7 @@ scrape_configs:
|
|||||||
| `pallas_downstream_up` | gauge | `agent`, `server` | `1` when the named downstream MCP server passed the last `get_health` probe |
|
| `pallas_downstream_up` | gauge | `agent`, `server` | `1` when the named downstream MCP server passed the last `get_health` probe |
|
||||||
| `pallas_llm_provider_up` | gauge | `provider` | `1` when the active LLM provider passed its last preflight or runtime re-probe |
|
| `pallas_llm_provider_up` | gauge | `provider` | `1` when the active LLM provider passed its last preflight or runtime re-probe |
|
||||||
| `pallas_agent_health_status` | gauge | `agent` | Aggregate from the last `get_health`: `1`=ok, `0.5`=degraded, `0`=error |
|
| `pallas_agent_health_status` | gauge | `agent` | Aggregate from the last `get_health`: `1`=ok, `0.5`=degraded, `0`=error |
|
||||||
|
| `pallas_agent_loop_aborted_total` | counter | `agent`, `reason` | Agentic loops force-stopped by a runtime guard. `reason` ∈ `repeat` (identical-tool-call loop detected) |
|
||||||
|
|
||||||
Standard process metrics (RSS, CPU, GC, open FDs) are emitted by `prometheus-client`'s default collectors on the same endpoint.
|
Standard process metrics (RSS, CPU, GC, open FDs) are emitted by `prometheus-client`'s default collectors on the same endpoint.
|
||||||
|
|
||||||
@@ -616,6 +652,7 @@ pallas_llm_provider_up == 0
|
|||||||
| Agent error rate elevated | `rate(pallas_send_message_total{outcome="error"}[10m]) > 0.1` | >10% errors over 10 min |
|
| Agent error rate elevated | `rate(pallas_send_message_total{outcome="error"}[10m]) > 0.1` | >10% errors over 10 min |
|
||||||
| Latency regression | `histogram_quantile(0.95, sum by (agent, le) (rate(pallas_send_message_duration_seconds_bucket[10m]))) > 60` | p95 over 60 s |
|
| Latency regression | `histogram_quantile(0.95, sum by (agent, le) (rate(pallas_send_message_duration_seconds_bucket[10m]))) > 60` | p95 over 60 s |
|
||||||
| Token burn | `sum(rate(pallas_llm_tokens_total{kind="output"}[1h])) > N` | Set N to your budget |
|
| Token burn | `sum(rate(pallas_llm_tokens_total{kind="output"}[1h])) > N` | Set N to your budget |
|
||||||
|
| Agent loop halted | `increase(pallas_agent_loop_aborted_total[15m]) > 0` | A repeated-tool-call loop was force-stopped — investigate the upstream tool/data |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -645,6 +682,7 @@ This avoids the brittle pattern of inferring capabilities from model name substr
|
|||||||
| `pallas.registry` | `registry.py` | Starlette app serving `GET /.well-known/mcp/server.json` — agent catalogue built from config |
|
| `pallas.registry` | `registry.py` | Starlette app serving `GET /.well-known/mcp/server.json` — agent catalogue built from config |
|
||||||
| `pallas.multimodal_server` | `multimodal_server.py` | `MultimodalAgentMCPServer` — extends `AgentMCPServer` with image support, conversation history prompts, bearer token propagation |
|
| `pallas.multimodal_server` | `multimodal_server.py` | `MultimodalAgentMCPServer` — extends `AgentMCPServer` with image support, conversation history prompts, bearer token propagation |
|
||||||
| `pallas.health` | `health.py` | LLM provider preflight validation, downstream MCP server probing, `get_health` tool registration |
|
| `pallas.health` | `health.py` | LLM provider preflight validation, downstream MCP server probing, `get_health` tool registration |
|
||||||
|
| `pallas.loop_guard` | `loop_guard.py` | Per-request `ToolRunnerHooks` that halt the agentic loop on repeated-identical tool calls |
|
||||||
| `pallas.log` | `log.py` | JSON log configuration, third-party traceback capture, Rich-TUI-safe handler attachment |
|
| `pallas.log` | `log.py` | JSON log configuration, third-party traceback capture, Rich-TUI-safe handler attachment |
|
||||||
| `pallas._fastagent_patch` | `_fastagent_patch.py` | Monkey-patches fast-agent at import time: per-request bearer forwarding via `httpx.Auth`, diagnostic trace-capture wrappers around `send_request` / `session.call_tool` / `_execute_on_server` |
|
| `pallas._fastagent_patch` | `_fastagent_patch.py` | Monkey-patches fast-agent at import time: per-request bearer forwarding via `httpx.Auth`, diagnostic trace-capture wrappers around `send_request` / `session.call_tool` / `_execute_on_server` |
|
||||||
|
|
||||||
|
|||||||
@@ -90,6 +90,16 @@ class _StaticFieldsFilter(logging.Filter):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# Standard ``LogRecord`` attributes — everything else on ``record.__dict__`` is
|
||||||
|
# an ``extra={...}`` field a caller attached and wants serialized. ``message``
|
||||||
|
# and ``asctime`` are populated during formatting; ``taskName`` exists on 3.12+.
|
||||||
|
_STANDARD_LOGRECORD_KEYS = set(logging.makeLogRecord({}).__dict__) | {
|
||||||
|
"message",
|
||||||
|
"asctime",
|
||||||
|
"taskName",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class _JSONFormatter(logging.Formatter):
|
class _JSONFormatter(logging.Formatter):
|
||||||
"""Single-line JSON formatter compatible with Alloy's ``| json`` pipeline.
|
"""Single-line JSON formatter compatible with Alloy's ``| json`` pipeline.
|
||||||
|
|
||||||
@@ -123,6 +133,12 @@ class _JSONFormatter(logging.Formatter):
|
|||||||
"project": getattr(record, "project", _PROJECT),
|
"project": getattr(record, "project", _PROJECT),
|
||||||
"component": getattr(record, "component", _COMPONENT_CTX.get()),
|
"component": getattr(record, "component", _COMPONENT_CTX.get()),
|
||||||
}
|
}
|
||||||
|
# Merge caller-supplied ``extra={...}`` fields (anything on the record
|
||||||
|
# that isn't a standard LogRecord attribute or already emitted above).
|
||||||
|
for key, value in record.__dict__.items():
|
||||||
|
if key in _STANDARD_LOGRECORD_KEYS or key in payload:
|
||||||
|
continue
|
||||||
|
payload[key] = value
|
||||||
if record.exc_info:
|
if record.exc_info:
|
||||||
if not record.exc_text:
|
if not record.exc_text:
|
||||||
record.exc_text = self.formatException(record.exc_info)
|
record.exc_text = self.formatException(record.exc_info)
|
||||||
@@ -131,7 +147,8 @@ class _JSONFormatter(logging.Formatter):
|
|||||||
payload["traceback"] = record.exc_text
|
payload["traceback"] = record.exc_text
|
||||||
if record.stack_info:
|
if record.stack_info:
|
||||||
payload["stack"] = self.formatStack(record.stack_info)
|
payload["stack"] = self.formatStack(record.stack_info)
|
||||||
return json.dumps(payload)
|
# default=str keeps a non-serializable extra value from crashing logging.
|
||||||
|
return json.dumps(payload, default=str)
|
||||||
|
|
||||||
|
|
||||||
class _HealthAccessFilter(logging.Filter):
|
class _HealthAccessFilter(logging.Filter):
|
||||||
|
|||||||
231
pallas/loop_guard.py
Normal file
231
pallas/loop_guard.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
"""Runaway-loop detection for the agentic tool loop.
|
||||||
|
|
||||||
|
A small model occasionally gets stuck emitting the *identical* tool call
|
||||||
|
every iteration — typically because an upstream MCP server returned a
|
||||||
|
contradictory or malformed result the model keeps trying to reconcile.
|
||||||
|
Left alone the loop burns LLM turns and context until the client times
|
||||||
|
out and the user sees ``empty_response``.
|
||||||
|
|
||||||
|
This guard installs ``ToolRunnerHooks`` on the request-scoped agent that
|
||||||
|
track a per-turn signature of ``(tool, normalized_args) -> result_hash``.
|
||||||
|
When the same signature repeats ``threshold`` times consecutively the
|
||||||
|
loop is halted **immediately** — we don't ask the model to troubleshoot,
|
||||||
|
because the fault is almost always upstream and self-recovery is slow,
|
||||||
|
unpredictable, and token-hungry. Instead we:
|
||||||
|
|
||||||
|
* force fast-agent's own ``max_iterations`` termination path so the turn
|
||||||
|
ends after the current tool result with a partial answer rather than a
|
||||||
|
client timeout;
|
||||||
|
* replace the dangling turn with an honest, user-facing explanation;
|
||||||
|
* log the offending tool, arguments, and result with full detail so the
|
||||||
|
real (upstream) bug can be fixed durably; and
|
||||||
|
* increment ``pallas_agent_loop_aborted_total`` for alerting.
|
||||||
|
|
||||||
|
The hooks compose after any already-installed hooks (e.g. the assistant
|
||||||
|
stream) via the same merge strategy ``assistant_stream`` uses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fast_agent.agents.tool_runner import ToolRunnerHooks
|
||||||
|
from fast_agent.types import PromptMessageExtended
|
||||||
|
from fast_agent.types.llm_stop_reason import LlmStopReason
|
||||||
|
from mcp.types import TextContent
|
||||||
|
|
||||||
|
from pallas import metrics as _pallas_metrics
|
||||||
|
from pallas.assistant_stream import _merge_hooks, _split_server
|
||||||
|
from pallas.progress import format_args_preview, format_result_preview
|
||||||
|
|
||||||
|
logger = logging.getLogger("pallas.loop_guard")
|
||||||
|
|
||||||
|
DEFAULT_THRESHOLD = 3
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_args(arguments: Any) -> str:
|
||||||
|
"""Deterministically serialize tool arguments for signature comparison."""
|
||||||
|
try:
|
||||||
|
return json.dumps(arguments, sort_keys=True, default=str, ensure_ascii=False)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return repr(arguments)
|
||||||
|
|
||||||
|
|
||||||
|
def _result_hash(result: Any) -> str:
|
||||||
|
"""Stable hash of a tool result's content for repeat detection."""
|
||||||
|
parts: list[str] = []
|
||||||
|
for block in getattr(result, "content", None) or []:
|
||||||
|
text = getattr(block, "text", None)
|
||||||
|
parts.append(text if isinstance(text, str) else repr(block))
|
||||||
|
parts.append(f"isError={bool(getattr(result, 'isError', False))}")
|
||||||
|
digest = hashlib.sha256("\x00".join(parts).encode("utf-8", "replace"))
|
||||||
|
return digest.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class LoopGuard:
|
||||||
|
"""Per-request consecutive-identical-tool-call detector."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, *, agent_name: str, conversation_id: str | None, threshold: int
|
||||||
|
) -> None:
|
||||||
|
self._agent_name = agent_name
|
||||||
|
self._conversation_id = conversation_id
|
||||||
|
self._threshold = threshold
|
||||||
|
# call_id -> (tool_name, normalized_args, raw_arguments), staged by
|
||||||
|
# before_tool_call and consumed by the matching after_tool_call.
|
||||||
|
self._pending: dict[str, tuple[str | None, str, Any]] = {}
|
||||||
|
self._last_signature: str | None = None
|
||||||
|
self._repeat_count = 0
|
||||||
|
self._halted = False
|
||||||
|
|
||||||
|
def as_before_tool_call_hook(self):
|
||||||
|
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
for call_id, call in (message.tool_calls or {}).items():
|
||||||
|
params = getattr(call, "params", None)
|
||||||
|
name = getattr(params, "name", None)
|
||||||
|
arguments = getattr(params, "arguments", None)
|
||||||
|
self._pending[call_id] = (name, _normalize_args(arguments), arguments)
|
||||||
|
|
||||||
|
return _hook
|
||||||
|
|
||||||
|
def as_after_tool_call_hook(self):
|
||||||
|
async def _hook(runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
if self._halted:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self._evaluate(runner, message)
|
||||||
|
except Exception: # never let the guard break a live turn
|
||||||
|
logger.warning(
|
||||||
|
"loop_guard evaluation failed",
|
||||||
|
exc_info=True,
|
||||||
|
extra={
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return _hook
|
||||||
|
|
||||||
|
def _evaluate(self, runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
results = message.tool_results or {}
|
||||||
|
if not results:
|
||||||
|
return
|
||||||
|
|
||||||
|
components: list[tuple[str, str, str]] = []
|
||||||
|
names: list[str] = []
|
||||||
|
raw_args: list[Any] = []
|
||||||
|
for call_id, result in results.items():
|
||||||
|
name, args_sig, arguments = self._pending.pop(call_id, (None, "null", None))
|
||||||
|
components.append((name or "", args_sig, _result_hash(result)))
|
||||||
|
if name:
|
||||||
|
names.append(name)
|
||||||
|
raw_args.append(arguments)
|
||||||
|
components.sort()
|
||||||
|
signature = "|".join(f"{n}:{a}:{r}" for n, a, r in components)
|
||||||
|
|
||||||
|
if signature == self._last_signature:
|
||||||
|
self._repeat_count += 1
|
||||||
|
else:
|
||||||
|
self._last_signature = signature
|
||||||
|
self._repeat_count = 1
|
||||||
|
|
||||||
|
if self._repeat_count >= self._threshold:
|
||||||
|
tool_label = ", ".join(dict.fromkeys(names)) or "unknown"
|
||||||
|
args_preview = format_args_preview(raw_args[0]) if raw_args else ""
|
||||||
|
self._halt(runner, results, tool_label, args_preview)
|
||||||
|
|
||||||
|
def _halt(
|
||||||
|
self,
|
||||||
|
runner: Any,
|
||||||
|
results: dict[str, Any],
|
||||||
|
tool_label: str,
|
||||||
|
args_preview: str,
|
||||||
|
) -> None:
|
||||||
|
self._halted = True
|
||||||
|
|
||||||
|
first = next(iter(results.values()), None)
|
||||||
|
result_preview = (
|
||||||
|
format_result_preview(list(getattr(first, "content", []) or []))
|
||||||
|
if first is not None
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force fast-agent's own termination check (tool_runner: `_iteration >
|
||||||
|
# max_iterations`) to fire after this tool result — no further LLM
|
||||||
|
# call, partial answer returned instead of a client timeout.
|
||||||
|
params = getattr(runner, "request_params", None)
|
||||||
|
iteration = getattr(runner, "iteration", 0)
|
||||||
|
if params is not None:
|
||||||
|
params.max_iterations = iteration
|
||||||
|
|
||||||
|
# Replace the dangling tool-use turn with an honest final message so
|
||||||
|
# the client gets an explanation rather than an empty/truncated turn.
|
||||||
|
note = (
|
||||||
|
f"Halted: the '{tool_label}' tool returned an identical result "
|
||||||
|
f"{self._repeat_count} times in a row, so the agent was looping "
|
||||||
|
"without making progress. This is usually an upstream data or "
|
||||||
|
"tool-server issue rather than a problem with the request. "
|
||||||
|
"Stopping early to avoid a runaway loop."
|
||||||
|
)
|
||||||
|
last = getattr(runner, "last_message", None)
|
||||||
|
if last is not None:
|
||||||
|
if last.content is None:
|
||||||
|
last.content = []
|
||||||
|
last.content.append(TextContent(type="text", text=note))
|
||||||
|
last.stop_reason = LlmStopReason.END_TURN
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"agentic loop halted: identical tool call repeated",
|
||||||
|
extra={
|
||||||
|
"event": "loop_halt",
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
"tool": tool_label,
|
||||||
|
"server": _split_server(tool_label),
|
||||||
|
"repeat_count": self._repeat_count,
|
||||||
|
"threshold": self._threshold,
|
||||||
|
"iteration": iteration,
|
||||||
|
"arguments_preview": args_preview,
|
||||||
|
"result_preview": result_preview,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
_pallas_metrics.record_loop_abort(self._agent_name, "repeat")
|
||||||
|
|
||||||
|
|
||||||
|
def install_for_request(
|
||||||
|
agent: Any,
|
||||||
|
*,
|
||||||
|
agent_name: str,
|
||||||
|
conversation_id: str | None,
|
||||||
|
threshold: int = DEFAULT_THRESHOLD,
|
||||||
|
):
|
||||||
|
"""Install the loop guard on a request-scoped agent instance.
|
||||||
|
|
||||||
|
Returns a no-arg ``restore`` callable that reinstates the agent's prior
|
||||||
|
``tool_runner_hooks`` — call it in a ``finally``. A non-positive
|
||||||
|
``threshold`` disables the guard (returns a no-op restore).
|
||||||
|
"""
|
||||||
|
if threshold is None or threshold < 1:
|
||||||
|
return lambda: None
|
||||||
|
|
||||||
|
guard = LoopGuard(
|
||||||
|
agent_name=agent_name,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
threshold=threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
extra = ToolRunnerHooks(
|
||||||
|
before_tool_call=guard.as_before_tool_call_hook(),
|
||||||
|
after_tool_call=guard.as_after_tool_call_hook(),
|
||||||
|
)
|
||||||
|
|
||||||
|
previous = getattr(agent, "tool_runner_hooks", None)
|
||||||
|
agent.tool_runner_hooks = _merge_hooks(previous, extra)
|
||||||
|
|
||||||
|
def restore() -> None:
|
||||||
|
agent.tool_runner_hooks = previous
|
||||||
|
|
||||||
|
return restore
|
||||||
@@ -131,10 +131,22 @@ agent_health_status = Gauge(
|
|||||||
registry=REGISTRY,
|
registry=REGISTRY,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_loop_aborted_total = Counter(
|
||||||
|
"pallas_agent_loop_aborted_total",
|
||||||
|
"Agentic loops force-stopped by a runtime guard",
|
||||||
|
labelnames=["agent", "reason"], # reason: repeat
|
||||||
|
registry=REGISTRY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def record_loop_abort(agent: str, reason: str) -> None:
|
||||||
|
"""Record one agentic loop aborted by a runtime guard."""
|
||||||
|
agent_loop_aborted_total.labels(agent=agent, reason=reason).inc()
|
||||||
|
|
||||||
|
|
||||||
def set_agent_info(agents: dict[str, dict]) -> None:
|
def set_agent_info(agents: dict[str, dict]) -> None:
|
||||||
"""Record the deployment's configured agents (called once at startup)."""
|
"""Record the deployment's configured agents (called once at startup)."""
|
||||||
for name, agent in agents.items():
|
for name, agent in agents.items():
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from fast_agent.mcp.server import AgentMCPServer
|
|||||||
from fast_agent.types import PromptMessageExtended, RequestParams
|
from fast_agent.types import PromptMessageExtended, RequestParams
|
||||||
|
|
||||||
from pallas.assistant_stream import install_for_request as _install_assistant_stream
|
from pallas.assistant_stream import install_for_request as _install_assistant_stream
|
||||||
|
from pallas.loop_guard import install_for_request as _install_loop_guard
|
||||||
from pallas.progress import EnrichedMCPToolProgressManager
|
from pallas.progress import EnrichedMCPToolProgressManager
|
||||||
from pallas import metrics as _pallas_metrics
|
from pallas import metrics as _pallas_metrics
|
||||||
from fastmcp import Context as MCPContext
|
from fastmcp import Context as MCPContext
|
||||||
@@ -229,12 +230,25 @@ class MultimodalAgentMCPServer(AgentMCPServer):
|
|||||||
# in earlier loop iterations stays trapped inside fast-agent's
|
# in earlier loop iterations stays trapped inside fast-agent's
|
||||||
# ``message_history`` and the user sees a spinner that ends with
|
# ``message_history`` and the user sees a spinner that ends with
|
||||||
# a thin wrap-up sentence.
|
# a thin wrap-up sentence.
|
||||||
restore_hooks = _install_assistant_stream(
|
restore_stream = _install_assistant_stream(
|
||||||
agent,
|
agent,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
)
|
)
|
||||||
|
# Compose the loop guard on top: it halts the agentic loop the
|
||||||
|
# moment a tool call repeats with an identical result, before
|
||||||
|
# the turn runs to the iteration cap or client timeout.
|
||||||
|
restore_guard = _install_loop_guard(
|
||||||
|
agent,
|
||||||
|
agent_name=agent_name,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
threshold=self._request_limits.get("loop_repeat_threshold", 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
def restore_hooks() -> None:
|
||||||
|
restore_guard()
|
||||||
|
restore_stream()
|
||||||
try:
|
try:
|
||||||
# Seed the freshly-created instance's message_history from the
|
# Seed the freshly-created instance's message_history from the
|
||||||
# caller-supplied history so the agent sees the full
|
# caller-supplied history so the agent sees the full
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ def _build_agents_table(config: dict) -> dict[str, dict]:
|
|||||||
"max_iterations": agent.get("max_iterations"),
|
"max_iterations": agent.get("max_iterations"),
|
||||||
"streaming_timeout": agent.get("streaming_timeout"),
|
"streaming_timeout": agent.get("streaming_timeout"),
|
||||||
"turn_timeout": agent.get("turn_timeout"),
|
"turn_timeout": agent.get("turn_timeout"),
|
||||||
|
"loop_repeat_threshold": agent.get("loop_repeat_threshold"),
|
||||||
}
|
}
|
||||||
for name, agent in config["agents"].items()
|
for name, agent in config["agents"].items()
|
||||||
}
|
}
|
||||||
@@ -264,7 +265,12 @@ async def _start_agent(name: str, agents: dict[str, dict]) -> None:
|
|||||||
# and the LLM sees exactly what Daedalus asks it to see.
|
# and the LLM sees exactly what Daedalus asks it to see.
|
||||||
request_limits = {
|
request_limits = {
|
||||||
k: entry[k]
|
k: entry[k]
|
||||||
for k in ("max_iterations", "streaming_timeout", "turn_timeout")
|
for k in (
|
||||||
|
"max_iterations",
|
||||||
|
"streaming_timeout",
|
||||||
|
"turn_timeout",
|
||||||
|
"loop_repeat_threshold",
|
||||||
|
)
|
||||||
if entry.get(k) is not None
|
if entry.get(k) is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
55
tests/test_log.py
Normal file
55
tests/test_log.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Tests for ``pallas.log._JSONFormatter``.
|
||||||
|
|
||||||
|
The formatter must serialize caller-supplied ``extra={...}`` fields (the
|
||||||
|
loop guard and other diagnostics rely on this) while never emitting the
|
||||||
|
internal ``LogRecord`` bookkeeping attributes.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from pallas.log import _JSONFormatter
|
||||||
|
|
||||||
|
|
||||||
|
def _format(msg: str, **extra) -> dict:
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="pallas.test",
|
||||||
|
level=logging.WARNING,
|
||||||
|
pathname=__file__,
|
||||||
|
lineno=1,
|
||||||
|
msg=msg,
|
||||||
|
args=(),
|
||||||
|
exc_info=None,
|
||||||
|
)
|
||||||
|
for key, value in extra.items():
|
||||||
|
setattr(record, key, value)
|
||||||
|
return json.loads(_JSONFormatter().format(record))
|
||||||
|
|
||||||
|
|
||||||
|
def test_extra_fields_are_serialized():
|
||||||
|
out = _format(
|
||||||
|
"agentic loop halted",
|
||||||
|
event="loop_halt",
|
||||||
|
tool="kairos-update_task",
|
||||||
|
repeat_count=3,
|
||||||
|
result_preview="COMPLETED but 0%",
|
||||||
|
)
|
||||||
|
assert out["message"] == "agentic loop halted"
|
||||||
|
assert out["event"] == "loop_halt"
|
||||||
|
assert out["tool"] == "kairos-update_task"
|
||||||
|
assert out["repeat_count"] == 3
|
||||||
|
assert out["result_preview"] == "COMPLETED but 0%"
|
||||||
|
|
||||||
|
|
||||||
|
def test_standard_attributes_are_not_leaked():
|
||||||
|
out = _format("plain message")
|
||||||
|
for noise in ("msg", "args", "levelno", "pathname", "lineno", "funcName"):
|
||||||
|
assert noise not in out
|
||||||
|
assert out["level"] == "WARNING"
|
||||||
|
assert out["logger"] == "pallas.test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_serializable_extra_does_not_crash():
|
||||||
|
out = _format("with object", obj=object())
|
||||||
|
assert "obj" in out # coerced via default=str, not dropped or raised
|
||||||
190
tests/test_loop_guard.py
Normal file
190
tests/test_loop_guard.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
"""Tests for ``pallas.loop_guard``.
|
||||||
|
|
||||||
|
Drives the ``before_tool_call`` / ``after_tool_call`` hooks with handcrafted
|
||||||
|
``PromptMessageExtended`` objects against a fake ToolRunner and asserts the
|
||||||
|
halt behaviour: the runner's ``max_iterations`` is collapsed to the current
|
||||||
|
iteration (so fast-agent terminates on its next check), the dangling turn is
|
||||||
|
annotated with an explanation, and the abort metric is incremented.
|
||||||
|
|
||||||
|
No fast-agent runtime is involved — the hooks are pure async functions.
|
||||||
|
Uses ``asyncio.run`` directly to match the convention in the other test
|
||||||
|
modules (pallas has no pytest-asyncio dependency).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fast_agent.types import PromptMessageExtended
|
||||||
|
from fast_agent.types.llm_stop_reason import LlmStopReason
|
||||||
|
from mcp.types import (
|
||||||
|
CallToolRequest,
|
||||||
|
CallToolRequestParams,
|
||||||
|
CallToolResult,
|
||||||
|
TextContent,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pallas import metrics as _pallas_metrics
|
||||||
|
from pallas.loop_guard import LoopGuard, install_for_request
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call(name: str, arguments: dict | None = None) -> CallToolRequest:
|
||||||
|
return CallToolRequest(
|
||||||
|
method="tools/call",
|
||||||
|
params=CallToolRequestParams(name=name, arguments=arguments or {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_result(text: str = "ok", *, is_error: bool = False) -> CallToolResult:
|
||||||
|
return CallToolResult(
|
||||||
|
content=[TextContent(type="text", text=text)], isError=is_error
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _request(name: str, arguments: dict, call_id: str = "toolu_1"):
|
||||||
|
return PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[],
|
||||||
|
tool_calls={call_id: _tool_call(name, arguments)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _result(text: str, call_id: str = "toolu_1"):
|
||||||
|
return PromptMessageExtended(
|
||||||
|
role="user", content=[], tool_results={call_id: _tool_result(text)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRunner:
|
||||||
|
def __init__(self, *, iteration: int = 5, max_iterations: int = 30) -> None:
|
||||||
|
self.request_params = SimpleNamespace(max_iterations=max_iterations)
|
||||||
|
self.iteration = iteration
|
||||||
|
self.last_message = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[],
|
||||||
|
stop_reason=LlmStopReason.TOOL_USE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _drive(guard: LoopGuard, runner: _FakeRunner, name, args, result) -> None:
|
||||||
|
"""Run one tool round through the guard's hooks."""
|
||||||
|
before = guard.as_before_tool_call_hook()
|
||||||
|
after = guard.as_after_tool_call_hook()
|
||||||
|
await before(runner, _request(name, args))
|
||||||
|
await after(runner, _result(result))
|
||||||
|
|
||||||
|
|
||||||
|
def _abort_count(agent: str) -> float:
|
||||||
|
return (
|
||||||
|
_pallas_metrics.agent_loop_aborted_total.labels(agent=agent, reason="repeat")
|
||||||
|
._value.get()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _last_text(runner: _FakeRunner) -> str:
|
||||||
|
return "".join(
|
||||||
|
b.text for b in (runner.last_message.content or []) if hasattr(b, "text")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_halts_on_third_identical_round():
|
||||||
|
guard = LoopGuard(agent_name="shawn", conversation_id="c1", threshold=3)
|
||||||
|
runner = _FakeRunner(iteration=12, max_iterations=30)
|
||||||
|
before = _abort_count("shawn")
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
for _ in range(2):
|
||||||
|
await _drive(guard, runner, "kairos-update_task", {"task_id": 494}, "same")
|
||||||
|
# not yet halted after rounds 1 and 2
|
||||||
|
assert runner.request_params.max_iterations == 30
|
||||||
|
assert runner.last_message.stop_reason == LlmStopReason.TOOL_USE
|
||||||
|
# third identical round trips the guard
|
||||||
|
await _drive(guard, runner, "kairos-update_task", {"task_id": 494}, "same")
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
|
||||||
|
# max_iterations collapsed to the current iteration -> fast-agent stops
|
||||||
|
# on its next `_iteration > max_iterations` check, no further LLM call.
|
||||||
|
assert runner.request_params.max_iterations == 12
|
||||||
|
assert runner.last_message.stop_reason == LlmStopReason.END_TURN
|
||||||
|
assert "Halted" in _last_text(runner)
|
||||||
|
assert "kairos-update_task" in _last_text(runner)
|
||||||
|
assert _abort_count("shawn") == before + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_halt_when_result_changes():
|
||||||
|
guard = LoopGuard(agent_name="a1", conversation_id=None, threshold=3)
|
||||||
|
runner = _FakeRunner()
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
for i in range(6):
|
||||||
|
await _drive(
|
||||||
|
guard, runner, "kairos-update_task", {"task_id": 494}, f"r{i}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
assert runner.request_params.max_iterations == 30
|
||||||
|
assert runner.last_message.stop_reason == LlmStopReason.TOOL_USE
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_halt_when_args_change():
|
||||||
|
guard = LoopGuard(agent_name="a2", conversation_id=None, threshold=3)
|
||||||
|
runner = _FakeRunner()
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
for i in range(6):
|
||||||
|
await _drive(guard, runner, "kairos-update_task", {"task_id": i}, "same")
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
assert runner.request_params.max_iterations == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_threshold_respected():
|
||||||
|
guard = LoopGuard(agent_name="a3", conversation_id=None, threshold=5)
|
||||||
|
runner = _FakeRunner()
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
for _ in range(4):
|
||||||
|
await _drive(guard, runner, "t", {"x": 1}, "same")
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
# 4 identical rounds, threshold 5 -> still running
|
||||||
|
assert runner.request_params.max_iterations == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_halt_fires_once():
|
||||||
|
guard = LoopGuard(agent_name="a4", conversation_id=None, threshold=3)
|
||||||
|
runner = _FakeRunner()
|
||||||
|
before = _abort_count("a4")
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
for _ in range(6):
|
||||||
|
await _drive(guard, runner, "t", {"x": 1}, "same")
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
assert _abort_count("a4") == before + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_install_disabled_with_nonpositive_threshold():
|
||||||
|
agent = SimpleNamespace(tool_runner_hooks="sentinel")
|
||||||
|
restore = install_for_request(
|
||||||
|
agent, agent_name="a", conversation_id=None, threshold=0
|
||||||
|
)
|
||||||
|
assert agent.tool_runner_hooks == "sentinel" # untouched
|
||||||
|
restore() # no-op, must not raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_install_merges_and_restores():
|
||||||
|
agent = SimpleNamespace(tool_runner_hooks=None)
|
||||||
|
restore = install_for_request(
|
||||||
|
agent, agent_name="a", conversation_id=None, threshold=3
|
||||||
|
)
|
||||||
|
assert agent.tool_runner_hooks is not None
|
||||||
|
assert agent.tool_runner_hooks.after_tool_call is not None
|
||||||
|
restore()
|
||||||
|
assert agent.tool_runner_hooks is None
|
||||||
Reference in New Issue
Block a user