Compare commits
7 Commits
440f7fb60c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| ea37ab38c1 | |||
| e29669304b | |||
| febd4b4062 | |||
| 0b0a8f37a4 | |||
| d387650bf2 | |||
| 63ced9ba2b | |||
| 8a5046fef0 |
@@ -193,6 +193,8 @@ agents:
|
|||||||
| `agents.<name>.title` | no | Display name in registry. Default: `name.title()` |
|
| `agents.<name>.title` | no | Display name in registry. Default: `name.title()` |
|
||||||
| `agents.<name>.description` | no | Description in registry |
|
| `agents.<name>.description` | no | Description in registry |
|
||||||
| `agents.<name>.depends_on` | no | List of agent names that must start and become ready before this agent |
|
| `agents.<name>.depends_on` | no | List of agent names that must start and become ready before this agent |
|
||||||
|
| `agents.<name>.max_iterations` | no | Hard cap on agentic-loop turns per `send_message`. Default: `15`. fast-agent returns a partial answer once exceeded |
|
||||||
|
| `agents.<name>.loop_repeat_threshold` | no | Halt the loop after this many consecutive identical `(tool, args) → result` rounds. Default: `3`. `0` disables the guard |
|
||||||
|
|
||||||
### `fastagent.config.yaml` Extensions
|
### `fastagent.config.yaml` Extensions
|
||||||
|
|
||||||
@@ -530,6 +532,39 @@ Registered on each agent's MCP server. Checks:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Loop Guard
|
||||||
|
|
||||||
|
A small model occasionally gets stuck emitting the *identical* tool call every
|
||||||
|
iteration — usually because an upstream MCP server returned a contradictory or
|
||||||
|
malformed result it keeps trying to reconcile. Left alone the loop burns LLM
|
||||||
|
turns and context until the client times out and the user sees
|
||||||
|
`empty_response`.
|
||||||
|
|
||||||
|
`pallas.loop_guard` installs per-request `ToolRunnerHooks` (composed on top of
|
||||||
|
the assistant-stream hooks) that track a rolling signature of
|
||||||
|
`(tool, normalized_args) → result_hash`. When the same signature repeats
|
||||||
|
`loop_repeat_threshold` times consecutively (default **3**), the loop is
|
||||||
|
**halted immediately** — the runtime does *not* ask the model to troubleshoot,
|
||||||
|
because the fault is almost always upstream and self-recovery is slow,
|
||||||
|
unpredictable, and token-hungry. On halt it:
|
||||||
|
|
||||||
|
- collapses the request's `max_iterations` to the current iteration, so
|
||||||
|
fast-agent's own `_iteration > max_iterations` check terminates the turn
|
||||||
|
after the current tool result with **no further LLM call**;
|
||||||
|
- appends an honest, user-facing explanation to the returned turn (and sets
|
||||||
|
`stop_reason = endTurn`) so the client gets a real message instead of an
|
||||||
|
empty/truncated one;
|
||||||
|
- logs the offending tool, arguments, and result at WARNING (`event=loop_halt`
|
||||||
|
in `pallas.loop_guard`) so the upstream bug can be fixed durably; and
|
||||||
|
- increments `pallas_agent_loop_aborted_total{reason="repeat"}`.
|
||||||
|
|
||||||
|
This fires well before the `max_iterations` cap (a 3-round repeat halts within
|
||||||
|
~3 turns regardless of the configured ceiling), which is the point: the cap is
|
||||||
|
a backstop, the guard is the fast path. Set `loop_repeat_threshold: 0` on an
|
||||||
|
agent to disable it.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Metrics
|
## Metrics
|
||||||
|
|
||||||
Pallas exposes Prometheus metrics for scraping and alerting. One scrape target per Pallas deployment is sufficient — all agents run as coroutines in a single process under `asyncio.gather`, so metrics are process-global.
|
Pallas exposes Prometheus metrics for scraping and alerting. One scrape target per Pallas deployment is sufficient — all agents run as coroutines in a single process under `asyncio.gather`, so metrics are process-global.
|
||||||
@@ -570,6 +605,7 @@ scrape_configs:
|
|||||||
| `pallas_downstream_up` | gauge | `agent`, `server` | `1` when the named downstream MCP server passed the last `get_health` probe |
|
| `pallas_downstream_up` | gauge | `agent`, `server` | `1` when the named downstream MCP server passed the last `get_health` probe |
|
||||||
| `pallas_llm_provider_up` | gauge | `provider` | `1` when the active LLM provider passed its last preflight or runtime re-probe |
|
| `pallas_llm_provider_up` | gauge | `provider` | `1` when the active LLM provider passed its last preflight or runtime re-probe |
|
||||||
| `pallas_agent_health_status` | gauge | `agent` | Aggregate from the last `get_health`: `1`=ok, `0.5`=degraded, `0`=error |
|
| `pallas_agent_health_status` | gauge | `agent` | Aggregate from the last `get_health`: `1`=ok, `0.5`=degraded, `0`=error |
|
||||||
|
| `pallas_agent_loop_aborted_total` | counter | `agent`, `reason` | Agentic loops force-stopped by a runtime guard. `reason` ∈ `repeat` (identical-tool-call loop detected) |
|
||||||
|
|
||||||
Standard process metrics (RSS, CPU, GC, open FDs) are emitted by `prometheus-client`'s default collectors on the same endpoint.
|
Standard process metrics (RSS, CPU, GC, open FDs) are emitted by `prometheus-client`'s default collectors on the same endpoint.
|
||||||
|
|
||||||
@@ -616,6 +652,7 @@ pallas_llm_provider_up == 0
|
|||||||
| Agent error rate elevated | `rate(pallas_send_message_total{outcome="error"}[10m]) > 0.1` | >10% errors over 10 min |
|
| Agent error rate elevated | `rate(pallas_send_message_total{outcome="error"}[10m]) > 0.1` | >10% errors over 10 min |
|
||||||
| Latency regression | `histogram_quantile(0.95, sum by (agent, le) (rate(pallas_send_message_duration_seconds_bucket[10m]))) > 60` | p95 over 60 s |
|
| Latency regression | `histogram_quantile(0.95, sum by (agent, le) (rate(pallas_send_message_duration_seconds_bucket[10m]))) > 60` | p95 over 60 s |
|
||||||
| Token burn | `sum(rate(pallas_llm_tokens_total{kind="output"}[1h])) > N` | Set N to your budget |
|
| Token burn | `sum(rate(pallas_llm_tokens_total{kind="output"}[1h])) > N` | Set N to your budget |
|
||||||
|
| Agent loop halted | `increase(pallas_agent_loop_aborted_total[15m]) > 0` | A repeated-tool-call loop was force-stopped — investigate the upstream tool/data |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -645,6 +682,7 @@ This avoids the brittle pattern of inferring capabilities from model name substr
|
|||||||
| `pallas.registry` | `registry.py` | Starlette app serving `GET /.well-known/mcp/server.json` — agent catalogue built from config |
|
| `pallas.registry` | `registry.py` | Starlette app serving `GET /.well-known/mcp/server.json` — agent catalogue built from config |
|
||||||
| `pallas.multimodal_server` | `multimodal_server.py` | `MultimodalAgentMCPServer` — extends `AgentMCPServer` with image support, conversation history prompts, bearer token propagation |
|
| `pallas.multimodal_server` | `multimodal_server.py` | `MultimodalAgentMCPServer` — extends `AgentMCPServer` with image support, conversation history prompts, bearer token propagation |
|
||||||
| `pallas.health` | `health.py` | LLM provider preflight validation, downstream MCP server probing, `get_health` tool registration |
|
| `pallas.health` | `health.py` | LLM provider preflight validation, downstream MCP server probing, `get_health` tool registration |
|
||||||
|
| `pallas.loop_guard` | `loop_guard.py` | Per-request `ToolRunnerHooks` that halt the agentic loop on repeated-identical tool calls |
|
||||||
| `pallas.log` | `log.py` | JSON log configuration, third-party traceback capture, Rich-TUI-safe handler attachment |
|
| `pallas.log` | `log.py` | JSON log configuration, third-party traceback capture, Rich-TUI-safe handler attachment |
|
||||||
| `pallas._fastagent_patch` | `_fastagent_patch.py` | Monkey-patches fast-agent at import time: per-request bearer forwarding via `httpx.Auth`, diagnostic trace-capture wrappers around `send_request` / `session.call_tool` / `_execute_on_server` |
|
| `pallas._fastagent_patch` | `_fastagent_patch.py` | Monkey-patches fast-agent at import time: per-request bearer forwarding via `httpx.Auth`, diagnostic trace-capture wrappers around `send_request` / `session.call_tool` / `_execute_on_server` |
|
||||||
|
|
||||||
|
|||||||
@@ -1,28 +1,56 @@
|
|||||||
"""fast-agent runtime patches — traceback capture on three opaque catch-sites.
|
"""fast-agent runtime patches — traceback capture + sub-agent activity relay.
|
||||||
|
|
||||||
fast-agent's transport layer catches every downstream-transport exception at
|
This module wires two unrelated extensions onto fast-agent at import time:
|
||||||
several nesting levels, logs only ``str(exc)`` (no ``exc_info=True``), and
|
|
||||||
re-raises. By the time the exception surfaces to the MCP tool result, the
|
|
||||||
traceback has been flattened to a bare string — the canonical symptom being
|
|
||||||
``"object NoneType can't be used in 'await' expression"`` with no stack
|
|
||||||
attached. This module wraps three of those catch-sites so Pallas emits
|
|
||||||
``logger.exception(...)`` with the full frame before fast-agent's swallowing
|
|
||||||
``except`` runs. Behaviour is otherwise unchanged: every wrapper re-raises
|
|
||||||
the exception it caught.
|
|
||||||
|
|
||||||
The three wrapped entry points are:
|
A. **Traceback capture** on three opaque catch-sites.
|
||||||
|
|
||||||
1. ``MCPAgentClientSession.send_request`` — the lowest-level send call;
|
fast-agent's transport layer catches every downstream-transport exception
|
||||||
2. ``MCPAgentClientSession.call_tool`` — the session-side wrapper around
|
at several nesting levels, logs only ``str(exc)`` (no ``exc_info=True``),
|
||||||
|
and re-raises. By the time the exception surfaces to the MCP tool
|
||||||
|
result, the traceback has been flattened to a bare string — the
|
||||||
|
canonical symptom being ``"object NoneType can't be used in 'await'
|
||||||
|
expression"`` with no stack attached. We wrap three of those
|
||||||
|
catch-sites so Pallas emits ``logger.exception(...)`` with the full
|
||||||
|
frame before fast-agent's swallowing ``except`` runs. Behaviour is
|
||||||
|
otherwise unchanged: every wrapper re-raises the exception it caught.
|
||||||
|
|
||||||
|
The three wrapped entry points are:
|
||||||
|
|
||||||
|
1. ``MCPAgentClientSession.send_request`` — the lowest-level send call;
|
||||||
|
2. ``MCPAgentClientSession.call_tool`` — the session-side wrapper around
|
||||||
meta merge, permission handling, progress callback factory, and the
|
meta merge, permission handling, progress callback factory, and the
|
||||||
send_request invocation itself;
|
send_request invocation itself;
|
||||||
3. ``MCPAggregator._execute_on_server`` — the aggregator's setup around
|
3. ``MCPAggregator._execute_on_server`` — the aggregator's setup around
|
||||||
the client call (server lookup, session factory, tracer span,
|
the client call (server lookup, session factory, tracer span,
|
||||||
``try_execute`` harness).
|
``try_execute`` harness).
|
||||||
|
|
||||||
Any one wrapper being triggered while the other two stay silent pinpoints
|
Any one wrapper being triggered while the other two stay silent
|
||||||
which frame is swallowing the exception, which is how we debug opaque
|
pinpoints which frame is swallowing the exception, which is how we
|
||||||
transport failures.
|
debug opaque transport failures.
|
||||||
|
|
||||||
|
B. **Sub-agent activity relay** — see ``pallas.mcp_subagent_relay``.
|
||||||
|
|
||||||
|
When Alan calls a peer agent (research/jeffrey/ann/...) the peer's
|
||||||
|
own ``assistant_chunk`` notifications stop at fast-agent's
|
||||||
|
``MCPAgentClientSession``. We wrap two surfaces to fix that:
|
||||||
|
|
||||||
|
1. ``MCPAggregator._create_session_factory`` — the per-server factory
|
||||||
|
that builds peer ``MCPAgentClientSession`` instances. We
|
||||||
|
post-process the returned session by attaching our
|
||||||
|
``logging_callback`` (replacing the default no-op) so peer
|
||||||
|
``notifications/message`` flow through our relay filter and get
|
||||||
|
re-emitted on the parent's MCP context.
|
||||||
|
2. ``MCPAggregator.call_tool`` — the entry point fast-agent's tool
|
||||||
|
runner reaches when invoking a peer. We push/pop the
|
||||||
|
``(tool_use_id, server_name)`` in-flight stack on the peer session
|
||||||
|
so the relay callback (which runs on the SDK reader task and
|
||||||
|
therefore can't see contextvars) can attribute incoming chunks to
|
||||||
|
the right parent tool_call.
|
||||||
|
|
||||||
|
The relay's upstream-send callable is set per-request from
|
||||||
|
``multimodal_server.py``'s ``register_agent_tools`` via
|
||||||
|
``set_upstream_send_for_request`` — Pallas knows the parent ``ctx``
|
||||||
|
only inside a tool call, so the upstream binding is request-scoped.
|
||||||
|
|
||||||
Historical note: this file used to also carry the bearer-forwarding patch
|
Historical note: this file used to also carry the bearer-forwarding patch
|
||||||
that propagated inbound ``Authorization`` headers to opted-in downstream
|
that propagated inbound ``Authorization`` headers to opted-in downstream
|
||||||
@@ -35,16 +63,105 @@ server's ``headers.Authorization`` is what fast-agent sends, full stop.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, Awaitable, Callable
|
||||||
|
|
||||||
from fast_agent.mcp import mcp_aggregator as _magg
|
from fast_agent.mcp import mcp_aggregator as _magg
|
||||||
from fast_agent.mcp import mcp_agent_client_session as _macs
|
from fast_agent.mcp import mcp_agent_client_session as _macs
|
||||||
|
|
||||||
from pallas import metrics as _pallas_metrics
|
from pallas import metrics as _pallas_metrics
|
||||||
|
from pallas import mcp_subagent_relay as _relay
|
||||||
|
|
||||||
logger = logging.getLogger("pallas.forward")
|
logger = logging.getLogger("pallas.forward")
|
||||||
_trace_logger = logging.getLogger("pallas.forward.trace")
|
_trace_logger = logging.getLogger("pallas.forward.trace")
|
||||||
|
_relay_logger = logging.getLogger("pallas.subagent_relay")
|
||||||
|
|
||||||
|
|
||||||
|
# Per-request upstream-send callable. Set by ``multimodal_server.py`` for
|
||||||
|
# the duration of one parent ``send_message`` MCP tool call so the relay's
|
||||||
|
# ``logging_callback`` can re-emit chunks on the parent's MCP session.
|
||||||
|
# A contextvar (rather than a plain module global) is required because
|
||||||
|
# multiple concurrent send_message requests share the same Pallas process
|
||||||
|
# and would otherwise cross-talk. Async tasks fast-agent spawns inherit
|
||||||
|
# the contextvar — including the SDK reader task it starts inside the
|
||||||
|
# session __aenter__, but not the session itself which lives across many
|
||||||
|
# requests. See ``_install_logging_callback_on_session`` for how we
|
||||||
|
# bridge the contextvar into the session-scoped reader.
|
||||||
|
_upstream_send: contextvars.ContextVar[
|
||||||
|
Callable[[dict[str, Any]], Awaitable[None]] | None
|
||||||
|
] = contextvars.ContextVar("_pallas_relay_upstream_send", default=None)
|
||||||
|
|
||||||
|
# Per-request max relay depth (overrides DEFAULT_MAX_RELAY_DEPTH).
|
||||||
|
_max_relay_depth: contextvars.ContextVar[int | None] = contextvars.ContextVar(
|
||||||
|
"_pallas_relay_max_depth", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_upstream_send_for_request(
|
||||||
|
send: Callable[[dict[str, Any]], Awaitable[None]] | None,
|
||||||
|
*,
|
||||||
|
max_depth: int | None = None,
|
||||||
|
):
|
||||||
|
"""Bind the relay's upstream-send for the current request task.
|
||||||
|
|
||||||
|
Returns a ``(send_token, depth_token)`` pair that callers should pass
|
||||||
|
to ``contextvars.ContextVar.reset`` when the request scope ends. Use
|
||||||
|
via the ``relay_upstream_scope`` context manager below for the common
|
||||||
|
enter-tool / restore-on-exit pattern.
|
||||||
|
"""
|
||||||
|
send_token = _upstream_send.set(send)
|
||||||
|
depth_token = _max_relay_depth.set(max_depth)
|
||||||
|
return send_token, depth_token
|
||||||
|
|
||||||
|
|
||||||
|
class relay_upstream_scope:
|
||||||
|
"""Context manager binding the relay's upstream-send for one request.
|
||||||
|
|
||||||
|
Use inside the per-request ``send_message`` body so the contextvar
|
||||||
|
is set on entry and restored on exit even if the tool call raises::
|
||||||
|
|
||||||
|
with relay_upstream_scope(send=ctx_send, max_depth=3):
|
||||||
|
await agent.send(payload, request_params=request_params)
|
||||||
|
|
||||||
|
Equivalent to a ``try/finally`` around ``set_upstream_send_for_request``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
send: Callable[[dict[str, Any]], Awaitable[None]] | None,
|
||||||
|
max_depth: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._send = send
|
||||||
|
self._max_depth = max_depth
|
||||||
|
self._tokens: tuple[Any, Any] | None = None
|
||||||
|
|
||||||
|
def __enter__(self) -> "relay_upstream_scope":
|
||||||
|
self._tokens = set_upstream_send_for_request(
|
||||||
|
self._send, max_depth=self._max_depth
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
if self._tokens is not None:
|
||||||
|
send_token, depth_token = self._tokens
|
||||||
|
_upstream_send.reset(send_token)
|
||||||
|
_max_relay_depth.reset(depth_token)
|
||||||
|
self._tokens = None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_upstream_send():
|
||||||
|
return _upstream_send.get()
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_max_relay_depth() -> int:
|
||||||
|
override = _max_relay_depth.get()
|
||||||
|
if isinstance(override, int) and override > 0:
|
||||||
|
return override
|
||||||
|
return _relay.DEFAULT_MAX_RELAY_DEPTH
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ── send_request traceback capture ───────────────────────────────────────────
|
# ── send_request traceback capture ───────────────────────────────────────────
|
||||||
@@ -160,14 +277,200 @@ def _patch_execute_on_server() -> None:
|
|||||||
logger.info("aggregator._execute_on_server traceback-capture patch installed")
|
logger.info("aggregator._execute_on_server traceback-capture patch installed")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Sub-agent activity relay: session factory + call_tool wrappers ───────────
|
||||||
|
#
|
||||||
|
# We hook two surfaces:
|
||||||
|
#
|
||||||
|
# * ``_create_session_factory`` (returns a closure that builds peer
|
||||||
|
# ``MCPAgentClientSession`` instances). We post-process each session
|
||||||
|
# right after construction to install our ``logging_callback``. At
|
||||||
|
# construction time we do NOT yet know the per-request upstream-send
|
||||||
|
# callable — that's only set inside ``send_message`` — so the callback
|
||||||
|
# reads it from the contextvar each time a notification arrives.
|
||||||
|
# * ``MCPAggregator.call_tool`` (called by fast-agent's tool runner with
|
||||||
|
# ``tool_use_id`` and the namespaced tool name). We split the server
|
||||||
|
# prefix off the tool name, look up the live session for that server,
|
||||||
|
# and push ``(tool_use_id, server_name)`` onto the session's in-flight
|
||||||
|
# stack so chunks arriving on the SDK reader task can be attributed.
|
||||||
|
#
|
||||||
|
# Both wrappers are idempotent (same ``_pallas_relay_patched`` attribute
|
||||||
|
# guard pattern) so ``install()`` is safe to call multiple times.
|
||||||
|
_original_create_session_factory = _magg.MCPAggregator._create_session_factory
|
||||||
|
|
||||||
|
|
||||||
|
def _create_session_factory_with_relay(self, server_name: str):
|
||||||
|
inner_factory = _original_create_session_factory(self, server_name)
|
||||||
|
|
||||||
|
def factory(read_stream, write_stream, read_timeout, **kwargs):
|
||||||
|
session = inner_factory(read_stream, write_stream, read_timeout, **kwargs)
|
||||||
|
try:
|
||||||
|
_install_logging_callback_on_session(session)
|
||||||
|
except Exception:
|
||||||
|
_relay_logger.warning(
|
||||||
|
"subagent_relay_install_callback_failed server=%s",
|
||||||
|
server_name,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
def _install_logging_callback_on_session(session: Any) -> None:
|
||||||
|
"""Attach the relay's ``logging_callback`` to a freshly built peer session.
|
||||||
|
|
||||||
|
fast-agent doesn't pass a ``logging_callback`` when constructing
|
||||||
|
sessions, so the SDK's ``ClientSession`` falls back to the default
|
||||||
|
no-op handler that drops every ``notifications/message``. We swap
|
||||||
|
that for one that filters on Pallas's ``assistant_stream`` logger
|
||||||
|
and re-emits matching chunks upstream — but only when an upstream
|
||||||
|
binding is active (i.e. inside a ``send_message`` request scope).
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _upstream_send(payload: dict[str, Any]) -> None:
|
||||||
|
upstream = _resolve_upstream_send()
|
||||||
|
if upstream is None:
|
||||||
|
# No active request — drop. This is the steady-state when
|
||||||
|
# peer sessions are pinged by fast-agent's connection
|
||||||
|
# manager outside any user tool call.
|
||||||
|
return
|
||||||
|
await upstream(payload)
|
||||||
|
|
||||||
|
callback = _relay.make_logging_callback(
|
||||||
|
upstream_send=_upstream_send,
|
||||||
|
session=session,
|
||||||
|
max_depth=_resolve_max_relay_depth(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# The SDK's ``ClientSession`` exposes the logging callback through a
|
||||||
|
# ``_logging_callback`` attribute on the underlying session object.
|
||||||
|
# Set both that attribute and (defensively) any public alias the
|
||||||
|
# session may expose, so future SDK changes that move the slot don't
|
||||||
|
# silently disable the relay.
|
||||||
|
setattr(session, "_logging_callback", callback)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_create_session_factory() -> None:
|
||||||
|
if getattr(
|
||||||
|
_magg.MCPAggregator._create_session_factory,
|
||||||
|
"_pallas_relay_patched",
|
||||||
|
False,
|
||||||
|
):
|
||||||
|
return
|
||||||
|
_create_session_factory_with_relay._pallas_relay_patched = True # type: ignore[attr-defined]
|
||||||
|
_magg.MCPAggregator._create_session_factory = _create_session_factory_with_relay # type: ignore[assignment]
|
||||||
|
logger.info("aggregator._create_session_factory subagent-relay patch installed")
|
||||||
|
|
||||||
|
|
||||||
|
_original_aggregator_call_tool = _magg.MCPAggregator.call_tool
|
||||||
|
|
||||||
|
|
||||||
|
async def _aggregator_call_tool_with_relay(self, *args, **kwargs):
|
||||||
|
"""Push/pop the relay's in-flight entry around each peer tool call.
|
||||||
|
|
||||||
|
We extract the tool name (``args[0]`` or ``kwargs['name']``) and the
|
||||||
|
LLM's ``tool_use_id`` if present. When we can identify a peer
|
||||||
|
server prefix on the tool name, we look up the active session for
|
||||||
|
that server in the persistent connection manager and push the
|
||||||
|
in-flight entry there. Without a session lookup hit (e.g. the
|
||||||
|
server uses bare tool names, or the connection isn't tracked yet)
|
||||||
|
we fall through to the unwrapped behaviour — losing nesting
|
||||||
|
metadata, never affecting correctness.
|
||||||
|
|
||||||
|
This wraps ``MCPAggregator.call_tool``, the entry point fast-agent's
|
||||||
|
tool runner uses. Note: this is **not** the same as
|
||||||
|
``_session.call_tool`` (the lower-level SDK call) wrapped above for
|
||||||
|
traceback capture; the names collide in fast-agent's API.
|
||||||
|
"""
|
||||||
|
name = args[0] if args else kwargs.get("name")
|
||||||
|
tool_use_id = kwargs.get("tool_use_id")
|
||||||
|
|
||||||
|
session = _resolve_peer_session(self, name) if isinstance(name, str) else None
|
||||||
|
|
||||||
|
pushed = False
|
||||||
|
if session is not None and isinstance(tool_use_id, str) and tool_use_id:
|
||||||
|
peer_name = _split_server_from_tool_name(name)
|
||||||
|
if peer_name:
|
||||||
|
try:
|
||||||
|
_relay.push_inflight(session, tool_use_id, peer_name)
|
||||||
|
pushed = True
|
||||||
|
except Exception:
|
||||||
|
_relay_logger.debug(
|
||||||
|
"subagent_relay_push_failed",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await _original_aggregator_call_tool(self, *args, **kwargs)
|
||||||
|
finally:
|
||||||
|
if pushed and session is not None and isinstance(tool_use_id, str):
|
||||||
|
try:
|
||||||
|
_relay.pop_inflight(session, tool_use_id)
|
||||||
|
except Exception:
|
||||||
|
_relay_logger.debug(
|
||||||
|
"subagent_relay_pop_failed",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_server_from_tool_name(tool_name: str) -> str | None:
|
||||||
|
"""Return the server prefix from a ``server-tool`` namespaced name.
|
||||||
|
|
||||||
|
fast-agent uses ``-`` as the SEP between server and tool name (see
|
||||||
|
``fast_agent.mcp.common.SEP``). Bare tool names (no separator)
|
||||||
|
have no peer prefix to attribute, so the relay is a no-op for them.
|
||||||
|
"""
|
||||||
|
if not tool_name or "-" not in tool_name:
|
||||||
|
return None
|
||||||
|
prefix, _ = tool_name.split("-", 1)
|
||||||
|
return prefix or None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_peer_session(aggregator, tool_name: str) -> Any | None:
|
||||||
|
"""Look up the live ``MCPAgentClientSession`` for the peer of this tool.
|
||||||
|
|
||||||
|
fast-agent stores per-server sessions on
|
||||||
|
``aggregator._persistent_connection_manager.running_servers``
|
||||||
|
(see ``mcp_connection_manager.py``). We probe that path defensively
|
||||||
|
— any miss returns ``None`` and the caller skips relay attribution.
|
||||||
|
"""
|
||||||
|
server_name = _split_server_from_tool_name(tool_name)
|
||||||
|
if not server_name:
|
||||||
|
return None
|
||||||
|
manager = getattr(aggregator, "_persistent_connection_manager", None)
|
||||||
|
if manager is None:
|
||||||
|
return None
|
||||||
|
running = getattr(manager, "running_servers", None)
|
||||||
|
if not isinstance(running, dict):
|
||||||
|
return None
|
||||||
|
server_conn = running.get(server_name)
|
||||||
|
if server_conn is None:
|
||||||
|
return None
|
||||||
|
# ServerConnection exposes ``session`` once initialised.
|
||||||
|
return getattr(server_conn, "session", None)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_aggregator_call_tool() -> None:
|
||||||
|
if getattr(
|
||||||
|
_magg.MCPAggregator.call_tool, "_pallas_relay_patched", False
|
||||||
|
):
|
||||||
|
return
|
||||||
|
_aggregator_call_tool_with_relay._pallas_relay_patched = True # type: ignore[attr-defined]
|
||||||
|
_magg.MCPAggregator.call_tool = _aggregator_call_tool_with_relay # type: ignore[assignment]
|
||||||
|
logger.info("aggregator.call_tool subagent-relay patch installed")
|
||||||
|
|
||||||
|
|
||||||
def install() -> None:
|
def install() -> None:
|
||||||
"""Install all three trace-capture wrappers.
|
"""Install all trace-capture and relay wrappers.
|
||||||
|
|
||||||
Each ``_patch_*`` helper is individually idempotent (guarded on a
|
Each ``_patch_*`` helper is individually idempotent (guarded on a
|
||||||
``_pallas_trace_patched`` attribute), so ``install()`` is safe to call
|
``_pallas_*_patched`` attribute), so ``install()`` is safe to call
|
||||||
repeatedly — e.g. from ``pallas/__init__.py`` on import + again from
|
repeatedly — e.g. from ``pallas/__init__.py`` on import + again from
|
||||||
a test harness — without stacking wrappers.
|
a test harness — without stacking wrappers.
|
||||||
"""
|
"""
|
||||||
_patch_send_request()
|
_patch_send_request()
|
||||||
_patch_session_call_tool()
|
_patch_session_call_tool()
|
||||||
_patch_execute_on_server()
|
_patch_execute_on_server()
|
||||||
|
_patch_create_session_factory()
|
||||||
|
_patch_aggregator_call_tool()
|
||||||
|
|
||||||
|
|||||||
385
pallas/assistant_stream.py
Normal file
385
pallas/assistant_stream.py
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
"""
|
||||||
|
Mid-turn assistant chunk streaming over MCP.
|
||||||
|
|
||||||
|
The MCP ``tools/call`` contract returns a single ``CallToolResult`` at end of
|
||||||
|
turn. When fast-agent runs a multi-iteration tool loop, every intermediate
|
||||||
|
assistant message — including substantive natural-language replies the LLM
|
||||||
|
emits before deciding to call a tool — stays inside fast-agent's
|
||||||
|
``message_history`` and never crosses the MCP boundary. The caller (Daedalus)
|
||||||
|
sees only the final ``last_text()``, which is often a thin wrap-up sentence
|
||||||
|
while the substantive answer was produced two iterations earlier.
|
||||||
|
|
||||||
|
This module fixes that by emitting one MCP ``notifications/message``
|
||||||
|
(``LoggingMessageNotification``) per assistant turn, carrying the structured
|
||||||
|
content blocks as opaque JSON in the ``data`` field. The transport is
|
||||||
|
already plumbed through StreamableHTTP — the SDK delivers it to the client's
|
||||||
|
``logging_callback`` — so no SDK changes are needed on either side.
|
||||||
|
|
||||||
|
Each chunk also carries an enriched ``tool_calls`` list so the consumer can
|
||||||
|
show *which* tools the model decided to invoke on that iteration (with arg
|
||||||
|
previews) and, once results arrive, attach a result preview + duration.
|
||||||
|
The shared progress manager (``progress.py``) computes the same previews
|
||||||
|
for the live ``notifications/progress`` ticker; we reuse those helpers.
|
||||||
|
|
||||||
|
The hooks install onto fast-agent's ``ToolRunnerHooks`` (``after_llm_call``,
|
||||||
|
``before_tool_call``, ``after_tool_call``). Pallas merges these with whatever
|
||||||
|
hooks the agent (or fast-agent itself) already has — see
|
||||||
|
``multimodal_server.py`` ``_install_assistant_stream_hook`` for the
|
||||||
|
merge-and-restore wiring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
from fast_agent.agents.tool_runner import ToolRunnerHooks
|
||||||
|
from fast_agent.types import PromptMessageExtended
|
||||||
|
from mcp.types import ImageContent, TextContent
|
||||||
|
|
||||||
|
from pallas.progress import format_args_preview, format_result_preview
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastmcp import Context as MCPContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Logger string the client uses to filter our notifications from any other
|
||||||
|
# log messages the server might emit. Stays narrow on both sides.
|
||||||
|
LOGGER_NAME = "pallas.assistant_stream"
|
||||||
|
|
||||||
|
# Discriminator inside ``data`` so this transport stays multiplexable later.
|
||||||
|
KIND = "assistant_chunk"
|
||||||
|
|
||||||
|
# Bumped if the on-wire payload shape changes incompatibly. Daedalus reads
|
||||||
|
# this and ignores chunks whose schema_version it doesn't recognise.
|
||||||
|
#
|
||||||
|
# v1: text + bare tool-call ids/names.
|
||||||
|
# v2: tool_calls enriched with server, arguments_preview, and (after the
|
||||||
|
# tool runs) result_preview / duration_ms / ok. Late-arriving "tool
|
||||||
|
# results" notifications use ``kind == "assistant_chunk_results"`` so
|
||||||
|
# a v1 client that ignores them still gets clean v2 behaviour for the
|
||||||
|
# primary chunk.
|
||||||
|
SCHEMA_VERSION = 2
|
||||||
|
|
||||||
|
# Discriminator for the late-arriving tool-results companion notification.
|
||||||
|
KIND_RESULTS = "assistant_chunk_results"
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantChunkEmitter:
|
||||||
|
"""Per-request emitter that ships each assistant turn over MCP.
|
||||||
|
|
||||||
|
One instance per ``send_message`` MCP call. Tracks an iteration counter
|
||||||
|
so consumers can order chunks within a single turn (and ignore stale
|
||||||
|
ones from an aborted retry), plus the in-flight tool-call timings so
|
||||||
|
``after_tool_call`` can attach result previews + durations to the
|
||||||
|
iteration the consumer already received.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ctx: "MCPContext",
|
||||||
|
*,
|
||||||
|
agent_name: str,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._ctx = ctx
|
||||||
|
self._agent_name = agent_name
|
||||||
|
self._conversation_id = conversation_id
|
||||||
|
self._iteration = 0
|
||||||
|
# Maps tool-call id → (iteration, start_time_perf_counter).
|
||||||
|
# Filled in ``before_tool_call`` and consumed in ``after_tool_call``
|
||||||
|
# so we can pair results with the iteration that produced the call.
|
||||||
|
self._pending_tool_calls: dict[str, tuple[int, float]] = {}
|
||||||
|
|
||||||
|
def as_after_llm_call_hook(self):
|
||||||
|
"""Return an async callable suitable for ``ToolRunnerHooks.after_llm_call``."""
|
||||||
|
|
||||||
|
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
await self._emit_assistant_chunk(message)
|
||||||
|
|
||||||
|
return _hook
|
||||||
|
|
||||||
|
def as_before_tool_call_hook(self):
|
||||||
|
"""Return an async callable suitable for ``ToolRunnerHooks.before_tool_call``.
|
||||||
|
|
||||||
|
Records the start time for each tool call in the request so
|
||||||
|
``after_tool_call`` can report a duration alongside the result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
now = time.perf_counter()
|
||||||
|
for call_id in (message.tool_calls or {}).keys():
|
||||||
|
# The iteration that *requested* these calls is whatever
|
||||||
|
# we just emitted — the LLM call hook ran immediately
|
||||||
|
# before this one for the same assistant turn.
|
||||||
|
self._pending_tool_calls[call_id] = (self._iteration, now)
|
||||||
|
|
||||||
|
return _hook
|
||||||
|
|
||||||
|
def as_after_tool_call_hook(self):
|
||||||
|
"""Return an async callable suitable for ``ToolRunnerHooks.after_tool_call``.
|
||||||
|
|
||||||
|
Emits a follow-up ``assistant_chunk_results`` notification for the
|
||||||
|
iteration that originally requested these tools, carrying result
|
||||||
|
previews + per-call durations so the consumer can enrich the
|
||||||
|
already-rendered iteration card.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
await self._emit_tool_results(message)
|
||||||
|
|
||||||
|
return _hook
|
||||||
|
|
||||||
|
async def _emit_assistant_chunk(self, message: PromptMessageExtended) -> None:
|
||||||
|
"""Serialize one assistant turn and ship it as a log notification.
|
||||||
|
|
||||||
|
Failures here must never break the agent turn — wrap everything and
|
||||||
|
log a warning. The user-facing consequence of a dropped chunk is
|
||||||
|
"the live bubble didn't update once"; the canonical final message
|
||||||
|
still arrives via the tools/call response.
|
||||||
|
"""
|
||||||
|
self._iteration += 1
|
||||||
|
try:
|
||||||
|
payload = self._build_payload(message)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"assistant_stream payload build failed",
|
||||||
|
exc_info=True,
|
||||||
|
extra={
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
"iteration": self._iteration,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip empty turns (e.g. a pure tool-only iteration with no text).
|
||||||
|
# The tool-call lifecycle is already surfaced via
|
||||||
|
# notifications/progress, so an empty assistant_chunk would just be
|
||||||
|
# noise on the wire and the live bubble.
|
||||||
|
if not payload["content"] and not payload["tool_calls"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._send(payload)
|
||||||
|
|
||||||
|
async def _emit_tool_results(self, message: PromptMessageExtended) -> None:
|
||||||
|
"""Ship a companion notification with result previews + durations."""
|
||||||
|
try:
|
||||||
|
results = message.tool_results or {}
|
||||||
|
if not results:
|
||||||
|
return
|
||||||
|
now = time.perf_counter()
|
||||||
|
entries: list[dict[str, Any]] = []
|
||||||
|
iteration_for_results: int | None = None
|
||||||
|
for call_id, result in results.items():
|
||||||
|
iteration, started = self._pending_tool_calls.pop(
|
||||||
|
call_id, (self._iteration, now)
|
||||||
|
)
|
||||||
|
# All results from one ``after_tool_call`` correspond to the
|
||||||
|
# same iteration (they're the response to that iteration's
|
||||||
|
# tool_calls), so keep the first non-None we see.
|
||||||
|
if iteration_for_results is None:
|
||||||
|
iteration_for_results = iteration
|
||||||
|
duration_ms = int(round((now - started) * 1000))
|
||||||
|
ok = not bool(getattr(result, "isError", False))
|
||||||
|
content_blocks = list(getattr(result, "content", []) or [])
|
||||||
|
entries.append(
|
||||||
|
{
|
||||||
|
"id": call_id,
|
||||||
|
"ok": ok,
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
"result_preview": format_result_preview(content_blocks),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not entries:
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"kind": KIND_RESULTS,
|
||||||
|
"schema_version": SCHEMA_VERSION,
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
# The iteration the *call* belonged to — the consumer keys
|
||||||
|
# by this to enrich the iteration card it already drew.
|
||||||
|
"iteration": iteration_for_results,
|
||||||
|
"results": entries,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"assistant_stream tool-results payload build failed",
|
||||||
|
exc_info=True,
|
||||||
|
extra={
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
"iteration": self._iteration,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._send(payload)
|
||||||
|
|
||||||
|
async def _send(self, payload: dict[str, Any]) -> None:
|
||||||
|
try:
|
||||||
|
session = self._ctx.session
|
||||||
|
related_request_id = getattr(self._ctx, "request_id", None)
|
||||||
|
await session.send_log_message(
|
||||||
|
level="info",
|
||||||
|
data=payload,
|
||||||
|
logger=LOGGER_NAME,
|
||||||
|
related_request_id=related_request_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"assistant_stream send_log_message failed",
|
||||||
|
exc_info=True,
|
||||||
|
extra={
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
"iteration": self._iteration,
|
||||||
|
"kind": payload.get("kind"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_payload(self, message: PromptMessageExtended) -> dict[str, Any]:
|
||||||
|
content_blocks: list[dict[str, Any]] = []
|
||||||
|
for block in message.content or []:
|
||||||
|
if isinstance(block, TextContent) and block.text:
|
||||||
|
content_blocks.append({"type": "text", "text": block.text})
|
||||||
|
elif isinstance(block, ImageContent):
|
||||||
|
# Images on assistant turns are unusual but legal. We send
|
||||||
|
# the raw base64 — same shape the final tools/call result
|
||||||
|
# already uses so the client can render it identically.
|
||||||
|
content_blocks.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"data": block.data,
|
||||||
|
"mime_type": block.mimeType,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Other block types (resources, embedded resources, etc.) are
|
||||||
|
# intentionally omitted from the live stream — they're rare on
|
||||||
|
# assistant turns and the canonical final message carries them.
|
||||||
|
|
||||||
|
tool_calls: list[dict[str, Any]] = []
|
||||||
|
for call_id, call in (message.tool_calls or {}).items():
|
||||||
|
params = getattr(call, "params", None)
|
||||||
|
name = getattr(params, "name", None)
|
||||||
|
arguments = getattr(params, "arguments", None)
|
||||||
|
server = _split_server(name)
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": call_id,
|
||||||
|
"name": name,
|
||||||
|
"server": server,
|
||||||
|
"arguments_preview": format_args_preview(arguments),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stop_reason = message.stop_reason
|
||||||
|
if stop_reason is not None:
|
||||||
|
# LlmStopReason is a str-Enum; .value is the wire form fast-agent
|
||||||
|
# already uses ("toolUse", "endTurn", ...).
|
||||||
|
stop_reason_str = getattr(stop_reason, "value", str(stop_reason))
|
||||||
|
else:
|
||||||
|
stop_reason_str = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"kind": KIND,
|
||||||
|
"schema_version": SCHEMA_VERSION,
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
"iteration": self._iteration,
|
||||||
|
"stop_reason": stop_reason_str,
|
||||||
|
"content": content_blocks,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _split_server(name: str | None) -> str | None:
|
||||||
|
"""Extract the server prefix from a fast-agent tool name.
|
||||||
|
|
||||||
|
Fast-agent namespaces aggregator tools as ``server-tool`` (e.g.
|
||||||
|
``argos-search_web``). Single-server agents may use a bare name.
|
||||||
|
Returns the prefix when present, else ``None``.
|
||||||
|
"""
|
||||||
|
if not name or "-" not in name:
|
||||||
|
return None
|
||||||
|
prefix, _ = name.split("-", 1)
|
||||||
|
return prefix or None
|
||||||
|
|
||||||
|
|
||||||
|
def install_for_request(
|
||||||
|
agent: Any,
|
||||||
|
*,
|
||||||
|
ctx: "MCPContext",
|
||||||
|
agent_name: str,
|
||||||
|
conversation_id: str | None,
|
||||||
|
):
|
||||||
|
"""Install the assistant-stream hooks on a request-scoped agent instance.
|
||||||
|
|
||||||
|
Returns a no-arg callable that restores the agent's previous
|
||||||
|
``tool_runner_hooks`` — call it in a ``finally`` so per-request hooks
|
||||||
|
don't leak across requests if the underlying instance is ever reused.
|
||||||
|
|
||||||
|
Pallas runs ``instance_scope="request"`` so in normal operation the
|
||||||
|
instance is disposed immediately after; the restore call is defensive
|
||||||
|
against config changes or future shared-instance modes.
|
||||||
|
"""
|
||||||
|
emitter = AssistantChunkEmitter(
|
||||||
|
ctx, agent_name=agent_name, conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
extra = ToolRunnerHooks(
|
||||||
|
after_llm_call=emitter.as_after_llm_call_hook(),
|
||||||
|
before_tool_call=emitter.as_before_tool_call_hook(),
|
||||||
|
after_tool_call=emitter.as_after_tool_call_hook(),
|
||||||
|
)
|
||||||
|
|
||||||
|
previous = getattr(agent, "tool_runner_hooks", None)
|
||||||
|
merged = _merge_hooks(previous, extra)
|
||||||
|
agent.tool_runner_hooks = merged
|
||||||
|
|
||||||
|
def restore() -> None:
|
||||||
|
agent.tool_runner_hooks = previous
|
||||||
|
|
||||||
|
return restore
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_hooks(
|
||||||
|
base: ToolRunnerHooks | None, extra: ToolRunnerHooks
|
||||||
|
) -> ToolRunnerHooks:
|
||||||
|
"""Compose ``extra``'s hooks after each matching ``base`` hook.
|
||||||
|
|
||||||
|
Mirrors fast-agent's own ``ToolAgent._merge_tool_runner_hooks`` (private)
|
||||||
|
so we don't depend on its signature. Each slot composes independently;
|
||||||
|
when ``base`` is ``None`` the ``extra`` hook is used directly.
|
||||||
|
"""
|
||||||
|
if base is None:
|
||||||
|
return extra
|
||||||
|
|
||||||
|
return ToolRunnerHooks(
|
||||||
|
before_llm_call=_compose(base.before_llm_call, extra.before_llm_call),
|
||||||
|
after_llm_call=_compose(base.after_llm_call, extra.after_llm_call),
|
||||||
|
before_tool_call=_compose(base.before_tool_call, extra.before_tool_call),
|
||||||
|
after_tool_call=_compose(base.after_tool_call, extra.after_tool_call),
|
||||||
|
after_turn_complete=_compose(
|
||||||
|
base.after_turn_complete, extra.after_turn_complete
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compose(base_hook, extra_hook):
|
||||||
|
"""Return an async callable that runs ``base_hook`` then ``extra_hook``."""
|
||||||
|
if base_hook is None:
|
||||||
|
return extra_hook
|
||||||
|
if extra_hook is None: # pragma: no cover - extra always sets the slots we care about
|
||||||
|
return base_hook
|
||||||
|
|
||||||
|
async def merged(runner, message):
|
||||||
|
await base_hook(runner, message)
|
||||||
|
await extra_hook(runner, message)
|
||||||
|
|
||||||
|
return merged
|
||||||
@@ -90,6 +90,16 @@ class _StaticFieldsFilter(logging.Filter):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# Standard ``LogRecord`` attributes — everything else on ``record.__dict__`` is
|
||||||
|
# an ``extra={...}`` field a caller attached and wants serialized. ``message``
|
||||||
|
# and ``asctime`` are populated during formatting; ``taskName`` exists on 3.12+.
|
||||||
|
_STANDARD_LOGRECORD_KEYS = set(logging.makeLogRecord({}).__dict__) | {
|
||||||
|
"message",
|
||||||
|
"asctime",
|
||||||
|
"taskName",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class _JSONFormatter(logging.Formatter):
|
class _JSONFormatter(logging.Formatter):
|
||||||
"""Single-line JSON formatter compatible with Alloy's ``| json`` pipeline.
|
"""Single-line JSON formatter compatible with Alloy's ``| json`` pipeline.
|
||||||
|
|
||||||
@@ -123,6 +133,12 @@ class _JSONFormatter(logging.Formatter):
|
|||||||
"project": getattr(record, "project", _PROJECT),
|
"project": getattr(record, "project", _PROJECT),
|
||||||
"component": getattr(record, "component", _COMPONENT_CTX.get()),
|
"component": getattr(record, "component", _COMPONENT_CTX.get()),
|
||||||
}
|
}
|
||||||
|
# Merge caller-supplied ``extra={...}`` fields (anything on the record
|
||||||
|
# that isn't a standard LogRecord attribute or already emitted above).
|
||||||
|
for key, value in record.__dict__.items():
|
||||||
|
if key in _STANDARD_LOGRECORD_KEYS or key in payload:
|
||||||
|
continue
|
||||||
|
payload[key] = value
|
||||||
if record.exc_info:
|
if record.exc_info:
|
||||||
if not record.exc_text:
|
if not record.exc_text:
|
||||||
record.exc_text = self.formatException(record.exc_info)
|
record.exc_text = self.formatException(record.exc_info)
|
||||||
@@ -131,7 +147,8 @@ class _JSONFormatter(logging.Formatter):
|
|||||||
payload["traceback"] = record.exc_text
|
payload["traceback"] = record.exc_text
|
||||||
if record.stack_info:
|
if record.stack_info:
|
||||||
payload["stack"] = self.formatStack(record.stack_info)
|
payload["stack"] = self.formatStack(record.stack_info)
|
||||||
return json.dumps(payload)
|
# default=str keeps a non-serializable extra value from crashing logging.
|
||||||
|
return json.dumps(payload, default=str)
|
||||||
|
|
||||||
|
|
||||||
class _HealthAccessFilter(logging.Filter):
|
class _HealthAccessFilter(logging.Filter):
|
||||||
|
|||||||
231
pallas/loop_guard.py
Normal file
231
pallas/loop_guard.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
"""Runaway-loop detection for the agentic tool loop.
|
||||||
|
|
||||||
|
A small model occasionally gets stuck emitting the *identical* tool call
|
||||||
|
every iteration — typically because an upstream MCP server returned a
|
||||||
|
contradictory or malformed result the model keeps trying to reconcile.
|
||||||
|
Left alone the loop burns LLM turns and context until the client times
|
||||||
|
out and the user sees ``empty_response``.
|
||||||
|
|
||||||
|
This guard installs ``ToolRunnerHooks`` on the request-scoped agent that
|
||||||
|
track a per-turn signature of ``(tool, normalized_args) -> result_hash``.
|
||||||
|
When the same signature repeats ``threshold`` times consecutively the
|
||||||
|
loop is halted **immediately** — we don't ask the model to troubleshoot,
|
||||||
|
because the fault is almost always upstream and self-recovery is slow,
|
||||||
|
unpredictable, and token-hungry. Instead we:
|
||||||
|
|
||||||
|
* force fast-agent's own ``max_iterations`` termination path so the turn
|
||||||
|
ends after the current tool result with a partial answer rather than a
|
||||||
|
client timeout;
|
||||||
|
* replace the dangling turn with an honest, user-facing explanation;
|
||||||
|
* log the offending tool, arguments, and result with full detail so the
|
||||||
|
real (upstream) bug can be fixed durably; and
|
||||||
|
* increment ``pallas_agent_loop_aborted_total`` for alerting.
|
||||||
|
|
||||||
|
The hooks compose after any already-installed hooks (e.g. the assistant
|
||||||
|
stream) via the same merge strategy ``assistant_stream`` uses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fast_agent.agents.tool_runner import ToolRunnerHooks
|
||||||
|
from fast_agent.types import PromptMessageExtended
|
||||||
|
from fast_agent.types.llm_stop_reason import LlmStopReason
|
||||||
|
from mcp.types import TextContent
|
||||||
|
|
||||||
|
from pallas import metrics as _pallas_metrics
|
||||||
|
from pallas.assistant_stream import _merge_hooks, _split_server
|
||||||
|
from pallas.progress import format_args_preview, format_result_preview
|
||||||
|
|
||||||
|
logger = logging.getLogger("pallas.loop_guard")
|
||||||
|
|
||||||
|
DEFAULT_THRESHOLD = 3
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_args(arguments: Any) -> str:
|
||||||
|
"""Deterministically serialize tool arguments for signature comparison."""
|
||||||
|
try:
|
||||||
|
return json.dumps(arguments, sort_keys=True, default=str, ensure_ascii=False)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return repr(arguments)
|
||||||
|
|
||||||
|
|
||||||
|
def _result_hash(result: Any) -> str:
|
||||||
|
"""Stable hash of a tool result's content for repeat detection."""
|
||||||
|
parts: list[str] = []
|
||||||
|
for block in getattr(result, "content", None) or []:
|
||||||
|
text = getattr(block, "text", None)
|
||||||
|
parts.append(text if isinstance(text, str) else repr(block))
|
||||||
|
parts.append(f"isError={bool(getattr(result, 'isError', False))}")
|
||||||
|
digest = hashlib.sha256("\x00".join(parts).encode("utf-8", "replace"))
|
||||||
|
return digest.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class LoopGuard:
|
||||||
|
"""Per-request consecutive-identical-tool-call detector."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, *, agent_name: str, conversation_id: str | None, threshold: int
|
||||||
|
) -> None:
|
||||||
|
self._agent_name = agent_name
|
||||||
|
self._conversation_id = conversation_id
|
||||||
|
self._threshold = threshold
|
||||||
|
# call_id -> (tool_name, normalized_args, raw_arguments), staged by
|
||||||
|
# before_tool_call and consumed by the matching after_tool_call.
|
||||||
|
self._pending: dict[str, tuple[str | None, str, Any]] = {}
|
||||||
|
self._last_signature: str | None = None
|
||||||
|
self._repeat_count = 0
|
||||||
|
self._halted = False
|
||||||
|
|
||||||
|
def as_before_tool_call_hook(self):
|
||||||
|
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
for call_id, call in (message.tool_calls or {}).items():
|
||||||
|
params = getattr(call, "params", None)
|
||||||
|
name = getattr(params, "name", None)
|
||||||
|
arguments = getattr(params, "arguments", None)
|
||||||
|
self._pending[call_id] = (name, _normalize_args(arguments), arguments)
|
||||||
|
|
||||||
|
return _hook
|
||||||
|
|
||||||
|
def as_after_tool_call_hook(self):
|
||||||
|
async def _hook(runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
if self._halted:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self._evaluate(runner, message)
|
||||||
|
except Exception: # never let the guard break a live turn
|
||||||
|
logger.warning(
|
||||||
|
"loop_guard evaluation failed",
|
||||||
|
exc_info=True,
|
||||||
|
extra={
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return _hook
|
||||||
|
|
||||||
|
def _evaluate(self, runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
results = message.tool_results or {}
|
||||||
|
if not results:
|
||||||
|
return
|
||||||
|
|
||||||
|
components: list[tuple[str, str, str]] = []
|
||||||
|
names: list[str] = []
|
||||||
|
raw_args: list[Any] = []
|
||||||
|
for call_id, result in results.items():
|
||||||
|
name, args_sig, arguments = self._pending.pop(call_id, (None, "null", None))
|
||||||
|
components.append((name or "", args_sig, _result_hash(result)))
|
||||||
|
if name:
|
||||||
|
names.append(name)
|
||||||
|
raw_args.append(arguments)
|
||||||
|
components.sort()
|
||||||
|
signature = "|".join(f"{n}:{a}:{r}" for n, a, r in components)
|
||||||
|
|
||||||
|
if signature == self._last_signature:
|
||||||
|
self._repeat_count += 1
|
||||||
|
else:
|
||||||
|
self._last_signature = signature
|
||||||
|
self._repeat_count = 1
|
||||||
|
|
||||||
|
if self._repeat_count >= self._threshold:
|
||||||
|
tool_label = ", ".join(dict.fromkeys(names)) or "unknown"
|
||||||
|
args_preview = format_args_preview(raw_args[0]) if raw_args else ""
|
||||||
|
self._halt(runner, results, tool_label, args_preview)
|
||||||
|
|
||||||
|
def _halt(
|
||||||
|
self,
|
||||||
|
runner: Any,
|
||||||
|
results: dict[str, Any],
|
||||||
|
tool_label: str,
|
||||||
|
args_preview: str,
|
||||||
|
) -> None:
|
||||||
|
self._halted = True
|
||||||
|
|
||||||
|
first = next(iter(results.values()), None)
|
||||||
|
result_preview = (
|
||||||
|
format_result_preview(list(getattr(first, "content", []) or []))
|
||||||
|
if first is not None
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force fast-agent's own termination check (tool_runner: `_iteration >
|
||||||
|
# max_iterations`) to fire after this tool result — no further LLM
|
||||||
|
# call, partial answer returned instead of a client timeout.
|
||||||
|
params = getattr(runner, "request_params", None)
|
||||||
|
iteration = getattr(runner, "iteration", 0)
|
||||||
|
if params is not None:
|
||||||
|
params.max_iterations = iteration
|
||||||
|
|
||||||
|
# Replace the dangling tool-use turn with an honest final message so
|
||||||
|
# the client gets an explanation rather than an empty/truncated turn.
|
||||||
|
note = (
|
||||||
|
f"Halted: the '{tool_label}' tool returned an identical result "
|
||||||
|
f"{self._repeat_count} times in a row, so the agent was looping "
|
||||||
|
"without making progress. This is usually an upstream data or "
|
||||||
|
"tool-server issue rather than a problem with the request. "
|
||||||
|
"Stopping early to avoid a runaway loop."
|
||||||
|
)
|
||||||
|
last = getattr(runner, "last_message", None)
|
||||||
|
if last is not None:
|
||||||
|
if last.content is None:
|
||||||
|
last.content = []
|
||||||
|
last.content.append(TextContent(type="text", text=note))
|
||||||
|
last.stop_reason = LlmStopReason.END_TURN
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"agentic loop halted: identical tool call repeated",
|
||||||
|
extra={
|
||||||
|
"event": "loop_halt",
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
"tool": tool_label,
|
||||||
|
"server": _split_server(tool_label),
|
||||||
|
"repeat_count": self._repeat_count,
|
||||||
|
"threshold": self._threshold,
|
||||||
|
"iteration": iteration,
|
||||||
|
"arguments_preview": args_preview,
|
||||||
|
"result_preview": result_preview,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
_pallas_metrics.record_loop_abort(self._agent_name, "repeat")
|
||||||
|
|
||||||
|
|
||||||
|
def install_for_request(
|
||||||
|
agent: Any,
|
||||||
|
*,
|
||||||
|
agent_name: str,
|
||||||
|
conversation_id: str | None,
|
||||||
|
threshold: int = DEFAULT_THRESHOLD,
|
||||||
|
):
|
||||||
|
"""Install the loop guard on a request-scoped agent instance.
|
||||||
|
|
||||||
|
Returns a no-arg ``restore`` callable that reinstates the agent's prior
|
||||||
|
``tool_runner_hooks`` — call it in a ``finally``. A non-positive
|
||||||
|
``threshold`` disables the guard (returns a no-op restore).
|
||||||
|
"""
|
||||||
|
if threshold is None or threshold < 1:
|
||||||
|
return lambda: None
|
||||||
|
|
||||||
|
guard = LoopGuard(
|
||||||
|
agent_name=agent_name,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
threshold=threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
extra = ToolRunnerHooks(
|
||||||
|
before_tool_call=guard.as_before_tool_call_hook(),
|
||||||
|
after_tool_call=guard.as_after_tool_call_hook(),
|
||||||
|
)
|
||||||
|
|
||||||
|
previous = getattr(agent, "tool_runner_hooks", None)
|
||||||
|
agent.tool_runner_hooks = _merge_hooks(previous, extra)
|
||||||
|
|
||||||
|
def restore() -> None:
|
||||||
|
agent.tool_runner_hooks = previous
|
||||||
|
|
||||||
|
return restore
|
||||||
213
pallas/mcp_subagent_relay.py
Normal file
213
pallas/mcp_subagent_relay.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""Sub-agent activity relay over the MCP logging channel.
|
||||||
|
|
||||||
|
When Alan (or any peer-orchestrator agent) calls a sub-agent like ``research``,
|
||||||
|
that sub-agent is itself a Pallas server emitting its own ``assistant_chunk``
|
||||||
|
notifications. Those notifications are delivered to fast-agent's
|
||||||
|
``MCPAgentClientSession`` for the peer — and stop there: fast-agent's default
|
||||||
|
session has no ``logging_callback``, so the only thing that ever crosses
|
||||||
|
back up to Daedalus is the final ``CallToolResult`` text. When a sub-agent
|
||||||
|
does substantive multi-turn reasoning, all that reasoning disappears into
|
||||||
|
server logs.
|
||||||
|
|
||||||
|
This module installs a ``logging_callback`` on every peer session Pallas
|
||||||
|
opens and re-emits the relevant notifications upstream, decorated so the
|
||||||
|
consumer can render them as **nested iterations** under the parent
|
||||||
|
tool-call entry.
|
||||||
|
|
||||||
|
Two pieces of provenance are added:
|
||||||
|
|
||||||
|
* ``via_agent`` — call-stack list (e.g. ``["research"]`` for chunks the
|
||||||
|
``research`` peer emitted; ``["research", "search"]`` if research itself
|
||||||
|
relayed deeper). The parent agent is implicit (it's whose message
|
||||||
|
these attach to) so the list contains only relayed-through peers.
|
||||||
|
* ``parent_call_id`` — the LLM ``tool_use_id`` (e.g. ``toolu_42``) that
|
||||||
|
the parent's LLM emitted to invoke this peer. The Daedalus collector
|
||||||
|
uses it to nest the relayed iterations under the matching tool_call
|
||||||
|
entry on the parent's iteration card.
|
||||||
|
|
||||||
|
The schema_version is bumped to **3** when relaying. Top-level chunks
|
||||||
|
(produced by the parent agent's own assistant_stream hooks, no
|
||||||
|
``parent_call_id``) are unchanged at v2 — adding a new optional field on
|
||||||
|
top-level entries doesn't change the wire shape.
|
||||||
|
|
||||||
|
Depth cap (``max_relay_depth``, default 3) drops chunks that would push
|
||||||
|
``len(via_agent)`` past the cap. This guards against runaway nested-pyramid
|
||||||
|
JSON in ``Message.iterations`` if a future agent wired up indefinite peer
|
||||||
|
chains; today the realistic depth is 2 (Alan → research → search/knowledge).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import deque
|
||||||
|
from typing import Any, Awaitable, Callable, Deque, Tuple
|
||||||
|
|
||||||
|
from mcp.types import LoggingMessageNotificationParams
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Logger string Pallas tags assistant-stream notifications with.
|
||||||
|
_ASSISTANT_STREAM_LOGGER = "pallas.assistant_stream"
|
||||||
|
_KIND = "assistant_chunk"
|
||||||
|
_KIND_RESULTS = "assistant_chunk_results"
|
||||||
|
_KINDS = frozenset({_KIND, _KIND_RESULTS})
|
||||||
|
|
||||||
|
# Schema produced by this relay. The relay only re-emits payloads at v2+
|
||||||
|
# (the version that carries enriched tool_calls + companion results); the
|
||||||
|
# bump to 3 signals "may carry via_agent / parent_call_id".
|
||||||
|
RELAY_SCHEMA_VERSION = 3
|
||||||
|
|
||||||
|
# Default depth cap. Override at install time via ``max_relay_depth``.
|
||||||
|
DEFAULT_MAX_RELAY_DEPTH = 3
|
||||||
|
|
||||||
|
|
||||||
|
# Per-session in-flight stack: (parent_call_id, peer_server_name).
|
||||||
|
# Pushed when ``MCPAggregator.call_tool`` enters; popped when it exits.
|
||||||
|
# Stored as a deque so concurrent same-server tool calls (rare — fast-agent
|
||||||
|
# supports parallel tool execution within an iteration, see
|
||||||
|
# ``tool_agent.py`` ``run_tools`` ``should_parallel`` branch) get attributed
|
||||||
|
# to whichever push is currently most recent. This is a best-effort
|
||||||
|
# attribution; the realistic concurrent shape is "different servers" not
|
||||||
|
# "same server".
|
||||||
|
_INFLIGHT_ATTR = "_pallas_relay_inflight"
|
||||||
|
|
||||||
|
|
||||||
|
def push_inflight(session: Any, parent_call_id: str, peer_name: str) -> None:
|
||||||
|
"""Record an in-flight peer call so the relay can tag its chunks.
|
||||||
|
|
||||||
|
The session is the ``MCPAgentClientSession`` instance that fast-agent's
|
||||||
|
aggregator opened to the peer. We attach a deque lazily so this
|
||||||
|
helper is a no-op on sessions that never receive any chunks.
|
||||||
|
"""
|
||||||
|
stack: Deque[Tuple[str, str]] = getattr(session, _INFLIGHT_ATTR, None)
|
||||||
|
if stack is None:
|
||||||
|
stack = deque()
|
||||||
|
setattr(session, _INFLIGHT_ATTR, stack)
|
||||||
|
stack.append((parent_call_id, peer_name))
|
||||||
|
|
||||||
|
|
||||||
|
def pop_inflight(session: Any, parent_call_id: str) -> None:
|
||||||
|
"""Remove the most-recent matching entry for this parent_call_id.
|
||||||
|
|
||||||
|
Match by id rather than blind ``pop()`` so out-of-order completions
|
||||||
|
(when fast-agent runs parallel tool calls) don't strand entries.
|
||||||
|
"""
|
||||||
|
stack: Deque[Tuple[str, str]] | None = getattr(session, _INFLIGHT_ATTR, None)
|
||||||
|
if not stack:
|
||||||
|
return
|
||||||
|
# Walk the deque from the right (newest first) to find the match.
|
||||||
|
for i in range(len(stack) - 1, -1, -1):
|
||||||
|
if stack[i][0] == parent_call_id:
|
||||||
|
del stack[i]
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def current_inflight(session: Any) -> Tuple[str, str] | None:
|
||||||
|
"""Return ``(parent_call_id, peer_name)`` for the most-recent in-flight call.
|
||||||
|
|
||||||
|
Returns ``None`` when no peer call is currently in flight on this
|
||||||
|
session — a chunk arriving in that window came from server-side
|
||||||
|
bookkeeping (e.g. the peer's own startup chatter) that we shouldn't
|
||||||
|
attribute to any specific parent tool_call.
|
||||||
|
"""
|
||||||
|
stack: Deque[Tuple[str, str]] | None = getattr(session, _INFLIGHT_ATTR, None)
|
||||||
|
if not stack:
|
||||||
|
return None
|
||||||
|
return stack[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def make_logging_callback(
|
||||||
|
*,
|
||||||
|
upstream_send: Callable[[dict[str, Any]], Awaitable[None]],
|
||||||
|
session: Any,
|
||||||
|
max_depth: int = DEFAULT_MAX_RELAY_DEPTH,
|
||||||
|
) -> Callable[[LoggingMessageNotificationParams], Awaitable[None]]:
|
||||||
|
"""Build a ``logging_callback`` that relays peer assistant-stream chunks.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
upstream_send:
|
||||||
|
Async callable that ships a payload upward via the parent's MCP
|
||||||
|
session (typically wraps ``ctx.session.send_log_message`` from
|
||||||
|
the parent's per-request ``MCPContext``). Failures inside it
|
||||||
|
must not propagate — the relay logs and moves on.
|
||||||
|
session:
|
||||||
|
The peer ``MCPAgentClientSession`` this callback is attached to.
|
||||||
|
Used to look up the in-flight ``(parent_call_id, peer_name)`` for
|
||||||
|
attribution.
|
||||||
|
max_depth:
|
||||||
|
Hard cap on ``len(via_agent)`` before relaying. Chunks that
|
||||||
|
would exceed it are dropped with a debug log.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _callback(params: LoggingMessageNotificationParams) -> None:
|
||||||
|
# 1) Filter — only Pallas assistant-stream notifications get relayed;
|
||||||
|
# every other log message routes through the existing trace
|
||||||
|
# handlers and isn't our concern.
|
||||||
|
if params.logger != _ASSISTANT_STREAM_LOGGER:
|
||||||
|
return
|
||||||
|
data = params.data
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return
|
||||||
|
if data.get("kind") not in _KINDS:
|
||||||
|
return
|
||||||
|
schema = data.get("schema_version")
|
||||||
|
if not isinstance(schema, int) or schema < 2:
|
||||||
|
# v1 chunks don't carry the enriched tool_calls / results
|
||||||
|
# shape the consumer needs to render nested iterations.
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2) Look up the in-flight tool call this chunk belongs to. No
|
||||||
|
# in-flight call → drop (the chunk arrived during peer
|
||||||
|
# bookkeeping outside any tool_call, which we can't attribute).
|
||||||
|
inflight = current_inflight(session)
|
||||||
|
if inflight is None:
|
||||||
|
return
|
||||||
|
parent_call_id, peer_name = inflight
|
||||||
|
|
||||||
|
# 3) Enforce the depth cap. ``via_agent`` may already exist if
|
||||||
|
# the peer was itself relaying chunks from a deeper sub-agent.
|
||||||
|
existing_via: list[str] = list(data.get("via_agent") or [])
|
||||||
|
if len(existing_via) >= max_depth:
|
||||||
|
logger.debug(
|
||||||
|
"subagent_relay_depth_capped",
|
||||||
|
extra={
|
||||||
|
"depth": len(existing_via),
|
||||||
|
"max_depth": max_depth,
|
||||||
|
"peer": peer_name,
|
||||||
|
"parent_call_id": parent_call_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4) Decorate the payload and re-emit upstream. Mutate a shallow
|
||||||
|
# copy so we don't change a dict the SDK might still reference.
|
||||||
|
relayed = dict(data)
|
||||||
|
relayed["via_agent"] = existing_via + [peer_name]
|
||||||
|
relayed["parent_call_id"] = parent_call_id
|
||||||
|
relayed["schema_version"] = RELAY_SCHEMA_VERSION
|
||||||
|
|
||||||
|
try:
|
||||||
|
await upstream_send(relayed)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"subagent_relay_upstream_send_failed",
|
||||||
|
exc_info=True,
|
||||||
|
extra={
|
||||||
|
"peer": peer_name,
|
||||||
|
"parent_call_id": parent_call_id,
|
||||||
|
"kind": relayed.get("kind"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return _callback
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RELAY_SCHEMA_VERSION",
|
||||||
|
"DEFAULT_MAX_RELAY_DEPTH",
|
||||||
|
"current_inflight",
|
||||||
|
"make_logging_callback",
|
||||||
|
"pop_inflight",
|
||||||
|
"push_inflight",
|
||||||
|
]
|
||||||
@@ -131,10 +131,22 @@ agent_health_status = Gauge(
|
|||||||
registry=REGISTRY,
|
registry=REGISTRY,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_loop_aborted_total = Counter(
|
||||||
|
"pallas_agent_loop_aborted_total",
|
||||||
|
"Agentic loops force-stopped by a runtime guard",
|
||||||
|
labelnames=["agent", "reason"], # reason: repeat
|
||||||
|
registry=REGISTRY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def record_loop_abort(agent: str, reason: str) -> None:
|
||||||
|
"""Record one agentic loop aborted by a runtime guard."""
|
||||||
|
agent_loop_aborted_total.labels(agent=agent, reason=reason).inc()
|
||||||
|
|
||||||
|
|
||||||
def set_agent_info(agents: dict[str, dict]) -> None:
|
def set_agent_info(agents: dict[str, dict]) -> None:
|
||||||
"""Record the deployment's configured agents (called once at startup)."""
|
"""Record the deployment's configured agents (called once at startup)."""
|
||||||
for name, agent in agents.items():
|
for name, agent in agents.items():
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ from fast_agent.core.logging.logger import get_logger
|
|||||||
from fast_agent.mcp.server import AgentMCPServer
|
from fast_agent.mcp.server import AgentMCPServer
|
||||||
from fast_agent.types import PromptMessageExtended, RequestParams
|
from fast_agent.types import PromptMessageExtended, RequestParams
|
||||||
|
|
||||||
|
from pallas.assistant_stream import install_for_request as _install_assistant_stream
|
||||||
|
from pallas.loop_guard import install_for_request as _install_loop_guard
|
||||||
from pallas.progress import EnrichedMCPToolProgressManager
|
from pallas.progress import EnrichedMCPToolProgressManager
|
||||||
from pallas import metrics as _pallas_metrics
|
from pallas import metrics as _pallas_metrics
|
||||||
from fastmcp import Context as MCPContext
|
from fastmcp import Context as MCPContext
|
||||||
@@ -220,6 +222,33 @@ class MultimodalAgentMCPServer(AgentMCPServer):
|
|||||||
agent_context = getattr(agent, "context", None)
|
agent_context = getattr(agent, "context", None)
|
||||||
metrics_start = time.perf_counter()
|
metrics_start = time.perf_counter()
|
||||||
metrics_outcome = "ok"
|
metrics_outcome = "ok"
|
||||||
|
|
||||||
|
# Install per-request after_llm_call hook that ships every
|
||||||
|
# intermediate assistant turn over MCP as a notifications/message.
|
||||||
|
# Without this, only the final ``agent.send()`` return value
|
||||||
|
# crosses the MCP boundary — substantive assistant text emitted
|
||||||
|
# in earlier loop iterations stays trapped inside fast-agent's
|
||||||
|
# ``message_history`` and the user sees a spinner that ends with
|
||||||
|
# a thin wrap-up sentence.
|
||||||
|
restore_stream = _install_assistant_stream(
|
||||||
|
agent,
|
||||||
|
ctx=ctx,
|
||||||
|
agent_name=agent_name,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
)
|
||||||
|
# Compose the loop guard on top: it halts the agentic loop the
|
||||||
|
# moment a tool call repeats with an identical result, before
|
||||||
|
# the turn runs to the iteration cap or client timeout.
|
||||||
|
restore_guard = _install_loop_guard(
|
||||||
|
agent,
|
||||||
|
agent_name=agent_name,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
threshold=self._request_limits.get("loop_repeat_threshold", 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
def restore_hooks() -> None:
|
||||||
|
restore_guard()
|
||||||
|
restore_stream()
|
||||||
try:
|
try:
|
||||||
# Seed the freshly-created instance's message_history from the
|
# Seed the freshly-created instance's message_history from the
|
||||||
# caller-supplied history so the agent sees the full
|
# caller-supplied history so the agent sees the full
|
||||||
@@ -313,6 +342,14 @@ class MultimodalAgentMCPServer(AgentMCPServer):
|
|||||||
_pallas_metrics.send_message_total.labels(
|
_pallas_metrics.send_message_total.labels(
|
||||||
agent=agent_name, outcome=metrics_outcome
|
agent=agent_name, outcome=metrics_outcome
|
||||||
).inc()
|
).inc()
|
||||||
|
# Restore the agent's prior tool_runner_hooks before the
|
||||||
|
# instance is released — defensive against any future
|
||||||
|
# shared-instance mode where leaking per-request hooks
|
||||||
|
# across requests would mis-attribute notifications.
|
||||||
|
try:
|
||||||
|
restore_hooks()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
await self._release_instance(ctx, instance)
|
await self._release_instance(ctx, instance)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ def _build_agents_table(config: dict) -> dict[str, dict]:
|
|||||||
"max_iterations": agent.get("max_iterations"),
|
"max_iterations": agent.get("max_iterations"),
|
||||||
"streaming_timeout": agent.get("streaming_timeout"),
|
"streaming_timeout": agent.get("streaming_timeout"),
|
||||||
"turn_timeout": agent.get("turn_timeout"),
|
"turn_timeout": agent.get("turn_timeout"),
|
||||||
|
"loop_repeat_threshold": agent.get("loop_repeat_threshold"),
|
||||||
}
|
}
|
||||||
for name, agent in config["agents"].items()
|
for name, agent in config["agents"].items()
|
||||||
}
|
}
|
||||||
@@ -264,7 +265,12 @@ async def _start_agent(name: str, agents: dict[str, dict]) -> None:
|
|||||||
# and the LLM sees exactly what Daedalus asks it to see.
|
# and the LLM sees exactly what Daedalus asks it to see.
|
||||||
request_limits = {
|
request_limits = {
|
||||||
k: entry[k]
|
k: entry[k]
|
||||||
for k in ("max_iterations", "streaming_timeout", "turn_timeout")
|
for k in (
|
||||||
|
"max_iterations",
|
||||||
|
"streaming_timeout",
|
||||||
|
"turn_timeout",
|
||||||
|
"loop_repeat_threshold",
|
||||||
|
)
|
||||||
if entry.get(k) is not None
|
if entry.get(k) is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,7 +300,7 @@ async def _wait_for_agent(
|
|||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
port = agents[name]["port"]
|
port = agents[name]["port"]
|
||||||
url = f"http://127.0.0.1:{port}/mcp"
|
url = f"http://127.0.0.1:{port}/ready"
|
||||||
deadline = asyncio.get_event_loop().time() + timeout
|
deadline = asyncio.get_event_loop().time() + timeout
|
||||||
while asyncio.get_event_loop().time() < deadline:
|
while asyncio.get_event_loop().time() < deadline:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "pallas-mcp"
|
name = "pallas-mcp"
|
||||||
version = "0.3.0"
|
version = "0.5.0"
|
||||||
description = "FastAgent MCP Bridge — generic runtime for serving FastAgent agents over StreamableHTTP"
|
description = "FastAgent MCP Bridge — generic runtime for serving FastAgent agents over StreamableHTTP"
|
||||||
requires-python = ">=3.13.5"
|
requires-python = ">=3.13.5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fast-agent-mcp>=0.6.10",
|
"fast-agent-mcp>=0.7.15",
|
||||||
"httpx",
|
"httpx",
|
||||||
"prometheus-client",
|
"prometheus-client",
|
||||||
"pyyaml",
|
"pyyaml",
|
||||||
|
|||||||
402
tests/test_assistant_stream.py
Normal file
402
tests/test_assistant_stream.py
Normal file
@@ -0,0 +1,402 @@
|
|||||||
|
"""Tests for ``pallas.assistant_stream``.
|
||||||
|
|
||||||
|
Drives the ``after_llm_call`` / ``before_tool_call`` / ``after_tool_call``
|
||||||
|
hooks with handcrafted ``PromptMessageExtended`` objects and asserts the
|
||||||
|
resulting MCP ``send_log_message`` payload shape. No fast-agent runtime is
|
||||||
|
involved — the hooks are pure async functions and the MCP context is faked.
|
||||||
|
|
||||||
|
Tests use ``asyncio.run`` directly to match the convention in
|
||||||
|
``tests/test_health.py`` and ``tests/test_mantle_shims.py`` (pallas has no
|
||||||
|
pytest-asyncio dependency).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fast_agent.agents.tool_runner import ToolRunnerHooks
|
||||||
|
from fast_agent.types import PromptMessageExtended
|
||||||
|
from fast_agent.types.llm_stop_reason import LlmStopReason
|
||||||
|
from mcp.types import (
|
||||||
|
CallToolRequest,
|
||||||
|
CallToolRequestParams,
|
||||||
|
CallToolResult,
|
||||||
|
ImageContent,
|
||||||
|
TextContent,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pallas.assistant_stream import (
|
||||||
|
KIND,
|
||||||
|
KIND_RESULTS,
|
||||||
|
LOGGER_NAME,
|
||||||
|
SCHEMA_VERSION,
|
||||||
|
AssistantChunkEmitter,
|
||||||
|
install_for_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fakes ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
"""Records every call to ``send_log_message`` for later assertion."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls: list[dict[str, Any]] = []
|
||||||
|
self.fail_with: Exception | None = None
|
||||||
|
|
||||||
|
async def send_log_message(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
level: str,
|
||||||
|
data: Any,
|
||||||
|
logger: str | None = None,
|
||||||
|
related_request_id: Any = None,
|
||||||
|
) -> None:
|
||||||
|
if self.fail_with is not None:
|
||||||
|
raise self.fail_with
|
||||||
|
self.calls.append(
|
||||||
|
{
|
||||||
|
"level": level,
|
||||||
|
"data": data,
|
||||||
|
"logger": logger,
|
||||||
|
"related_request_id": related_request_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeContext:
|
||||||
|
def __init__(self, request_id: str = "req-1") -> None:
|
||||||
|
self.session = _FakeSession()
|
||||||
|
self.request_id = request_id
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeAgent:
|
||||||
|
"""Minimal stand-in for a fast-agent agent — only carries hooks."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.tool_runner_hooks: ToolRunnerHooks | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call(name: str, arguments: dict | None = None) -> CallToolRequest:
|
||||||
|
return CallToolRequest(
|
||||||
|
method="tools/call",
|
||||||
|
params=CallToolRequestParams(name=name, arguments=arguments or {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_result(
|
||||||
|
text: str | None = "ok", *, is_error: bool = False
|
||||||
|
) -> CallToolResult:
|
||||||
|
content: list = []
|
||||||
|
if text is not None:
|
||||||
|
content.append(TextContent(type="text", text=text))
|
||||||
|
return CallToolResult(content=content, isError=is_error)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_text_only_iteration() -> None:
|
||||||
|
"""A pure-text assistant turn produces one log message with one text block."""
|
||||||
|
ctx = _FakeContext(request_id="req-text")
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="Fair. You can't size value...")],
|
||||||
|
stop_reason=LlmStopReason.END_TURN,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(emitter._emit_assistant_chunk(msg))
|
||||||
|
|
||||||
|
assert len(ctx.session.calls) == 1
|
||||||
|
call = ctx.session.calls[0]
|
||||||
|
assert call["level"] == "info"
|
||||||
|
assert call["logger"] == LOGGER_NAME
|
||||||
|
assert call["related_request_id"] == "req-text"
|
||||||
|
|
||||||
|
data = call["data"]
|
||||||
|
assert data["kind"] == KIND
|
||||||
|
assert data["schema_version"] == SCHEMA_VERSION
|
||||||
|
assert data["agent"] == "alan"
|
||||||
|
assert data["conversation_id"] == "conv-1"
|
||||||
|
assert data["iteration"] == 1
|
||||||
|
assert data["stop_reason"] == "endTurn"
|
||||||
|
assert data["content"] == [{"type": "text", "text": "Fair. You can't size value..."}]
|
||||||
|
assert data["tool_calls"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_text_with_tool_call_iteration_carries_args_preview() -> None:
|
||||||
|
"""An iteration that requests a tool ships name, server prefix, and args preview."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="Logging this…")],
|
||||||
|
tool_calls={
|
||||||
|
"toolu_1": _tool_call(
|
||||||
|
"argos-search_web",
|
||||||
|
{"query": "ducati v2 vs v4 reliability"},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
stop_reason=LlmStopReason.TOOL_USE,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(emitter._emit_assistant_chunk(msg))
|
||||||
|
|
||||||
|
assert len(ctx.session.calls) == 1
|
||||||
|
data = ctx.session.calls[0]["data"]
|
||||||
|
assert data["stop_reason"] == "toolUse"
|
||||||
|
assert data["content"] == [{"type": "text", "text": "Logging this…"}]
|
||||||
|
assert data["tool_calls"] == [
|
||||||
|
{
|
||||||
|
"id": "toolu_1",
|
||||||
|
"name": "argos-search_web",
|
||||||
|
"server": "argos",
|
||||||
|
"arguments_preview": "ducati v2 vs v4 reliability",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_skips_completely_empty_iteration() -> None:
|
||||||
|
"""A turn with no text blocks and no tool calls emits nothing."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(role="assistant", content=[])
|
||||||
|
|
||||||
|
_run(emitter._emit_assistant_chunk(msg))
|
||||||
|
|
||||||
|
assert ctx.session.calls == []
|
||||||
|
# Iteration counter still bumps so subsequent chunks aren't mis-numbered.
|
||||||
|
assert emitter._iteration == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_image_block_passes_through_with_mime_type_renamed() -> None:
|
||||||
|
"""ImageContent blocks are serialized with ``mime_type`` (snake-case)."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id=None)
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[ImageContent(type="image", data="ZmFrZQ==", mimeType="image/png")],
|
||||||
|
stop_reason=LlmStopReason.END_TURN,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(emitter._emit_assistant_chunk(msg))
|
||||||
|
|
||||||
|
data = ctx.session.calls[0]["data"]
|
||||||
|
assert data["content"] == [
|
||||||
|
{"type": "image", "data": "ZmFrZQ==", "mime_type": "image/png"}
|
||||||
|
]
|
||||||
|
assert data["conversation_id"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_iterations_are_numbered_in_order() -> None:
|
||||||
|
"""A multi-iteration loop produces sequentially numbered chunks."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||||
|
|
||||||
|
async def drive() -> None:
|
||||||
|
await emitter._emit_assistant_chunk(
|
||||||
|
PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="first")],
|
||||||
|
stop_reason=LlmStopReason.TOOL_USE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await emitter._emit_assistant_chunk(
|
||||||
|
PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="second")],
|
||||||
|
stop_reason=LlmStopReason.TOOL_USE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await emitter._emit_assistant_chunk(
|
||||||
|
PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="third — done")],
|
||||||
|
stop_reason=LlmStopReason.END_TURN,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(drive())
|
||||||
|
|
||||||
|
assert [c["data"]["iteration"] for c in ctx.session.calls] == [1, 2, 3]
|
||||||
|
assert [c["data"]["stop_reason"] for c in ctx.session.calls] == [
|
||||||
|
"toolUse",
|
||||||
|
"toolUse",
|
||||||
|
"endTurn",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_swallows_session_failure() -> None:
|
||||||
|
"""If ``send_log_message`` raises, the hook does not propagate."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
ctx.session.fail_with = RuntimeError("transport closed")
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="hi")],
|
||||||
|
stop_reason=LlmStopReason.END_TURN,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(emitter._emit_assistant_chunk(msg))
|
||||||
|
assert ctx.session.calls == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_tool_results_pairs_call_id_with_iteration() -> None:
|
||||||
|
"""``after_tool_call`` ships a results payload keyed to the iteration that called."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||||
|
|
||||||
|
# Iteration 1: text + tool request.
|
||||||
|
iter1 = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="Searching…")],
|
||||||
|
tool_calls={"toolu_1": _tool_call("argos-search_web", {"query": "foo"})},
|
||||||
|
stop_reason=LlmStopReason.TOOL_USE,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def drive() -> None:
|
||||||
|
# Simulate the runner: after_llm_call → before_tool_call → after_tool_call.
|
||||||
|
await emitter._emit_assistant_chunk(iter1)
|
||||||
|
await emitter.as_before_tool_call_hook()(None, iter1)
|
||||||
|
# The "user" message produced by the tool runner carries tool_results.
|
||||||
|
tool_msg = PromptMessageExtended(
|
||||||
|
role="user",
|
||||||
|
content=[],
|
||||||
|
tool_results={"toolu_1": _tool_result("12 results found")},
|
||||||
|
)
|
||||||
|
await emitter.as_after_tool_call_hook()(None, tool_msg)
|
||||||
|
|
||||||
|
_run(drive())
|
||||||
|
|
||||||
|
assert len(ctx.session.calls) == 2
|
||||||
|
chunk = ctx.session.calls[0]["data"]
|
||||||
|
results = ctx.session.calls[1]["data"]
|
||||||
|
assert chunk["kind"] == KIND
|
||||||
|
assert chunk["iteration"] == 1
|
||||||
|
assert results["kind"] == KIND_RESULTS
|
||||||
|
assert results["iteration"] == 1
|
||||||
|
assert results["agent"] == "alan"
|
||||||
|
assert results["conversation_id"] == "conv-1"
|
||||||
|
assert len(results["results"]) == 1
|
||||||
|
entry = results["results"][0]
|
||||||
|
assert entry["id"] == "toolu_1"
|
||||||
|
assert entry["ok"] is True
|
||||||
|
assert entry["result_preview"] == "12 results found"
|
||||||
|
assert isinstance(entry["duration_ms"], int)
|
||||||
|
assert entry["duration_ms"] >= 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_tool_results_marks_error() -> None:
|
||||||
|
"""A failing tool call surfaces ``ok: False`` in the results entry."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="c")
|
||||||
|
|
||||||
|
iter1 = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[],
|
||||||
|
tool_calls={"toolu_1": _tool_call("argos-search_web", {"query": "foo"})},
|
||||||
|
stop_reason=LlmStopReason.TOOL_USE,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def drive() -> None:
|
||||||
|
await emitter._emit_assistant_chunk(iter1)
|
||||||
|
await emitter.as_before_tool_call_hook()(None, iter1)
|
||||||
|
tool_msg = PromptMessageExtended(
|
||||||
|
role="user",
|
||||||
|
content=[],
|
||||||
|
tool_results={
|
||||||
|
"toolu_1": _tool_result("connection refused", is_error=True)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await emitter.as_after_tool_call_hook()(None, tool_msg)
|
||||||
|
|
||||||
|
_run(drive())
|
||||||
|
|
||||||
|
# iter1 is a pure-tool turn (no text + has tool_calls) → still emits an
|
||||||
|
# assistant_chunk because tool_calls is non-empty.
|
||||||
|
assert len(ctx.session.calls) == 2
|
||||||
|
results = ctx.session.calls[1]["data"]
|
||||||
|
assert results["results"][0]["ok"] is False
|
||||||
|
assert results["results"][0]["result_preview"] == "connection refused"
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_tool_results_empty_when_no_results() -> None:
|
||||||
|
"""An ``after_tool_call`` carrying no tool_results emits nothing."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="c")
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(role="user", content=[], tool_results=None)
|
||||||
|
_run(emitter.as_after_tool_call_hook()(None, msg))
|
||||||
|
assert ctx.session.calls == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_install_for_request_merges_with_existing_after_llm_call() -> None:
|
||||||
|
"""A pre-existing ``after_llm_call`` hook is composed, not replaced."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
agent = _FakeAgent()
|
||||||
|
|
||||||
|
seen: list[str] = []
|
||||||
|
|
||||||
|
async def base_after(_runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
seen.append("base")
|
||||||
|
|
||||||
|
agent.tool_runner_hooks = ToolRunnerHooks(after_llm_call=base_after)
|
||||||
|
|
||||||
|
restore = install_for_request(
|
||||||
|
agent, ctx=ctx, agent_name="alan", conversation_id="conv-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="hi")],
|
||||||
|
stop_reason=LlmStopReason.END_TURN,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(agent.tool_runner_hooks.after_llm_call(None, msg))
|
||||||
|
|
||||||
|
assert seen == ["base"]
|
||||||
|
assert len(ctx.session.calls) == 1
|
||||||
|
assert ctx.session.calls[0]["data"]["content"] == [
|
||||||
|
{"type": "text", "text": "hi"}
|
||||||
|
]
|
||||||
|
|
||||||
|
# All four hook slots used by the emitter are now bound; the restore
|
||||||
|
# call puts the original hooks back exactly.
|
||||||
|
assert agent.tool_runner_hooks.before_tool_call is not None
|
||||||
|
assert agent.tool_runner_hooks.after_tool_call is not None
|
||||||
|
|
||||||
|
restore()
|
||||||
|
assert agent.tool_runner_hooks is not None
|
||||||
|
assert agent.tool_runner_hooks.after_llm_call is base_after
|
||||||
|
assert agent.tool_runner_hooks.before_tool_call is None
|
||||||
|
assert agent.tool_runner_hooks.after_tool_call is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_install_for_request_with_no_existing_hooks() -> None:
|
||||||
|
"""When the agent has no prior hooks, ours is installed cleanly."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
agent = _FakeAgent()
|
||||||
|
assert agent.tool_runner_hooks is None
|
||||||
|
|
||||||
|
restore = install_for_request(
|
||||||
|
agent, ctx=ctx, agent_name="alan", conversation_id=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent.tool_runner_hooks is not None
|
||||||
|
assert agent.tool_runner_hooks.after_llm_call is not None
|
||||||
|
assert agent.tool_runner_hooks.before_tool_call is not None
|
||||||
|
assert agent.tool_runner_hooks.after_tool_call is not None
|
||||||
|
|
||||||
|
restore()
|
||||||
|
assert agent.tool_runner_hooks is None
|
||||||
55
tests/test_log.py
Normal file
55
tests/test_log.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Tests for ``pallas.log._JSONFormatter``.
|
||||||
|
|
||||||
|
The formatter must serialize caller-supplied ``extra={...}`` fields (the
|
||||||
|
loop guard and other diagnostics rely on this) while never emitting the
|
||||||
|
internal ``LogRecord`` bookkeeping attributes.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from pallas.log import _JSONFormatter
|
||||||
|
|
||||||
|
|
||||||
|
def _format(msg: str, **extra) -> dict:
|
||||||
|
record = logging.LogRecord(
|
||||||
|
name="pallas.test",
|
||||||
|
level=logging.WARNING,
|
||||||
|
pathname=__file__,
|
||||||
|
lineno=1,
|
||||||
|
msg=msg,
|
||||||
|
args=(),
|
||||||
|
exc_info=None,
|
||||||
|
)
|
||||||
|
for key, value in extra.items():
|
||||||
|
setattr(record, key, value)
|
||||||
|
return json.loads(_JSONFormatter().format(record))
|
||||||
|
|
||||||
|
|
||||||
|
def test_extra_fields_are_serialized():
|
||||||
|
out = _format(
|
||||||
|
"agentic loop halted",
|
||||||
|
event="loop_halt",
|
||||||
|
tool="kairos-update_task",
|
||||||
|
repeat_count=3,
|
||||||
|
result_preview="COMPLETED but 0%",
|
||||||
|
)
|
||||||
|
assert out["message"] == "agentic loop halted"
|
||||||
|
assert out["event"] == "loop_halt"
|
||||||
|
assert out["tool"] == "kairos-update_task"
|
||||||
|
assert out["repeat_count"] == 3
|
||||||
|
assert out["result_preview"] == "COMPLETED but 0%"
|
||||||
|
|
||||||
|
|
||||||
|
def test_standard_attributes_are_not_leaked():
|
||||||
|
out = _format("plain message")
|
||||||
|
for noise in ("msg", "args", "levelno", "pathname", "lineno", "funcName"):
|
||||||
|
assert noise not in out
|
||||||
|
assert out["level"] == "WARNING"
|
||||||
|
assert out["logger"] == "pallas.test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_serializable_extra_does_not_crash():
|
||||||
|
out = _format("with object", obj=object())
|
||||||
|
assert "obj" in out # coerced via default=str, not dropped or raised
|
||||||
190
tests/test_loop_guard.py
Normal file
190
tests/test_loop_guard.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
"""Tests for ``pallas.loop_guard``.
|
||||||
|
|
||||||
|
Drives the ``before_tool_call`` / ``after_tool_call`` hooks with handcrafted
|
||||||
|
``PromptMessageExtended`` objects against a fake ToolRunner and asserts the
|
||||||
|
halt behaviour: the runner's ``max_iterations`` is collapsed to the current
|
||||||
|
iteration (so fast-agent terminates on its next check), the dangling turn is
|
||||||
|
annotated with an explanation, and the abort metric is incremented.
|
||||||
|
|
||||||
|
No fast-agent runtime is involved — the hooks are pure async functions.
|
||||||
|
Uses ``asyncio.run`` directly to match the convention in the other test
|
||||||
|
modules (pallas has no pytest-asyncio dependency).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fast_agent.types import PromptMessageExtended
|
||||||
|
from fast_agent.types.llm_stop_reason import LlmStopReason
|
||||||
|
from mcp.types import (
|
||||||
|
CallToolRequest,
|
||||||
|
CallToolRequestParams,
|
||||||
|
CallToolResult,
|
||||||
|
TextContent,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pallas import metrics as _pallas_metrics
|
||||||
|
from pallas.loop_guard import LoopGuard, install_for_request
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call(name: str, arguments: dict | None = None) -> CallToolRequest:
|
||||||
|
return CallToolRequest(
|
||||||
|
method="tools/call",
|
||||||
|
params=CallToolRequestParams(name=name, arguments=arguments or {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_result(text: str = "ok", *, is_error: bool = False) -> CallToolResult:
|
||||||
|
return CallToolResult(
|
||||||
|
content=[TextContent(type="text", text=text)], isError=is_error
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _request(name: str, arguments: dict, call_id: str = "toolu_1"):
|
||||||
|
return PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[],
|
||||||
|
tool_calls={call_id: _tool_call(name, arguments)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _result(text: str, call_id: str = "toolu_1"):
|
||||||
|
return PromptMessageExtended(
|
||||||
|
role="user", content=[], tool_results={call_id: _tool_result(text)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRunner:
|
||||||
|
def __init__(self, *, iteration: int = 5, max_iterations: int = 30) -> None:
|
||||||
|
self.request_params = SimpleNamespace(max_iterations=max_iterations)
|
||||||
|
self.iteration = iteration
|
||||||
|
self.last_message = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[],
|
||||||
|
stop_reason=LlmStopReason.TOOL_USE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _drive(guard: LoopGuard, runner: _FakeRunner, name, args, result) -> None:
|
||||||
|
"""Run one tool round through the guard's hooks."""
|
||||||
|
before = guard.as_before_tool_call_hook()
|
||||||
|
after = guard.as_after_tool_call_hook()
|
||||||
|
await before(runner, _request(name, args))
|
||||||
|
await after(runner, _result(result))
|
||||||
|
|
||||||
|
|
||||||
|
def _abort_count(agent: str) -> float:
|
||||||
|
return (
|
||||||
|
_pallas_metrics.agent_loop_aborted_total.labels(agent=agent, reason="repeat")
|
||||||
|
._value.get()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _last_text(runner: _FakeRunner) -> str:
|
||||||
|
return "".join(
|
||||||
|
b.text for b in (runner.last_message.content or []) if hasattr(b, "text")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_halts_on_third_identical_round():
|
||||||
|
guard = LoopGuard(agent_name="shawn", conversation_id="c1", threshold=3)
|
||||||
|
runner = _FakeRunner(iteration=12, max_iterations=30)
|
||||||
|
before = _abort_count("shawn")
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
for _ in range(2):
|
||||||
|
await _drive(guard, runner, "kairos-update_task", {"task_id": 494}, "same")
|
||||||
|
# not yet halted after rounds 1 and 2
|
||||||
|
assert runner.request_params.max_iterations == 30
|
||||||
|
assert runner.last_message.stop_reason == LlmStopReason.TOOL_USE
|
||||||
|
# third identical round trips the guard
|
||||||
|
await _drive(guard, runner, "kairos-update_task", {"task_id": 494}, "same")
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
|
||||||
|
# max_iterations collapsed to the current iteration -> fast-agent stops
|
||||||
|
# on its next `_iteration > max_iterations` check, no further LLM call.
|
||||||
|
assert runner.request_params.max_iterations == 12
|
||||||
|
assert runner.last_message.stop_reason == LlmStopReason.END_TURN
|
||||||
|
assert "Halted" in _last_text(runner)
|
||||||
|
assert "kairos-update_task" in _last_text(runner)
|
||||||
|
assert _abort_count("shawn") == before + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_halt_when_result_changes():
|
||||||
|
guard = LoopGuard(agent_name="a1", conversation_id=None, threshold=3)
|
||||||
|
runner = _FakeRunner()
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
for i in range(6):
|
||||||
|
await _drive(
|
||||||
|
guard, runner, "kairos-update_task", {"task_id": 494}, f"r{i}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
assert runner.request_params.max_iterations == 30
|
||||||
|
assert runner.last_message.stop_reason == LlmStopReason.TOOL_USE
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_halt_when_args_change():
|
||||||
|
guard = LoopGuard(agent_name="a2", conversation_id=None, threshold=3)
|
||||||
|
runner = _FakeRunner()
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
for i in range(6):
|
||||||
|
await _drive(guard, runner, "kairos-update_task", {"task_id": i}, "same")
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
assert runner.request_params.max_iterations == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_threshold_respected():
|
||||||
|
guard = LoopGuard(agent_name="a3", conversation_id=None, threshold=5)
|
||||||
|
runner = _FakeRunner()
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
for _ in range(4):
|
||||||
|
await _drive(guard, runner, "t", {"x": 1}, "same")
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
# 4 identical rounds, threshold 5 -> still running
|
||||||
|
assert runner.request_params.max_iterations == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_halt_fires_once():
|
||||||
|
guard = LoopGuard(agent_name="a4", conversation_id=None, threshold=3)
|
||||||
|
runner = _FakeRunner()
|
||||||
|
before = _abort_count("a4")
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
for _ in range(6):
|
||||||
|
await _drive(guard, runner, "t", {"x": 1}, "same")
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
assert _abort_count("a4") == before + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_install_disabled_with_nonpositive_threshold():
|
||||||
|
agent = SimpleNamespace(tool_runner_hooks="sentinel")
|
||||||
|
restore = install_for_request(
|
||||||
|
agent, agent_name="a", conversation_id=None, threshold=0
|
||||||
|
)
|
||||||
|
assert agent.tool_runner_hooks == "sentinel" # untouched
|
||||||
|
restore() # no-op, must not raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_install_merges_and_restores():
|
||||||
|
agent = SimpleNamespace(tool_runner_hooks=None)
|
||||||
|
restore = install_for_request(
|
||||||
|
agent, agent_name="a", conversation_id=None, threshold=3
|
||||||
|
)
|
||||||
|
assert agent.tool_runner_hooks is not None
|
||||||
|
assert agent.tool_runner_hooks.after_tool_call is not None
|
||||||
|
restore()
|
||||||
|
assert agent.tool_runner_hooks is None
|
||||||
Reference in New Issue
Block a user