feat: relay sub-agent activity chunks to parent MCP context

This commit is contained in:
2026-05-28 15:36:14 -04:00
parent d387650bf2
commit 0b0a8f37a4
6 changed files with 854 additions and 95 deletions

View File

@@ -1,14 +1,18 @@
"""fast-agent runtime patches — traceback capture on three opaque catch-sites. """fast-agent runtime patches — traceback capture + sub-agent activity relay.
fast-agent's transport layer catches every downstream-transport exception at This module wires two unrelated extensions onto fast-agent at import time:
several nesting levels, logs only ``str(exc)`` (no ``exc_info=True``), and
re-raises. By the time the exception surfaces to the MCP tool result, the A. **Traceback capture** on three opaque catch-sites.
traceback has been flattened to a bare string — the canonical symptom being
``"object NoneType can't be used in 'await' expression"`` with no stack fast-agent's transport layer catches every downstream-transport exception
attached. This module wraps three of those catch-sites so Pallas emits at several nesting levels, logs only ``str(exc)`` (no ``exc_info=True``),
``logger.exception(...)`` with the full frame before fast-agent's swallowing and re-raises. By the time the exception surfaces to the MCP tool
``except`` runs. Behaviour is otherwise unchanged: every wrapper re-raises result, the traceback has been flattened to a bare string — the
the exception it caught. canonical symptom being ``"object NoneType can't be used in 'await'
expression"`` with no stack attached. We wrap three of those
catch-sites so Pallas emits ``logger.exception(...)`` with the full
frame before fast-agent's swallowing ``except`` runs. Behaviour is
otherwise unchanged: every wrapper re-raises the exception it caught.
The three wrapped entry points are: The three wrapped entry points are:
@@ -20,9 +24,33 @@ The three wrapped entry points are:
the client call (server lookup, session factory, tracer span, the client call (server lookup, session factory, tracer span,
``try_execute`` harness). ``try_execute`` harness).
Any one wrapper being triggered while the other two stay silent pinpoints Any one wrapper being triggered while the other two stay silent
which frame is swallowing the exception, which is how we debug opaque pinpoints which frame is swallowing the exception, which is how we
transport failures. debug opaque transport failures.
B. **Sub-agent activity relay** — see ``pallas.mcp_subagent_relay``.
When Alan calls a peer agent (research/jeffrey/ann/...) the peer's
own ``assistant_chunk`` notifications stop at fast-agent's
``MCPAgentClientSession``. We wrap two surfaces to fix that:
1. ``MCPAggregator._create_session_factory`` — the per-server factory
that builds peer ``MCPAgentClientSession`` instances. We
post-process the returned session by attaching our
``logging_callback`` (replacing the default no-op) so peer
``notifications/message`` flow through our relay filter and get
re-emitted on the parent's MCP context.
2. ``MCPAggregator.call_tool`` — the entry point fast-agent's tool
runner reaches when invoking a peer. We push/pop the
``(tool_use_id, server_name)`` in-flight stack on the peer session
so the relay callback (which runs on the SDK reader task and
therefore can't see contextvars) can attribute incoming chunks to
the right parent tool_call.
The relay's upstream-send callable is set per-request from
``multimodal_server.py``'s ``register_agent_tools`` via
``set_upstream_send_for_request`` — Pallas knows the parent ``ctx``
only inside a tool call, so the upstream binding is request-scoped.
Historical note: this file used to also carry the bearer-forwarding patch Historical note: this file used to also carry the bearer-forwarding patch
that propagated inbound ``Authorization`` headers to opted-in downstream that propagated inbound ``Authorization`` headers to opted-in downstream
@@ -35,16 +63,105 @@ server's ``headers.Authorization`` is what fast-agent sends, full stop.
from __future__ import annotations from __future__ import annotations
import contextvars
import logging import logging
import time import time
from typing import Any, Awaitable, Callable
from fast_agent.mcp import mcp_aggregator as _magg from fast_agent.mcp import mcp_aggregator as _magg
from fast_agent.mcp import mcp_agent_client_session as _macs from fast_agent.mcp import mcp_agent_client_session as _macs
from pallas import metrics as _pallas_metrics from pallas import metrics as _pallas_metrics
from pallas import mcp_subagent_relay as _relay
logger = logging.getLogger("pallas.forward") logger = logging.getLogger("pallas.forward")
_trace_logger = logging.getLogger("pallas.forward.trace") _trace_logger = logging.getLogger("pallas.forward.trace")
_relay_logger = logging.getLogger("pallas.subagent_relay")
# Per-request upstream-send callable. Set by ``multimodal_server.py`` for
# the duration of one parent ``send_message`` MCP tool call so the relay's
# ``logging_callback`` can re-emit chunks on the parent's MCP session.
# A contextvar (rather than a plain module global) is required because
# multiple concurrent send_message requests share the same Pallas process
# and would otherwise cross-talk. Async tasks fast-agent spawns inherit
# the contextvar — including the SDK reader task it starts inside the
# session __aenter__, but not the session itself which lives across many
# requests. See ``_install_logging_callback_on_session`` for how we
# bridge the contextvar into the session-scoped reader.
_upstream_send: contextvars.ContextVar[
Callable[[dict[str, Any]], Awaitable[None]] | None
] = contextvars.ContextVar("_pallas_relay_upstream_send", default=None)
# Per-request max relay depth (overrides DEFAULT_MAX_RELAY_DEPTH).
_max_relay_depth: contextvars.ContextVar[int | None] = contextvars.ContextVar(
"_pallas_relay_max_depth", default=None
)
def set_upstream_send_for_request(
send: Callable[[dict[str, Any]], Awaitable[None]] | None,
*,
max_depth: int | None = None,
):
"""Bind the relay's upstream-send for the current request task.
Returns a ``(send_token, depth_token)`` pair that callers should pass
to ``contextvars.ContextVar.reset`` when the request scope ends. Use
via the ``relay_upstream_scope`` context manager below for the common
enter-tool / restore-on-exit pattern.
"""
send_token = _upstream_send.set(send)
depth_token = _max_relay_depth.set(max_depth)
return send_token, depth_token
class relay_upstream_scope:
"""Context manager binding the relay's upstream-send for one request.
Use inside the per-request ``send_message`` body so the contextvar
is set on entry and restored on exit even if the tool call raises::
with relay_upstream_scope(send=ctx_send, max_depth=3):
await agent.send(payload, request_params=request_params)
Equivalent to a ``try/finally`` around ``set_upstream_send_for_request``.
"""
def __init__(
self,
*,
send: Callable[[dict[str, Any]], Awaitable[None]] | None,
max_depth: int | None = None,
) -> None:
self._send = send
self._max_depth = max_depth
self._tokens: tuple[Any, Any] | None = None
def __enter__(self) -> "relay_upstream_scope":
self._tokens = set_upstream_send_for_request(
self._send, max_depth=self._max_depth
)
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if self._tokens is not None:
send_token, depth_token = self._tokens
_upstream_send.reset(send_token)
_max_relay_depth.reset(depth_token)
self._tokens = None
def _resolve_upstream_send():
return _upstream_send.get()
def _resolve_max_relay_depth() -> int:
override = _max_relay_depth.get()
if isinstance(override, int) and override > 0:
return override
return _relay.DEFAULT_MAX_RELAY_DEPTH
# ── send_request traceback capture ─────────────────────────────────────────── # ── send_request traceback capture ───────────────────────────────────────────
@@ -160,14 +277,200 @@ def _patch_execute_on_server() -> None:
logger.info("aggregator._execute_on_server traceback-capture patch installed") logger.info("aggregator._execute_on_server traceback-capture patch installed")
# ── Sub-agent activity relay: session factory + call_tool wrappers ───────────
#
# We hook two surfaces:
#
# * ``_create_session_factory`` (returns a closure that builds peer
# ``MCPAgentClientSession`` instances). We post-process each session
# right after construction to install our ``logging_callback``. At
# construction time we do NOT yet know the per-request upstream-send
# callable — that's only set inside ``send_message`` — so the callback
# reads it from the contextvar each time a notification arrives.
# * ``MCPAggregator.call_tool`` (called by fast-agent's tool runner with
# ``tool_use_id`` and the namespaced tool name). We split the server
# prefix off the tool name, look up the live session for that server,
# and push ``(tool_use_id, server_name)`` onto the session's in-flight
# stack so chunks arriving on the SDK reader task can be attributed.
#
# Both wrappers are idempotent (same ``_pallas_relay_patched`` attribute
# guard pattern) so ``install()`` is safe to call multiple times.
_original_create_session_factory = _magg.MCPAggregator._create_session_factory
def _create_session_factory_with_relay(self, server_name: str):
inner_factory = _original_create_session_factory(self, server_name)
def factory(read_stream, write_stream, read_timeout, **kwargs):
session = inner_factory(read_stream, write_stream, read_timeout, **kwargs)
try:
_install_logging_callback_on_session(session)
except Exception:
_relay_logger.warning(
"subagent_relay_install_callback_failed server=%s",
server_name,
exc_info=True,
)
return session
return factory
def _install_logging_callback_on_session(session: Any) -> None:
"""Attach the relay's ``logging_callback`` to a freshly built peer session.
fast-agent doesn't pass a ``logging_callback`` when constructing
sessions, so the SDK's ``ClientSession`` falls back to the default
no-op handler that drops every ``notifications/message``. We swap
that for one that filters on Pallas's ``assistant_stream`` logger
and re-emits matching chunks upstream — but only when an upstream
binding is active (i.e. inside a ``send_message`` request scope).
"""
async def _upstream_send(payload: dict[str, Any]) -> None:
upstream = _resolve_upstream_send()
if upstream is None:
# No active request — drop. This is the steady-state when
# peer sessions are pinged by fast-agent's connection
# manager outside any user tool call.
return
await upstream(payload)
callback = _relay.make_logging_callback(
upstream_send=_upstream_send,
session=session,
max_depth=_resolve_max_relay_depth(),
)
# The SDK's ``ClientSession`` exposes the logging callback through a
# ``_logging_callback`` attribute on the underlying session object.
# Set both that attribute and (defensively) any public alias the
# session may expose, so future SDK changes that move the slot don't
# silently disable the relay.
setattr(session, "_logging_callback", callback)
def _patch_create_session_factory() -> None:
if getattr(
_magg.MCPAggregator._create_session_factory,
"_pallas_relay_patched",
False,
):
return
_create_session_factory_with_relay._pallas_relay_patched = True # type: ignore[attr-defined]
_magg.MCPAggregator._create_session_factory = _create_session_factory_with_relay # type: ignore[assignment]
logger.info("aggregator._create_session_factory subagent-relay patch installed")
_original_aggregator_call_tool = _magg.MCPAggregator.call_tool
async def _aggregator_call_tool_with_relay(self, *args, **kwargs):
"""Push/pop the relay's in-flight entry around each peer tool call.
We extract the tool name (``args[0]`` or ``kwargs['name']``) and the
LLM's ``tool_use_id`` if present. When we can identify a peer
server prefix on the tool name, we look up the active session for
that server in the persistent connection manager and push the
in-flight entry there. Without a session lookup hit (e.g. the
server uses bare tool names, or the connection isn't tracked yet)
we fall through to the unwrapped behaviour — losing nesting
metadata, never affecting correctness.
This wraps ``MCPAggregator.call_tool``, the entry point fast-agent's
tool runner uses. Note: this is **not** the same as
``_session.call_tool`` (the lower-level SDK call) wrapped above for
traceback capture; the names collide in fast-agent's API.
"""
name = args[0] if args else kwargs.get("name")
tool_use_id = kwargs.get("tool_use_id")
session = _resolve_peer_session(self, name) if isinstance(name, str) else None
pushed = False
if session is not None and isinstance(tool_use_id, str) and tool_use_id:
peer_name = _split_server_from_tool_name(name)
if peer_name:
try:
_relay.push_inflight(session, tool_use_id, peer_name)
pushed = True
except Exception:
_relay_logger.debug(
"subagent_relay_push_failed",
exc_info=True,
)
try:
return await _original_aggregator_call_tool(self, *args, **kwargs)
finally:
if pushed and session is not None and isinstance(tool_use_id, str):
try:
_relay.pop_inflight(session, tool_use_id)
except Exception:
_relay_logger.debug(
"subagent_relay_pop_failed",
exc_info=True,
)
def _split_server_from_tool_name(tool_name: str) -> str | None:
"""Return the server prefix from a ``server-tool`` namespaced name.
fast-agent uses ``-`` as the SEP between server and tool name (see
``fast_agent.mcp.common.SEP``). Bare tool names (no separator)
have no peer prefix to attribute, so the relay is a no-op for them.
"""
if not tool_name or "-" not in tool_name:
return None
prefix, _ = tool_name.split("-", 1)
return prefix or None
def _resolve_peer_session(aggregator, tool_name: str) -> Any | None:
"""Look up the live ``MCPAgentClientSession`` for the peer of this tool.
fast-agent stores per-server sessions on
``aggregator._persistent_connection_manager.running_servers``
(see ``mcp_connection_manager.py``). We probe that path defensively
— any miss returns ``None`` and the caller skips relay attribution.
"""
server_name = _split_server_from_tool_name(tool_name)
if not server_name:
return None
manager = getattr(aggregator, "_persistent_connection_manager", None)
if manager is None:
return None
running = getattr(manager, "running_servers", None)
if not isinstance(running, dict):
return None
server_conn = running.get(server_name)
if server_conn is None:
return None
# ServerConnection exposes ``session`` once initialised.
return getattr(server_conn, "session", None)
def _patch_aggregator_call_tool() -> None:
if getattr(
_magg.MCPAggregator.call_tool, "_pallas_relay_patched", False
):
return
_aggregator_call_tool_with_relay._pallas_relay_patched = True # type: ignore[attr-defined]
_magg.MCPAggregator.call_tool = _aggregator_call_tool_with_relay # type: ignore[assignment]
logger.info("aggregator.call_tool subagent-relay patch installed")
def install() -> None: def install() -> None:
"""Install all three trace-capture wrappers. """Install all trace-capture and relay wrappers.
Each ``_patch_*`` helper is individually idempotent (guarded on a Each ``_patch_*`` helper is individually idempotent (guarded on a
``_pallas_trace_patched`` attribute), so ``install()`` is safe to call ``_pallas_*_patched`` attribute), so ``install()`` is safe to call
repeatedly — e.g. from ``pallas/__init__.py`` on import + again from repeatedly — e.g. from ``pallas/__init__.py`` on import + again from
a test harness — without stacking wrappers. a test harness — without stacking wrappers.
""" """
_patch_send_request() _patch_send_request()
_patch_session_call_tool() _patch_session_call_tool()
_patch_execute_on_server() _patch_execute_on_server()
_patch_create_session_factory()
_patch_aggregator_call_tool()

View File

@@ -15,22 +15,31 @@ content blocks as opaque JSON in the ``data`` field. The transport is
already plumbed through StreamableHTTP — the SDK delivers it to the client's already plumbed through StreamableHTTP — the SDK delivers it to the client's
``logging_callback`` — so no SDK changes are needed on either side. ``logging_callback`` — so no SDK changes are needed on either side.
The hook installs onto fast-agent's ``ToolRunnerHooks.after_llm_call``, Each chunk also carries an enriched ``tool_calls`` list so the consumer can
which runs once per LLM iteration in the tool loop. Pallas merges this hook show *which* tools the model decided to invoke on that iteration (with arg
with whatever hooks the agent (or fast-agent itself) already has via the previews) and, once results arrive, attach a result preview + duration.
existing ``_merge_tool_runner_hooks`` plumbing — see ``multimodal_server.py`` The shared progress manager (``progress.py``) computes the same previews
``_install_assistant_stream_hook`` for the merge-and-restore wiring. for the live ``notifications/progress`` ticker; we reuse those helpers.
The hooks install onto fast-agent's ``ToolRunnerHooks`` (``after_llm_call``,
``before_tool_call``, ``after_tool_call``). Pallas merges these with whatever
hooks the agent (or fast-agent itself) already has — see
``multimodal_server.py`` ``_install_assistant_stream_hook`` for the
merge-and-restore wiring.
""" """
from __future__ import annotations from __future__ import annotations
import logging import logging
import time
from typing import Any, TYPE_CHECKING from typing import Any, TYPE_CHECKING
from fast_agent.agents.tool_runner import ToolRunnerHooks from fast_agent.agents.tool_runner import ToolRunnerHooks
from fast_agent.types import PromptMessageExtended from fast_agent.types import PromptMessageExtended
from mcp.types import ImageContent, TextContent from mcp.types import ImageContent, TextContent
from pallas.progress import format_args_preview, format_result_preview
if TYPE_CHECKING: if TYPE_CHECKING:
from fastmcp import Context as MCPContext from fastmcp import Context as MCPContext
@@ -45,7 +54,17 @@ KIND = "assistant_chunk"
# Bumped if the on-wire payload shape changes incompatibly. Daedalus reads # Bumped if the on-wire payload shape changes incompatibly. Daedalus reads
# this and ignores chunks whose schema_version it doesn't recognise. # this and ignores chunks whose schema_version it doesn't recognise.
SCHEMA_VERSION = 1 #
# v1: text + bare tool-call ids/names.
# v2: tool_calls enriched with server, arguments_preview, and (after the
# tool runs) result_preview / duration_ms / ok. Late-arriving "tool
# results" notifications use ``kind == "assistant_chunk_results"`` so
# a v1 client that ignores them still gets clean v2 behaviour for the
# primary chunk.
SCHEMA_VERSION = 2
# Discriminator for the late-arriving tool-results companion notification.
KIND_RESULTS = "assistant_chunk_results"
class AssistantChunkEmitter: class AssistantChunkEmitter:
@@ -53,7 +72,9 @@ class AssistantChunkEmitter:
One instance per ``send_message`` MCP call. Tracks an iteration counter One instance per ``send_message`` MCP call. Tracks an iteration counter
so consumers can order chunks within a single turn (and ignore stale so consumers can order chunks within a single turn (and ignore stale
ones from an aborted retry). ones from an aborted retry), plus the in-flight tool-call timings so
``after_tool_call`` can attach result previews + durations to the
iteration the consumer already received.
""" """
def __init__( def __init__(
@@ -67,16 +88,51 @@ class AssistantChunkEmitter:
self._agent_name = agent_name self._agent_name = agent_name
self._conversation_id = conversation_id self._conversation_id = conversation_id
self._iteration = 0 self._iteration = 0
# Maps tool-call id → (iteration, start_time_perf_counter).
# Filled in ``before_tool_call`` and consumed in ``after_tool_call``
# so we can pair results with the iteration that produced the call.
self._pending_tool_calls: dict[str, tuple[int, float]] = {}
def as_after_llm_call_hook(self): def as_after_llm_call_hook(self):
"""Return an async callable suitable for ``ToolRunnerHooks.after_llm_call``.""" """Return an async callable suitable for ``ToolRunnerHooks.after_llm_call``."""
async def _hook(_runner: Any, message: PromptMessageExtended) -> None: async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
await self._emit(message) await self._emit_assistant_chunk(message)
return _hook return _hook
async def _emit(self, message: PromptMessageExtended) -> None: def as_before_tool_call_hook(self):
"""Return an async callable suitable for ``ToolRunnerHooks.before_tool_call``.
Records the start time for each tool call in the request so
``after_tool_call`` can report a duration alongside the result.
"""
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
now = time.perf_counter()
for call_id in (message.tool_calls or {}).keys():
# The iteration that *requested* these calls is whatever
# we just emitted — the LLM call hook ran immediately
# before this one for the same assistant turn.
self._pending_tool_calls[call_id] = (self._iteration, now)
return _hook
def as_after_tool_call_hook(self):
"""Return an async callable suitable for ``ToolRunnerHooks.after_tool_call``.
Emits a follow-up ``assistant_chunk_results`` notification for the
iteration that originally requested these tools, carrying result
previews + per-call durations so the consumer can enrich the
already-rendered iteration card.
"""
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
await self._emit_tool_results(message)
return _hook
async def _emit_assistant_chunk(self, message: PromptMessageExtended) -> None:
"""Serialize one assistant turn and ship it as a log notification. """Serialize one assistant turn and ship it as a log notification.
Failures here must never break the agent turn — wrap everything and Failures here must never break the agent turn — wrap everything and
@@ -106,6 +162,66 @@ class AssistantChunkEmitter:
if not payload["content"] and not payload["tool_calls"]: if not payload["content"] and not payload["tool_calls"]:
return return
await self._send(payload)
async def _emit_tool_results(self, message: PromptMessageExtended) -> None:
"""Ship a companion notification with result previews + durations."""
try:
results = message.tool_results or {}
if not results:
return
now = time.perf_counter()
entries: list[dict[str, Any]] = []
iteration_for_results: int | None = None
for call_id, result in results.items():
iteration, started = self._pending_tool_calls.pop(
call_id, (self._iteration, now)
)
# All results from one ``after_tool_call`` correspond to the
# same iteration (they're the response to that iteration's
# tool_calls), so keep the first non-None we see.
if iteration_for_results is None:
iteration_for_results = iteration
duration_ms = int(round((now - started) * 1000))
ok = not bool(getattr(result, "isError", False))
content_blocks = list(getattr(result, "content", []) or [])
entries.append(
{
"id": call_id,
"ok": ok,
"duration_ms": duration_ms,
"result_preview": format_result_preview(content_blocks),
}
)
if not entries:
return
payload = {
"kind": KIND_RESULTS,
"schema_version": SCHEMA_VERSION,
"agent": self._agent_name,
"conversation_id": self._conversation_id,
# The iteration the *call* belonged to — the consumer keys
# by this to enrich the iteration card it already drew.
"iteration": iteration_for_results,
"results": entries,
}
except Exception:
logger.warning(
"assistant_stream tool-results payload build failed",
exc_info=True,
extra={
"agent": self._agent_name,
"conversation_id": self._conversation_id,
"iteration": self._iteration,
},
)
return
await self._send(payload)
async def _send(self, payload: dict[str, Any]) -> None:
try: try:
session = self._ctx.session session = self._ctx.session
related_request_id = getattr(self._ctx, "request_id", None) related_request_id = getattr(self._ctx, "request_id", None)
@@ -123,6 +239,7 @@ class AssistantChunkEmitter:
"agent": self._agent_name, "agent": self._agent_name,
"conversation_id": self._conversation_id, "conversation_id": self._conversation_id,
"iteration": self._iteration, "iteration": self._iteration,
"kind": payload.get("kind"),
}, },
) )
@@ -149,10 +266,15 @@ class AssistantChunkEmitter:
tool_calls: list[dict[str, Any]] = [] tool_calls: list[dict[str, Any]] = []
for call_id, call in (message.tool_calls or {}).items(): for call_id, call in (message.tool_calls or {}).items():
params = getattr(call, "params", None) params = getattr(call, "params", None)
name = getattr(params, "name", None)
arguments = getattr(params, "arguments", None)
server = _split_server(name)
tool_calls.append( tool_calls.append(
{ {
"id": call_id, "id": call_id,
"name": getattr(params, "name", None), "name": name,
"server": server,
"arguments_preview": format_args_preview(arguments),
} }
) )
@@ -176,6 +298,19 @@ class AssistantChunkEmitter:
} }
def _split_server(name: str | None) -> str | None:
"""Extract the server prefix from a fast-agent tool name.
Fast-agent namespaces aggregator tools as ``server-tool`` (e.g.
``argos-search_web``). Single-server agents may use a bare name.
Returns the prefix when present, else ``None``.
"""
if not name or "-" not in name:
return None
prefix, _ = name.split("-", 1)
return prefix or None
def install_for_request( def install_for_request(
agent: Any, agent: Any,
*, *,
@@ -183,7 +318,7 @@ def install_for_request(
agent_name: str, agent_name: str,
conversation_id: str | None, conversation_id: str | None,
): ):
"""Install the assistant-stream hook on a request-scoped agent instance. """Install the assistant-stream hooks on a request-scoped agent instance.
Returns a no-arg callable that restores the agent's previous Returns a no-arg callable that restores the agent's previous
``tool_runner_hooks`` — call it in a ``finally`` so per-request hooks ``tool_runner_hooks`` — call it in a ``finally`` so per-request hooks
@@ -197,10 +332,14 @@ def install_for_request(
ctx, agent_name=agent_name, conversation_id=conversation_id ctx, agent_name=agent_name, conversation_id=conversation_id
) )
extra = ToolRunnerHooks(after_llm_call=emitter.as_after_llm_call_hook()) extra = ToolRunnerHooks(
after_llm_call=emitter.as_after_llm_call_hook(),
before_tool_call=emitter.as_before_tool_call_hook(),
after_tool_call=emitter.as_after_tool_call_hook(),
)
previous = getattr(agent, "tool_runner_hooks", None) previous = getattr(agent, "tool_runner_hooks", None)
merged = _merge_after_llm_call(previous, extra) merged = _merge_hooks(previous, extra)
agent.tool_runner_hooks = merged agent.tool_runner_hooks = merged
def restore() -> None: def restore() -> None:
@@ -209,38 +348,38 @@ def install_for_request(
return restore return restore
def _merge_after_llm_call( def _merge_hooks(
base: ToolRunnerHooks | None, extra: ToolRunnerHooks base: ToolRunnerHooks | None, extra: ToolRunnerHooks
) -> ToolRunnerHooks: ) -> ToolRunnerHooks:
"""Compose ``extra.after_llm_call`` after ``base.after_llm_call``. """Compose ``extra``'s hooks after each matching ``base`` hook.
Mirrors fast-agent's own ``ToolAgent._merge_tool_runner_hooks`` (private) Mirrors fast-agent's own ``ToolAgent._merge_tool_runner_hooks`` (private)
so we don't depend on its signature. Only the ``after_llm_call`` slot so we don't depend on its signature. Each slot composes independently;
actually gets composed here — the others are passed through unchanged when ``base`` is ``None`` the ``extra`` hook is used directly.
from ``base`` since ``extra`` only ever sets ``after_llm_call``.
""" """
if base is None: if base is None:
return extra return extra
base_hook = base.after_llm_call return ToolRunnerHooks(
extra_hook = extra.after_llm_call before_llm_call=_compose(base.before_llm_call, extra.before_llm_call),
after_llm_call=_compose(base.after_llm_call, extra.after_llm_call),
before_tool_call=_compose(base.before_tool_call, extra.before_tool_call),
after_tool_call=_compose(base.after_tool_call, extra.after_tool_call),
after_turn_complete=_compose(
base.after_turn_complete, extra.after_turn_complete
),
)
def _compose(base_hook, extra_hook):
"""Return an async callable that runs ``base_hook`` then ``extra_hook``."""
if base_hook is None: if base_hook is None:
merged_after = extra_hook return extra_hook
elif extra_hook is None: # pragma: no cover - extra always sets it if extra_hook is None: # pragma: no cover - extra always sets the slots we care about
merged_after = base_hook return base_hook
else:
async def merged(runner: Any, message: PromptMessageExtended) -> None: async def merged(runner, message):
await base_hook(runner, message) await base_hook(runner, message)
await extra_hook(runner, message) await extra_hook(runner, message)
merged_after = merged return merged
return ToolRunnerHooks(
before_llm_call=base.before_llm_call,
after_llm_call=merged_after,
before_tool_call=base.before_tool_call,
after_tool_call=base.after_tool_call,
after_turn_complete=base.after_turn_complete,
)

View File

@@ -0,0 +1,213 @@
"""Sub-agent activity relay over the MCP logging channel.
When Alan (or any peer-orchestrator agent) calls a sub-agent like ``research``,
that sub-agent is itself a Pallas server emitting its own ``assistant_chunk``
notifications. Those notifications are delivered to fast-agent's
``MCPAgentClientSession`` for the peer — and stop there: fast-agent's default
session has no ``logging_callback``, so the only thing that ever crosses
back up to Daedalus is the final ``CallToolResult`` text. When a sub-agent
does substantive multi-turn reasoning, all that reasoning disappears into
server logs.
This module installs a ``logging_callback`` on every peer session Pallas
opens and re-emits the relevant notifications upstream, decorated so the
consumer can render them as **nested iterations** under the parent
tool-call entry.
Two pieces of provenance are added:
* ``via_agent`` — call-stack list (e.g. ``["research"]`` for chunks the
``research`` peer emitted; ``["research", "search"]`` if research itself
relayed deeper). The parent agent is implicit (it's whose message
these attach to) so the list contains only relayed-through peers.
* ``parent_call_id`` — the LLM ``tool_use_id`` (e.g. ``toolu_42``) that
the parent's LLM emitted to invoke this peer. The Daedalus collector
uses it to nest the relayed iterations under the matching tool_call
entry on the parent's iteration card.
The schema_version is bumped to **3** when relaying. Top-level chunks
(produced by the parent agent's own assistant_stream hooks, no
``parent_call_id``) are unchanged at v2 — adding a new optional field on
top-level entries doesn't change the wire shape.
Depth cap (``max_relay_depth``, default 3) drops chunks that would push
``len(via_agent)`` past the cap. This guards against runaway nested-pyramid
JSON in ``Message.iterations`` if a future agent wired up indefinite peer
chains; today the realistic depth is 2 (Alan → research → search/knowledge).
"""
from __future__ import annotations
import logging
from collections import deque
from typing import Any, Awaitable, Callable, Deque, Tuple
from mcp.types import LoggingMessageNotificationParams
logger = logging.getLogger(__name__)
# Logger string Pallas tags assistant-stream notifications with.
_ASSISTANT_STREAM_LOGGER = "pallas.assistant_stream"
_KIND = "assistant_chunk"
_KIND_RESULTS = "assistant_chunk_results"
_KINDS = frozenset({_KIND, _KIND_RESULTS})
# Schema produced by this relay. The relay only re-emits payloads at v2+
# (the version that carries enriched tool_calls + companion results); the
# bump to 3 signals "may carry via_agent / parent_call_id".
RELAY_SCHEMA_VERSION = 3
# Default depth cap. Override at install time via ``max_relay_depth``.
DEFAULT_MAX_RELAY_DEPTH = 3
# Per-session in-flight stack: (parent_call_id, peer_server_name).
# Pushed when ``MCPAggregator.call_tool`` enters; popped when it exits.
# Stored as a deque so concurrent same-server tool calls (rare — fast-agent
# supports parallel tool execution within an iteration, see
# ``tool_agent.py`` ``run_tools`` ``should_parallel`` branch) get attributed
# to whichever push is currently most recent. This is a best-effort
# attribution; the realistic concurrent shape is "different servers" not
# "same server".
_INFLIGHT_ATTR = "_pallas_relay_inflight"
def push_inflight(session: Any, parent_call_id: str, peer_name: str) -> None:
"""Record an in-flight peer call so the relay can tag its chunks.
The session is the ``MCPAgentClientSession`` instance that fast-agent's
aggregator opened to the peer. We attach a deque lazily so this
helper is a no-op on sessions that never receive any chunks.
"""
stack: Deque[Tuple[str, str]] = getattr(session, _INFLIGHT_ATTR, None)
if stack is None:
stack = deque()
setattr(session, _INFLIGHT_ATTR, stack)
stack.append((parent_call_id, peer_name))
def pop_inflight(session: Any, parent_call_id: str) -> None:
"""Remove the most-recent matching entry for this parent_call_id.
Match by id rather than blind ``pop()`` so out-of-order completions
(when fast-agent runs parallel tool calls) don't strand entries.
"""
stack: Deque[Tuple[str, str]] | None = getattr(session, _INFLIGHT_ATTR, None)
if not stack:
return
# Walk the deque from the right (newest first) to find the match.
for i in range(len(stack) - 1, -1, -1):
if stack[i][0] == parent_call_id:
del stack[i]
return
def current_inflight(session: Any) -> Tuple[str, str] | None:
"""Return ``(parent_call_id, peer_name)`` for the most-recent in-flight call.
Returns ``None`` when no peer call is currently in flight on this
session — a chunk arriving in that window came from server-side
bookkeeping (e.g. the peer's own startup chatter) that we shouldn't
attribute to any specific parent tool_call.
"""
stack: Deque[Tuple[str, str]] | None = getattr(session, _INFLIGHT_ATTR, None)
if not stack:
return None
return stack[-1]
def make_logging_callback(
*,
upstream_send: Callable[[dict[str, Any]], Awaitable[None]],
session: Any,
max_depth: int = DEFAULT_MAX_RELAY_DEPTH,
) -> Callable[[LoggingMessageNotificationParams], Awaitable[None]]:
"""Build a ``logging_callback`` that relays peer assistant-stream chunks.
Parameters
----------
upstream_send:
Async callable that ships a payload upward via the parent's MCP
session (typically wraps ``ctx.session.send_log_message`` from
the parent's per-request ``MCPContext``). Failures inside it
must not propagate — the relay logs and moves on.
session:
The peer ``MCPAgentClientSession`` this callback is attached to.
Used to look up the in-flight ``(parent_call_id, peer_name)`` for
attribution.
max_depth:
Hard cap on ``len(via_agent)`` before relaying. Chunks that
would exceed it are dropped with a debug log.
"""
async def _callback(params: LoggingMessageNotificationParams) -> None:
# 1) Filter — only Pallas assistant-stream notifications get relayed;
# every other log message routes through the existing trace
# handlers and isn't our concern.
if params.logger != _ASSISTANT_STREAM_LOGGER:
return
data = params.data
if not isinstance(data, dict):
return
if data.get("kind") not in _KINDS:
return
schema = data.get("schema_version")
if not isinstance(schema, int) or schema < 2:
# v1 chunks don't carry the enriched tool_calls / results
# shape the consumer needs to render nested iterations.
return
# 2) Look up the in-flight tool call this chunk belongs to. No
# in-flight call → drop (the chunk arrived during peer
# bookkeeping outside any tool_call, which we can't attribute).
inflight = current_inflight(session)
if inflight is None:
return
parent_call_id, peer_name = inflight
# 3) Enforce the depth cap. ``via_agent`` may already exist if
# the peer was itself relaying chunks from a deeper sub-agent.
existing_via: list[str] = list(data.get("via_agent") or [])
if len(existing_via) >= max_depth:
logger.debug(
"subagent_relay_depth_capped",
extra={
"depth": len(existing_via),
"max_depth": max_depth,
"peer": peer_name,
"parent_call_id": parent_call_id,
},
)
return
# 4) Decorate the payload and re-emit upstream. Mutate a shallow
# copy so we don't change a dict the SDK might still reference.
relayed = dict(data)
relayed["via_agent"] = existing_via + [peer_name]
relayed["parent_call_id"] = parent_call_id
relayed["schema_version"] = RELAY_SCHEMA_VERSION
try:
await upstream_send(relayed)
except Exception:
logger.warning(
"subagent_relay_upstream_send_failed",
exc_info=True,
extra={
"peer": peer_name,
"parent_call_id": parent_call_id,
"kind": relayed.get("kind"),
},
)
return _callback
__all__ = [
"RELAY_SCHEMA_VERSION",
"DEFAULT_MAX_RELAY_DEPTH",
"current_inflight",
"make_logging_callback",
"pop_inflight",
"push_inflight",
]

View File

@@ -294,7 +294,7 @@ async def _wait_for_agent(
import httpx import httpx
port = agents[name]["port"] port = agents[name]["port"]
url = f"http://127.0.0.1:{port}/mcp" url = f"http://127.0.0.1:{port}/ready"
deadline = asyncio.get_event_loop().time() + timeout deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline: while asyncio.get_event_loop().time() < deadline:
try: try:

View File

@@ -1,9 +1,9 @@
"""Tests for ``pallas.assistant_stream``. """Tests for ``pallas.assistant_stream``.
Drives the ``after_llm_call`` hook with handcrafted ``PromptMessageExtended`` Drives the ``after_llm_call`` / ``before_tool_call`` / ``after_tool_call``
objects and asserts the resulting MCP ``send_log_message`` payload shape. hooks with handcrafted ``PromptMessageExtended`` objects and asserts the
No fast-agent runtime is involved — the hook is a pure async function and resulting MCP ``send_log_message`` payload shape. No fast-agent runtime is
the MCP context is faked. involved — the hooks are pure async functions and the MCP context is faked.
Tests use ``asyncio.run`` directly to match the convention in Tests use ``asyncio.run`` directly to match the convention in
``tests/test_health.py`` and ``tests/test_mantle_shims.py`` (pallas has no ``tests/test_health.py`` and ``tests/test_mantle_shims.py`` (pallas has no
@@ -20,12 +20,14 @@ from fast_agent.types.llm_stop_reason import LlmStopReason
from mcp.types import ( from mcp.types import (
CallToolRequest, CallToolRequest,
CallToolRequestParams, CallToolRequestParams,
CallToolResult,
ImageContent, ImageContent,
TextContent, TextContent,
) )
from pallas.assistant_stream import ( from pallas.assistant_stream import (
KIND, KIND,
KIND_RESULTS,
LOGGER_NAME, LOGGER_NAME,
SCHEMA_VERSION, SCHEMA_VERSION,
AssistantChunkEmitter, AssistantChunkEmitter,
@@ -80,13 +82,22 @@ class _FakeAgent:
self.tool_runner_hooks: ToolRunnerHooks | None = None self.tool_runner_hooks: ToolRunnerHooks | None = None
def _tool_call(call_id: str, name: str, arguments: dict | None = None) -> CallToolRequest: def _tool_call(name: str, arguments: dict | None = None) -> CallToolRequest:
return CallToolRequest( return CallToolRequest(
method="tools/call", method="tools/call",
params=CallToolRequestParams(name=name, arguments=arguments or {}), params=CallToolRequestParams(name=name, arguments=arguments or {}),
) )
def _tool_result(
text: str | None = "ok", *, is_error: bool = False
) -> CallToolResult:
content: list = []
if text is not None:
content.append(TextContent(type="text", text=text))
return CallToolResult(content=content, isError=is_error)
# ── Tests ──────────────────────────────────────────────────────────────────── # ── Tests ────────────────────────────────────────────────────────────────────
@@ -101,7 +112,7 @@ def test_emit_text_only_iteration() -> None:
stop_reason=LlmStopReason.END_TURN, stop_reason=LlmStopReason.END_TURN,
) )
_run(emitter._emit(msg)) _run(emitter._emit_assistant_chunk(msg))
assert len(ctx.session.calls) == 1 assert len(ctx.session.calls) == 1
call = ctx.session.calls[0] call = ctx.session.calls[0]
@@ -120,8 +131,8 @@ def test_emit_text_only_iteration() -> None:
assert data["tool_calls"] == [] assert data["tool_calls"] == []
def test_emit_text_with_tool_call_iteration() -> None: def test_emit_text_with_tool_call_iteration_carries_args_preview() -> None:
"""An assistant turn that emits text and then calls a tool ships both.""" """An iteration that requests a tool ships name, server prefix, and args preview."""
ctx = _FakeContext() ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
@@ -129,33 +140,38 @@ def test_emit_text_with_tool_call_iteration() -> None:
role="assistant", role="assistant",
content=[TextContent(type="text", text="Logging this…")], content=[TextContent(type="text", text="Logging this…")],
tool_calls={ tool_calls={
"toolu_1": _tool_call("toolu_1", "time__get_current_time"), "toolu_1": _tool_call(
"argos-search_web",
{"query": "ducati v2 vs v4 reliability"},
),
}, },
stop_reason=LlmStopReason.TOOL_USE, stop_reason=LlmStopReason.TOOL_USE,
) )
_run(emitter._emit(msg)) _run(emitter._emit_assistant_chunk(msg))
assert len(ctx.session.calls) == 1 assert len(ctx.session.calls) == 1
data = ctx.session.calls[0]["data"] data = ctx.session.calls[0]["data"]
assert data["stop_reason"] == "toolUse" assert data["stop_reason"] == "toolUse"
assert data["content"] == [{"type": "text", "text": "Logging this…"}] assert data["content"] == [{"type": "text", "text": "Logging this…"}]
assert data["tool_calls"] == [{"id": "toolu_1", "name": "time__get_current_time"}] assert data["tool_calls"] == [
{
"id": "toolu_1",
"name": "argos-search_web",
"server": "argos",
"arguments_preview": "ducati v2 vs v4 reliability",
}
]
def test_emit_skips_completely_empty_iteration() -> None: def test_emit_skips_completely_empty_iteration() -> None:
"""A turn with no text blocks and no tool calls emits nothing. """A turn with no text blocks and no tool calls emits nothing."""
Tool-call lifecycle is already covered by notifications/progress. An
empty assistant_chunk would just be noise on the wire and a no-op
update on the live bubble.
"""
ctx = _FakeContext() ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
msg = PromptMessageExtended(role="assistant", content=[]) msg = PromptMessageExtended(role="assistant", content=[])
_run(emitter._emit(msg)) _run(emitter._emit_assistant_chunk(msg))
assert ctx.session.calls == [] assert ctx.session.calls == []
# Iteration counter still bumps so subsequent chunks aren't mis-numbered. # Iteration counter still bumps so subsequent chunks aren't mis-numbered.
@@ -173,7 +189,7 @@ def test_emit_image_block_passes_through_with_mime_type_renamed() -> None:
stop_reason=LlmStopReason.END_TURN, stop_reason=LlmStopReason.END_TURN,
) )
_run(emitter._emit(msg)) _run(emitter._emit_assistant_chunk(msg))
data = ctx.session.calls[0]["data"] data = ctx.session.calls[0]["data"]
assert data["content"] == [ assert data["content"] == [
@@ -188,21 +204,21 @@ def test_emit_iterations_are_numbered_in_order() -> None:
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
async def drive() -> None: async def drive() -> None:
await emitter._emit( await emitter._emit_assistant_chunk(
PromptMessageExtended( PromptMessageExtended(
role="assistant", role="assistant",
content=[TextContent(type="text", text="first")], content=[TextContent(type="text", text="first")],
stop_reason=LlmStopReason.TOOL_USE, stop_reason=LlmStopReason.TOOL_USE,
) )
) )
await emitter._emit( await emitter._emit_assistant_chunk(
PromptMessageExtended( PromptMessageExtended(
role="assistant", role="assistant",
content=[TextContent(type="text", text="second")], content=[TextContent(type="text", text="second")],
stop_reason=LlmStopReason.TOOL_USE, stop_reason=LlmStopReason.TOOL_USE,
) )
) )
await emitter._emit( await emitter._emit_assistant_chunk(
PromptMessageExtended( PromptMessageExtended(
role="assistant", role="assistant",
content=[TextContent(type="text", text="third — done")], content=[TextContent(type="text", text="third — done")],
@@ -232,9 +248,96 @@ def test_emit_swallows_session_failure() -> None:
stop_reason=LlmStopReason.END_TURN, stop_reason=LlmStopReason.END_TURN,
) )
# Must not raise. _run(emitter._emit_assistant_chunk(msg))
_run(emitter._emit(msg)) assert ctx.session.calls == []
# And no successful calls were recorded (fail_with raised before append).
def test_emit_tool_results_pairs_call_id_with_iteration() -> None:
"""``after_tool_call`` ships a results payload keyed to the iteration that called."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
# Iteration 1: text + tool request.
iter1 = PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="Searching…")],
tool_calls={"toolu_1": _tool_call("argos-search_web", {"query": "foo"})},
stop_reason=LlmStopReason.TOOL_USE,
)
async def drive() -> None:
# Simulate the runner: after_llm_call → before_tool_call → after_tool_call.
await emitter._emit_assistant_chunk(iter1)
await emitter.as_before_tool_call_hook()(None, iter1)
# The "user" message produced by the tool runner carries tool_results.
tool_msg = PromptMessageExtended(
role="user",
content=[],
tool_results={"toolu_1": _tool_result("12 results found")},
)
await emitter.as_after_tool_call_hook()(None, tool_msg)
_run(drive())
assert len(ctx.session.calls) == 2
chunk = ctx.session.calls[0]["data"]
results = ctx.session.calls[1]["data"]
assert chunk["kind"] == KIND
assert chunk["iteration"] == 1
assert results["kind"] == KIND_RESULTS
assert results["iteration"] == 1
assert results["agent"] == "alan"
assert results["conversation_id"] == "conv-1"
assert len(results["results"]) == 1
entry = results["results"][0]
assert entry["id"] == "toolu_1"
assert entry["ok"] is True
assert entry["result_preview"] == "12 results found"
assert isinstance(entry["duration_ms"], int)
assert entry["duration_ms"] >= 0
def test_emit_tool_results_marks_error() -> None:
"""A failing tool call surfaces ``ok: False`` in the results entry."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="c")
iter1 = PromptMessageExtended(
role="assistant",
content=[],
tool_calls={"toolu_1": _tool_call("argos-search_web", {"query": "foo"})},
stop_reason=LlmStopReason.TOOL_USE,
)
async def drive() -> None:
await emitter._emit_assistant_chunk(iter1)
await emitter.as_before_tool_call_hook()(None, iter1)
tool_msg = PromptMessageExtended(
role="user",
content=[],
tool_results={
"toolu_1": _tool_result("connection refused", is_error=True)
},
)
await emitter.as_after_tool_call_hook()(None, tool_msg)
_run(drive())
# iter1 is a pure-tool turn (no text + has tool_calls) → still emits an
# assistant_chunk because tool_calls is non-empty.
assert len(ctx.session.calls) == 2
results = ctx.session.calls[1]["data"]
assert results["results"][0]["ok"] is False
assert results["results"][0]["result_preview"] == "connection refused"
def test_emit_tool_results_empty_when_no_results() -> None:
"""An ``after_tool_call`` carrying no tool_results emits nothing."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="c")
msg = PromptMessageExtended(role="user", content=[], tool_results=None)
_run(emitter.as_after_tool_call_hook()(None, msg))
assert ctx.session.calls == [] assert ctx.session.calls == []
@@ -262,23 +365,22 @@ def test_install_for_request_merges_with_existing_after_llm_call() -> None:
_run(agent.tool_runner_hooks.after_llm_call(None, msg)) _run(agent.tool_runner_hooks.after_llm_call(None, msg))
# Base ran, and the assistant-stream emitter shipped the chunk.
assert seen == ["base"] assert seen == ["base"]
assert len(ctx.session.calls) == 1 assert len(ctx.session.calls) == 1
assert ctx.session.calls[0]["data"]["content"] == [ assert ctx.session.calls[0]["data"]["content"] == [
{"type": "text", "text": "hi"} {"type": "text", "text": "hi"}
] ]
# Other hook slots stay untouched. # All four hook slots used by the emitter are now bound; the restore
assert agent.tool_runner_hooks.before_llm_call is None # call puts the original hooks back exactly.
assert agent.tool_runner_hooks.before_tool_call is None assert agent.tool_runner_hooks.before_tool_call is not None
assert agent.tool_runner_hooks.after_tool_call is None assert agent.tool_runner_hooks.after_tool_call is not None
assert agent.tool_runner_hooks.after_turn_complete is None
# Restore puts the original hooks back exactly.
restore() restore()
assert agent.tool_runner_hooks is not None assert agent.tool_runner_hooks is not None
assert agent.tool_runner_hooks.after_llm_call is base_after assert agent.tool_runner_hooks.after_llm_call is base_after
assert agent.tool_runner_hooks.before_tool_call is None
assert agent.tool_runner_hooks.after_tool_call is None
def test_install_for_request_with_no_existing_hooks() -> None: def test_install_for_request_with_no_existing_hooks() -> None:
@@ -293,6 +395,8 @@ def test_install_for_request_with_no_existing_hooks() -> None:
assert agent.tool_runner_hooks is not None assert agent.tool_runner_hooks is not None
assert agent.tool_runner_hooks.after_llm_call is not None assert agent.tool_runner_hooks.after_llm_call is not None
assert agent.tool_runner_hooks.before_tool_call is not None
assert agent.tool_runner_hooks.after_tool_call is not None
restore() restore()
assert agent.tool_runner_hooks is None assert agent.tool_runner_hooks is None

2
uv.lock generated
View File

@@ -1562,7 +1562,7 @@ wheels = [
[[package]] [[package]]
name = "pallas-mcp" name = "pallas-mcp"
version = "0.2.0" version = "0.4.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "fast-agent-mcp" }, { name = "fast-agent-mcp" },