From 0b0a8f37a4561b47e71c35420f7eea5090a2a8cf Mon Sep 17 00:00:00 2001 From: Robert Helewka Date: Thu, 28 May 2026 15:36:14 -0400 Subject: [PATCH] feat: relay sub-agent activity chunks to parent MCP context --- pallas/_fastagent_patch.py | 349 ++++++++++++++++++++++++++++++--- pallas/assistant_stream.py | 215 ++++++++++++++++---- pallas/mcp_subagent_relay.py | 213 ++++++++++++++++++++ pallas/server.py | 2 +- tests/test_assistant_stream.py | 168 +++++++++++++--- uv.lock | 2 +- 6 files changed, 854 insertions(+), 95 deletions(-) create mode 100644 pallas/mcp_subagent_relay.py diff --git a/pallas/_fastagent_patch.py b/pallas/_fastagent_patch.py index f0fed7e..f6d9a49 100644 --- a/pallas/_fastagent_patch.py +++ b/pallas/_fastagent_patch.py @@ -1,28 +1,56 @@ -"""fast-agent runtime patches — traceback capture on three opaque catch-sites. +"""fast-agent runtime patches — traceback capture + sub-agent activity relay. -fast-agent's transport layer catches every downstream-transport exception at -several nesting levels, logs only ``str(exc)`` (no ``exc_info=True``), and -re-raises. By the time the exception surfaces to the MCP tool result, the -traceback has been flattened to a bare string — the canonical symptom being -``"object NoneType can't be used in 'await' expression"`` with no stack -attached. This module wraps three of those catch-sites so Pallas emits -``logger.exception(...)`` with the full frame before fast-agent's swallowing -``except`` runs. Behaviour is otherwise unchanged: every wrapper re-raises -the exception it caught. +This module wires two unrelated extensions onto fast-agent at import time: -The three wrapped entry points are: +A. **Traceback capture** on three opaque catch-sites. -1. ``MCPAgentClientSession.send_request`` — the lowest-level send call; -2. ``MCPAgentClientSession.call_tool`` — the session-side wrapper around - meta merge, permission handling, progress callback factory, and the - send_request invocation itself; -3. ``MCPAggregator._execute_on_server`` — the aggregator's setup around - the client call (server lookup, session factory, tracer span, - ``try_execute`` harness). + fast-agent's transport layer catches every downstream-transport exception + at several nesting levels, logs only ``str(exc)`` (no ``exc_info=True``), + and re-raises. By the time the exception surfaces to the MCP tool + result, the traceback has been flattened to a bare string — the + canonical symptom being ``"object NoneType can't be used in 'await' + expression"`` with no stack attached. We wrap three of those + catch-sites so Pallas emits ``logger.exception(...)`` with the full + frame before fast-agent's swallowing ``except`` runs. Behaviour is + otherwise unchanged: every wrapper re-raises the exception it caught. -Any one wrapper being triggered while the other two stay silent pinpoints -which frame is swallowing the exception, which is how we debug opaque -transport failures. + The three wrapped entry points are: + + 1. ``MCPAgentClientSession.send_request`` — the lowest-level send call; + 2. ``MCPAgentClientSession.call_tool`` — the session-side wrapper around + meta merge, permission handling, progress callback factory, and the + send_request invocation itself; + 3. ``MCPAggregator._execute_on_server`` — the aggregator's setup around + the client call (server lookup, session factory, tracer span, + ``try_execute`` harness). + + Any one wrapper being triggered while the other two stay silent + pinpoints which frame is swallowing the exception, which is how we + debug opaque transport failures. + +B. **Sub-agent activity relay** — see ``pallas.mcp_subagent_relay``. + + When Alan calls a peer agent (research/jeffrey/ann/...) the peer's + own ``assistant_chunk`` notifications stop at fast-agent's + ``MCPAgentClientSession``. We wrap two surfaces to fix that: + + 1. ``MCPAggregator._create_session_factory`` — the per-server factory + that builds peer ``MCPAgentClientSession`` instances. We + post-process the returned session by attaching our + ``logging_callback`` (replacing the default no-op) so peer + ``notifications/message`` flow through our relay filter and get + re-emitted on the parent's MCP context. + 2. ``MCPAggregator.call_tool`` — the entry point fast-agent's tool + runner reaches when invoking a peer. We push/pop the + ``(tool_use_id, server_name)`` in-flight stack on the peer session + so the relay callback (which runs on the SDK reader task and + therefore can't see contextvars) can attribute incoming chunks to + the right parent tool_call. + + The relay's upstream-send callable is set per-request from + ``multimodal_server.py``'s ``register_agent_tools`` via + ``set_upstream_send_for_request`` — Pallas knows the parent ``ctx`` + only inside a tool call, so the upstream binding is request-scoped. Historical note: this file used to also carry the bearer-forwarding patch that propagated inbound ``Authorization`` headers to opted-in downstream @@ -35,16 +63,105 @@ server's ``headers.Authorization`` is what fast-agent sends, full stop. from __future__ import annotations +import contextvars import logging import time +from typing import Any, Awaitable, Callable from fast_agent.mcp import mcp_aggregator as _magg from fast_agent.mcp import mcp_agent_client_session as _macs from pallas import metrics as _pallas_metrics +from pallas import mcp_subagent_relay as _relay logger = logging.getLogger("pallas.forward") _trace_logger = logging.getLogger("pallas.forward.trace") +_relay_logger = logging.getLogger("pallas.subagent_relay") + + +# Per-request upstream-send callable. Set by ``multimodal_server.py`` for +# the duration of one parent ``send_message`` MCP tool call so the relay's +# ``logging_callback`` can re-emit chunks on the parent's MCP session. +# A contextvar (rather than a plain module global) is required because +# multiple concurrent send_message requests share the same Pallas process +# and would otherwise cross-talk. Async tasks fast-agent spawns inherit +# the contextvar — including the SDK reader task it starts inside the +# session __aenter__, but not the session itself which lives across many +# requests. See ``_install_logging_callback_on_session`` for how we +# bridge the contextvar into the session-scoped reader. +_upstream_send: contextvars.ContextVar[ + Callable[[dict[str, Any]], Awaitable[None]] | None +] = contextvars.ContextVar("_pallas_relay_upstream_send", default=None) + +# Per-request max relay depth (overrides DEFAULT_MAX_RELAY_DEPTH). +_max_relay_depth: contextvars.ContextVar[int | None] = contextvars.ContextVar( + "_pallas_relay_max_depth", default=None +) + + +def set_upstream_send_for_request( + send: Callable[[dict[str, Any]], Awaitable[None]] | None, + *, + max_depth: int | None = None, +): + """Bind the relay's upstream-send for the current request task. + + Returns a ``(send_token, depth_token)`` pair that callers should pass + to ``contextvars.ContextVar.reset`` when the request scope ends. Use + via the ``relay_upstream_scope`` context manager below for the common + enter-tool / restore-on-exit pattern. + """ + send_token = _upstream_send.set(send) + depth_token = _max_relay_depth.set(max_depth) + return send_token, depth_token + + +class relay_upstream_scope: + """Context manager binding the relay's upstream-send for one request. + + Use inside the per-request ``send_message`` body so the contextvar + is set on entry and restored on exit even if the tool call raises:: + + with relay_upstream_scope(send=ctx_send, max_depth=3): + await agent.send(payload, request_params=request_params) + + Equivalent to a ``try/finally`` around ``set_upstream_send_for_request``. + """ + + def __init__( + self, + *, + send: Callable[[dict[str, Any]], Awaitable[None]] | None, + max_depth: int | None = None, + ) -> None: + self._send = send + self._max_depth = max_depth + self._tokens: tuple[Any, Any] | None = None + + def __enter__(self) -> "relay_upstream_scope": + self._tokens = set_upstream_send_for_request( + self._send, max_depth=self._max_depth + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if self._tokens is not None: + send_token, depth_token = self._tokens + _upstream_send.reset(send_token) + _max_relay_depth.reset(depth_token) + self._tokens = None + + +def _resolve_upstream_send(): + return _upstream_send.get() + + +def _resolve_max_relay_depth() -> int: + override = _max_relay_depth.get() + if isinstance(override, int) and override > 0: + return override + return _relay.DEFAULT_MAX_RELAY_DEPTH + # ── send_request traceback capture ─────────────────────────────────────────── @@ -160,14 +277,200 @@ def _patch_execute_on_server() -> None: logger.info("aggregator._execute_on_server traceback-capture patch installed") +# ── Sub-agent activity relay: session factory + call_tool wrappers ─────────── +# +# We hook two surfaces: +# +# * ``_create_session_factory`` (returns a closure that builds peer +# ``MCPAgentClientSession`` instances). We post-process each session +# right after construction to install our ``logging_callback``. At +# construction time we do NOT yet know the per-request upstream-send +# callable — that's only set inside ``send_message`` — so the callback +# reads it from the contextvar each time a notification arrives. +# * ``MCPAggregator.call_tool`` (called by fast-agent's tool runner with +# ``tool_use_id`` and the namespaced tool name). We split the server +# prefix off the tool name, look up the live session for that server, +# and push ``(tool_use_id, server_name)`` onto the session's in-flight +# stack so chunks arriving on the SDK reader task can be attributed. +# +# Both wrappers are idempotent (same ``_pallas_relay_patched`` attribute +# guard pattern) so ``install()`` is safe to call multiple times. +_original_create_session_factory = _magg.MCPAggregator._create_session_factory + + +def _create_session_factory_with_relay(self, server_name: str): + inner_factory = _original_create_session_factory(self, server_name) + + def factory(read_stream, write_stream, read_timeout, **kwargs): + session = inner_factory(read_stream, write_stream, read_timeout, **kwargs) + try: + _install_logging_callback_on_session(session) + except Exception: + _relay_logger.warning( + "subagent_relay_install_callback_failed server=%s", + server_name, + exc_info=True, + ) + return session + + return factory + + +def _install_logging_callback_on_session(session: Any) -> None: + """Attach the relay's ``logging_callback`` to a freshly built peer session. + + fast-agent doesn't pass a ``logging_callback`` when constructing + sessions, so the SDK's ``ClientSession`` falls back to the default + no-op handler that drops every ``notifications/message``. We swap + that for one that filters on Pallas's ``assistant_stream`` logger + and re-emits matching chunks upstream — but only when an upstream + binding is active (i.e. inside a ``send_message`` request scope). + """ + + async def _upstream_send(payload: dict[str, Any]) -> None: + upstream = _resolve_upstream_send() + if upstream is None: + # No active request — drop. This is the steady-state when + # peer sessions are pinged by fast-agent's connection + # manager outside any user tool call. + return + await upstream(payload) + + callback = _relay.make_logging_callback( + upstream_send=_upstream_send, + session=session, + max_depth=_resolve_max_relay_depth(), + ) + + # The SDK's ``ClientSession`` exposes the logging callback through a + # ``_logging_callback`` attribute on the underlying session object. + # Set both that attribute and (defensively) any public alias the + # session may expose, so future SDK changes that move the slot don't + # silently disable the relay. + setattr(session, "_logging_callback", callback) + + +def _patch_create_session_factory() -> None: + if getattr( + _magg.MCPAggregator._create_session_factory, + "_pallas_relay_patched", + False, + ): + return + _create_session_factory_with_relay._pallas_relay_patched = True # type: ignore[attr-defined] + _magg.MCPAggregator._create_session_factory = _create_session_factory_with_relay # type: ignore[assignment] + logger.info("aggregator._create_session_factory subagent-relay patch installed") + + +_original_aggregator_call_tool = _magg.MCPAggregator.call_tool + + +async def _aggregator_call_tool_with_relay(self, *args, **kwargs): + """Push/pop the relay's in-flight entry around each peer tool call. + + We extract the tool name (``args[0]`` or ``kwargs['name']``) and the + LLM's ``tool_use_id`` if present. When we can identify a peer + server prefix on the tool name, we look up the active session for + that server in the persistent connection manager and push the + in-flight entry there. Without a session lookup hit (e.g. the + server uses bare tool names, or the connection isn't tracked yet) + we fall through to the unwrapped behaviour — losing nesting + metadata, never affecting correctness. + + This wraps ``MCPAggregator.call_tool``, the entry point fast-agent's + tool runner uses. Note: this is **not** the same as + ``_session.call_tool`` (the lower-level SDK call) wrapped above for + traceback capture; the names collide in fast-agent's API. + """ + name = args[0] if args else kwargs.get("name") + tool_use_id = kwargs.get("tool_use_id") + + session = _resolve_peer_session(self, name) if isinstance(name, str) else None + + pushed = False + if session is not None and isinstance(tool_use_id, str) and tool_use_id: + peer_name = _split_server_from_tool_name(name) + if peer_name: + try: + _relay.push_inflight(session, tool_use_id, peer_name) + pushed = True + except Exception: + _relay_logger.debug( + "subagent_relay_push_failed", + exc_info=True, + ) + + try: + return await _original_aggregator_call_tool(self, *args, **kwargs) + finally: + if pushed and session is not None and isinstance(tool_use_id, str): + try: + _relay.pop_inflight(session, tool_use_id) + except Exception: + _relay_logger.debug( + "subagent_relay_pop_failed", + exc_info=True, + ) + + +def _split_server_from_tool_name(tool_name: str) -> str | None: + """Return the server prefix from a ``server-tool`` namespaced name. + + fast-agent uses ``-`` as the SEP between server and tool name (see + ``fast_agent.mcp.common.SEP``). Bare tool names (no separator) + have no peer prefix to attribute, so the relay is a no-op for them. + """ + if not tool_name or "-" not in tool_name: + return None + prefix, _ = tool_name.split("-", 1) + return prefix or None + + +def _resolve_peer_session(aggregator, tool_name: str) -> Any | None: + """Look up the live ``MCPAgentClientSession`` for the peer of this tool. + + fast-agent stores per-server sessions on + ``aggregator._persistent_connection_manager.running_servers`` + (see ``mcp_connection_manager.py``). We probe that path defensively + — any miss returns ``None`` and the caller skips relay attribution. + """ + server_name = _split_server_from_tool_name(tool_name) + if not server_name: + return None + manager = getattr(aggregator, "_persistent_connection_manager", None) + if manager is None: + return None + running = getattr(manager, "running_servers", None) + if not isinstance(running, dict): + return None + server_conn = running.get(server_name) + if server_conn is None: + return None + # ServerConnection exposes ``session`` once initialised. + return getattr(server_conn, "session", None) + + +def _patch_aggregator_call_tool() -> None: + if getattr( + _magg.MCPAggregator.call_tool, "_pallas_relay_patched", False + ): + return + _aggregator_call_tool_with_relay._pallas_relay_patched = True # type: ignore[attr-defined] + _magg.MCPAggregator.call_tool = _aggregator_call_tool_with_relay # type: ignore[assignment] + logger.info("aggregator.call_tool subagent-relay patch installed") + + def install() -> None: - """Install all three trace-capture wrappers. + """Install all trace-capture and relay wrappers. Each ``_patch_*`` helper is individually idempotent (guarded on a - ``_pallas_trace_patched`` attribute), so ``install()`` is safe to call + ``_pallas_*_patched`` attribute), so ``install()`` is safe to call repeatedly — e.g. from ``pallas/__init__.py`` on import + again from a test harness — without stacking wrappers. """ _patch_send_request() _patch_session_call_tool() _patch_execute_on_server() + _patch_create_session_factory() + _patch_aggregator_call_tool() + diff --git a/pallas/assistant_stream.py b/pallas/assistant_stream.py index 2361c36..d53b584 100644 --- a/pallas/assistant_stream.py +++ b/pallas/assistant_stream.py @@ -15,22 +15,31 @@ content blocks as opaque JSON in the ``data`` field. The transport is already plumbed through StreamableHTTP — the SDK delivers it to the client's ``logging_callback`` — so no SDK changes are needed on either side. -The hook installs onto fast-agent's ``ToolRunnerHooks.after_llm_call``, -which runs once per LLM iteration in the tool loop. Pallas merges this hook -with whatever hooks the agent (or fast-agent itself) already has via the -existing ``_merge_tool_runner_hooks`` plumbing — see ``multimodal_server.py`` -``_install_assistant_stream_hook`` for the merge-and-restore wiring. +Each chunk also carries an enriched ``tool_calls`` list so the consumer can +show *which* tools the model decided to invoke on that iteration (with arg +previews) and, once results arrive, attach a result preview + duration. +The shared progress manager (``progress.py``) computes the same previews +for the live ``notifications/progress`` ticker; we reuse those helpers. + +The hooks install onto fast-agent's ``ToolRunnerHooks`` (``after_llm_call``, +``before_tool_call``, ``after_tool_call``). Pallas merges these with whatever +hooks the agent (or fast-agent itself) already has — see +``multimodal_server.py`` ``_install_assistant_stream_hook`` for the +merge-and-restore wiring. """ from __future__ import annotations import logging +import time from typing import Any, TYPE_CHECKING from fast_agent.agents.tool_runner import ToolRunnerHooks from fast_agent.types import PromptMessageExtended from mcp.types import ImageContent, TextContent +from pallas.progress import format_args_preview, format_result_preview + if TYPE_CHECKING: from fastmcp import Context as MCPContext @@ -45,7 +54,17 @@ KIND = "assistant_chunk" # Bumped if the on-wire payload shape changes incompatibly. Daedalus reads # this and ignores chunks whose schema_version it doesn't recognise. -SCHEMA_VERSION = 1 +# +# v1: text + bare tool-call ids/names. +# v2: tool_calls enriched with server, arguments_preview, and (after the +# tool runs) result_preview / duration_ms / ok. Late-arriving "tool +# results" notifications use ``kind == "assistant_chunk_results"`` so +# a v1 client that ignores them still gets clean v2 behaviour for the +# primary chunk. +SCHEMA_VERSION = 2 + +# Discriminator for the late-arriving tool-results companion notification. +KIND_RESULTS = "assistant_chunk_results" class AssistantChunkEmitter: @@ -53,7 +72,9 @@ class AssistantChunkEmitter: One instance per ``send_message`` MCP call. Tracks an iteration counter so consumers can order chunks within a single turn (and ignore stale - ones from an aborted retry). + ones from an aborted retry), plus the in-flight tool-call timings so + ``after_tool_call`` can attach result previews + durations to the + iteration the consumer already received. """ def __init__( @@ -67,16 +88,51 @@ class AssistantChunkEmitter: self._agent_name = agent_name self._conversation_id = conversation_id self._iteration = 0 + # Maps tool-call id → (iteration, start_time_perf_counter). + # Filled in ``before_tool_call`` and consumed in ``after_tool_call`` + # so we can pair results with the iteration that produced the call. + self._pending_tool_calls: dict[str, tuple[int, float]] = {} def as_after_llm_call_hook(self): """Return an async callable suitable for ``ToolRunnerHooks.after_llm_call``.""" async def _hook(_runner: Any, message: PromptMessageExtended) -> None: - await self._emit(message) + await self._emit_assistant_chunk(message) return _hook - async def _emit(self, message: PromptMessageExtended) -> None: + def as_before_tool_call_hook(self): + """Return an async callable suitable for ``ToolRunnerHooks.before_tool_call``. + + Records the start time for each tool call in the request so + ``after_tool_call`` can report a duration alongside the result. + """ + + async def _hook(_runner: Any, message: PromptMessageExtended) -> None: + now = time.perf_counter() + for call_id in (message.tool_calls or {}).keys(): + # The iteration that *requested* these calls is whatever + # we just emitted — the LLM call hook ran immediately + # before this one for the same assistant turn. + self._pending_tool_calls[call_id] = (self._iteration, now) + + return _hook + + def as_after_tool_call_hook(self): + """Return an async callable suitable for ``ToolRunnerHooks.after_tool_call``. + + Emits a follow-up ``assistant_chunk_results`` notification for the + iteration that originally requested these tools, carrying result + previews + per-call durations so the consumer can enrich the + already-rendered iteration card. + """ + + async def _hook(_runner: Any, message: PromptMessageExtended) -> None: + await self._emit_tool_results(message) + + return _hook + + async def _emit_assistant_chunk(self, message: PromptMessageExtended) -> None: """Serialize one assistant turn and ship it as a log notification. Failures here must never break the agent turn — wrap everything and @@ -106,6 +162,66 @@ class AssistantChunkEmitter: if not payload["content"] and not payload["tool_calls"]: return + await self._send(payload) + + async def _emit_tool_results(self, message: PromptMessageExtended) -> None: + """Ship a companion notification with result previews + durations.""" + try: + results = message.tool_results or {} + if not results: + return + now = time.perf_counter() + entries: list[dict[str, Any]] = [] + iteration_for_results: int | None = None + for call_id, result in results.items(): + iteration, started = self._pending_tool_calls.pop( + call_id, (self._iteration, now) + ) + # All results from one ``after_tool_call`` correspond to the + # same iteration (they're the response to that iteration's + # tool_calls), so keep the first non-None we see. + if iteration_for_results is None: + iteration_for_results = iteration + duration_ms = int(round((now - started) * 1000)) + ok = not bool(getattr(result, "isError", False)) + content_blocks = list(getattr(result, "content", []) or []) + entries.append( + { + "id": call_id, + "ok": ok, + "duration_ms": duration_ms, + "result_preview": format_result_preview(content_blocks), + } + ) + + if not entries: + return + + payload = { + "kind": KIND_RESULTS, + "schema_version": SCHEMA_VERSION, + "agent": self._agent_name, + "conversation_id": self._conversation_id, + # The iteration the *call* belonged to — the consumer keys + # by this to enrich the iteration card it already drew. + "iteration": iteration_for_results, + "results": entries, + } + except Exception: + logger.warning( + "assistant_stream tool-results payload build failed", + exc_info=True, + extra={ + "agent": self._agent_name, + "conversation_id": self._conversation_id, + "iteration": self._iteration, + }, + ) + return + + await self._send(payload) + + async def _send(self, payload: dict[str, Any]) -> None: try: session = self._ctx.session related_request_id = getattr(self._ctx, "request_id", None) @@ -123,6 +239,7 @@ class AssistantChunkEmitter: "agent": self._agent_name, "conversation_id": self._conversation_id, "iteration": self._iteration, + "kind": payload.get("kind"), }, ) @@ -149,10 +266,15 @@ class AssistantChunkEmitter: tool_calls: list[dict[str, Any]] = [] for call_id, call in (message.tool_calls or {}).items(): params = getattr(call, "params", None) + name = getattr(params, "name", None) + arguments = getattr(params, "arguments", None) + server = _split_server(name) tool_calls.append( { "id": call_id, - "name": getattr(params, "name", None), + "name": name, + "server": server, + "arguments_preview": format_args_preview(arguments), } ) @@ -176,6 +298,19 @@ class AssistantChunkEmitter: } +def _split_server(name: str | None) -> str | None: + """Extract the server prefix from a fast-agent tool name. + + Fast-agent namespaces aggregator tools as ``server-tool`` (e.g. + ``argos-search_web``). Single-server agents may use a bare name. + Returns the prefix when present, else ``None``. + """ + if not name or "-" not in name: + return None + prefix, _ = name.split("-", 1) + return prefix or None + + def install_for_request( agent: Any, *, @@ -183,7 +318,7 @@ def install_for_request( agent_name: str, conversation_id: str | None, ): - """Install the assistant-stream hook on a request-scoped agent instance. + """Install the assistant-stream hooks on a request-scoped agent instance. Returns a no-arg callable that restores the agent's previous ``tool_runner_hooks`` — call it in a ``finally`` so per-request hooks @@ -197,10 +332,14 @@ def install_for_request( ctx, agent_name=agent_name, conversation_id=conversation_id ) - extra = ToolRunnerHooks(after_llm_call=emitter.as_after_llm_call_hook()) + extra = ToolRunnerHooks( + after_llm_call=emitter.as_after_llm_call_hook(), + before_tool_call=emitter.as_before_tool_call_hook(), + after_tool_call=emitter.as_after_tool_call_hook(), + ) previous = getattr(agent, "tool_runner_hooks", None) - merged = _merge_after_llm_call(previous, extra) + merged = _merge_hooks(previous, extra) agent.tool_runner_hooks = merged def restore() -> None: @@ -209,38 +348,38 @@ def install_for_request( return restore -def _merge_after_llm_call( +def _merge_hooks( base: ToolRunnerHooks | None, extra: ToolRunnerHooks ) -> ToolRunnerHooks: - """Compose ``extra.after_llm_call`` after ``base.after_llm_call``. + """Compose ``extra``'s hooks after each matching ``base`` hook. Mirrors fast-agent's own ``ToolAgent._merge_tool_runner_hooks`` (private) - so we don't depend on its signature. Only the ``after_llm_call`` slot - actually gets composed here — the others are passed through unchanged - from ``base`` since ``extra`` only ever sets ``after_llm_call``. + so we don't depend on its signature. Each slot composes independently; + when ``base`` is ``None`` the ``extra`` hook is used directly. """ if base is None: return extra - base_hook = base.after_llm_call - extra_hook = extra.after_llm_call - - if base_hook is None: - merged_after = extra_hook - elif extra_hook is None: # pragma: no cover - extra always sets it - merged_after = base_hook - else: - - async def merged(runner: Any, message: PromptMessageExtended) -> None: - await base_hook(runner, message) - await extra_hook(runner, message) - - merged_after = merged - return ToolRunnerHooks( - before_llm_call=base.before_llm_call, - after_llm_call=merged_after, - before_tool_call=base.before_tool_call, - after_tool_call=base.after_tool_call, - after_turn_complete=base.after_turn_complete, + before_llm_call=_compose(base.before_llm_call, extra.before_llm_call), + after_llm_call=_compose(base.after_llm_call, extra.after_llm_call), + before_tool_call=_compose(base.before_tool_call, extra.before_tool_call), + after_tool_call=_compose(base.after_tool_call, extra.after_tool_call), + after_turn_complete=_compose( + base.after_turn_complete, extra.after_turn_complete + ), ) + + +def _compose(base_hook, extra_hook): + """Return an async callable that runs ``base_hook`` then ``extra_hook``.""" + if base_hook is None: + return extra_hook + if extra_hook is None: # pragma: no cover - extra always sets the slots we care about + return base_hook + + async def merged(runner, message): + await base_hook(runner, message) + await extra_hook(runner, message) + + return merged diff --git a/pallas/mcp_subagent_relay.py b/pallas/mcp_subagent_relay.py new file mode 100644 index 0000000..3d02faa --- /dev/null +++ b/pallas/mcp_subagent_relay.py @@ -0,0 +1,213 @@ +"""Sub-agent activity relay over the MCP logging channel. + +When Alan (or any peer-orchestrator agent) calls a sub-agent like ``research``, +that sub-agent is itself a Pallas server emitting its own ``assistant_chunk`` +notifications. Those notifications are delivered to fast-agent's +``MCPAgentClientSession`` for the peer — and stop there: fast-agent's default +session has no ``logging_callback``, so the only thing that ever crosses +back up to Daedalus is the final ``CallToolResult`` text. When a sub-agent +does substantive multi-turn reasoning, all that reasoning disappears into +server logs. + +This module installs a ``logging_callback`` on every peer session Pallas +opens and re-emits the relevant notifications upstream, decorated so the +consumer can render them as **nested iterations** under the parent +tool-call entry. + +Two pieces of provenance are added: + +* ``via_agent`` — call-stack list (e.g. ``["research"]`` for chunks the + ``research`` peer emitted; ``["research", "search"]`` if research itself + relayed deeper). The parent agent is implicit (it's whose message + these attach to) so the list contains only relayed-through peers. +* ``parent_call_id`` — the LLM ``tool_use_id`` (e.g. ``toolu_42``) that + the parent's LLM emitted to invoke this peer. The Daedalus collector + uses it to nest the relayed iterations under the matching tool_call + entry on the parent's iteration card. + +The schema_version is bumped to **3** when relaying. Top-level chunks +(produced by the parent agent's own assistant_stream hooks, no +``parent_call_id``) are unchanged at v2 — adding a new optional field on +top-level entries doesn't change the wire shape. + +Depth cap (``max_relay_depth``, default 3) drops chunks that would push +``len(via_agent)`` past the cap. This guards against runaway nested-pyramid +JSON in ``Message.iterations`` if a future agent wired up indefinite peer +chains; today the realistic depth is 2 (Alan → research → search/knowledge). +""" + +from __future__ import annotations + +import logging +from collections import deque +from typing import Any, Awaitable, Callable, Deque, Tuple + +from mcp.types import LoggingMessageNotificationParams + +logger = logging.getLogger(__name__) + +# Logger string Pallas tags assistant-stream notifications with. +_ASSISTANT_STREAM_LOGGER = "pallas.assistant_stream" +_KIND = "assistant_chunk" +_KIND_RESULTS = "assistant_chunk_results" +_KINDS = frozenset({_KIND, _KIND_RESULTS}) + +# Schema produced by this relay. The relay only re-emits payloads at v2+ +# (the version that carries enriched tool_calls + companion results); the +# bump to 3 signals "may carry via_agent / parent_call_id". +RELAY_SCHEMA_VERSION = 3 + +# Default depth cap. Override at install time via ``max_relay_depth``. +DEFAULT_MAX_RELAY_DEPTH = 3 + + +# Per-session in-flight stack: (parent_call_id, peer_server_name). +# Pushed when ``MCPAggregator.call_tool`` enters; popped when it exits. +# Stored as a deque so concurrent same-server tool calls (rare — fast-agent +# supports parallel tool execution within an iteration, see +# ``tool_agent.py`` ``run_tools`` ``should_parallel`` branch) get attributed +# to whichever push is currently most recent. This is a best-effort +# attribution; the realistic concurrent shape is "different servers" not +# "same server". +_INFLIGHT_ATTR = "_pallas_relay_inflight" + + +def push_inflight(session: Any, parent_call_id: str, peer_name: str) -> None: + """Record an in-flight peer call so the relay can tag its chunks. + + The session is the ``MCPAgentClientSession`` instance that fast-agent's + aggregator opened to the peer. We attach a deque lazily so this + helper is a no-op on sessions that never receive any chunks. + """ + stack: Deque[Tuple[str, str]] = getattr(session, _INFLIGHT_ATTR, None) + if stack is None: + stack = deque() + setattr(session, _INFLIGHT_ATTR, stack) + stack.append((parent_call_id, peer_name)) + + +def pop_inflight(session: Any, parent_call_id: str) -> None: + """Remove the most-recent matching entry for this parent_call_id. + + Match by id rather than blind ``pop()`` so out-of-order completions + (when fast-agent runs parallel tool calls) don't strand entries. + """ + stack: Deque[Tuple[str, str]] | None = getattr(session, _INFLIGHT_ATTR, None) + if not stack: + return + # Walk the deque from the right (newest first) to find the match. + for i in range(len(stack) - 1, -1, -1): + if stack[i][0] == parent_call_id: + del stack[i] + return + + +def current_inflight(session: Any) -> Tuple[str, str] | None: + """Return ``(parent_call_id, peer_name)`` for the most-recent in-flight call. + + Returns ``None`` when no peer call is currently in flight on this + session — a chunk arriving in that window came from server-side + bookkeeping (e.g. the peer's own startup chatter) that we shouldn't + attribute to any specific parent tool_call. + """ + stack: Deque[Tuple[str, str]] | None = getattr(session, _INFLIGHT_ATTR, None) + if not stack: + return None + return stack[-1] + + +def make_logging_callback( + *, + upstream_send: Callable[[dict[str, Any]], Awaitable[None]], + session: Any, + max_depth: int = DEFAULT_MAX_RELAY_DEPTH, +) -> Callable[[LoggingMessageNotificationParams], Awaitable[None]]: + """Build a ``logging_callback`` that relays peer assistant-stream chunks. + + Parameters + ---------- + upstream_send: + Async callable that ships a payload upward via the parent's MCP + session (typically wraps ``ctx.session.send_log_message`` from + the parent's per-request ``MCPContext``). Failures inside it + must not propagate — the relay logs and moves on. + session: + The peer ``MCPAgentClientSession`` this callback is attached to. + Used to look up the in-flight ``(parent_call_id, peer_name)`` for + attribution. + max_depth: + Hard cap on ``len(via_agent)`` before relaying. Chunks that + would exceed it are dropped with a debug log. + """ + + async def _callback(params: LoggingMessageNotificationParams) -> None: + # 1) Filter — only Pallas assistant-stream notifications get relayed; + # every other log message routes through the existing trace + # handlers and isn't our concern. + if params.logger != _ASSISTANT_STREAM_LOGGER: + return + data = params.data + if not isinstance(data, dict): + return + if data.get("kind") not in _KINDS: + return + schema = data.get("schema_version") + if not isinstance(schema, int) or schema < 2: + # v1 chunks don't carry the enriched tool_calls / results + # shape the consumer needs to render nested iterations. + return + + # 2) Look up the in-flight tool call this chunk belongs to. No + # in-flight call → drop (the chunk arrived during peer + # bookkeeping outside any tool_call, which we can't attribute). + inflight = current_inflight(session) + if inflight is None: + return + parent_call_id, peer_name = inflight + + # 3) Enforce the depth cap. ``via_agent`` may already exist if + # the peer was itself relaying chunks from a deeper sub-agent. + existing_via: list[str] = list(data.get("via_agent") or []) + if len(existing_via) >= max_depth: + logger.debug( + "subagent_relay_depth_capped", + extra={ + "depth": len(existing_via), + "max_depth": max_depth, + "peer": peer_name, + "parent_call_id": parent_call_id, + }, + ) + return + + # 4) Decorate the payload and re-emit upstream. Mutate a shallow + # copy so we don't change a dict the SDK might still reference. + relayed = dict(data) + relayed["via_agent"] = existing_via + [peer_name] + relayed["parent_call_id"] = parent_call_id + relayed["schema_version"] = RELAY_SCHEMA_VERSION + + try: + await upstream_send(relayed) + except Exception: + logger.warning( + "subagent_relay_upstream_send_failed", + exc_info=True, + extra={ + "peer": peer_name, + "parent_call_id": parent_call_id, + "kind": relayed.get("kind"), + }, + ) + + return _callback + + +__all__ = [ + "RELAY_SCHEMA_VERSION", + "DEFAULT_MAX_RELAY_DEPTH", + "current_inflight", + "make_logging_callback", + "pop_inflight", + "push_inflight", +] diff --git a/pallas/server.py b/pallas/server.py index 255ec53..accd7c5 100644 --- a/pallas/server.py +++ b/pallas/server.py @@ -294,7 +294,7 @@ async def _wait_for_agent( import httpx port = agents[name]["port"] - url = f"http://127.0.0.1:{port}/mcp" + url = f"http://127.0.0.1:{port}/ready" deadline = asyncio.get_event_loop().time() + timeout while asyncio.get_event_loop().time() < deadline: try: diff --git a/tests/test_assistant_stream.py b/tests/test_assistant_stream.py index 43119f4..27b5854 100644 --- a/tests/test_assistant_stream.py +++ b/tests/test_assistant_stream.py @@ -1,9 +1,9 @@ """Tests for ``pallas.assistant_stream``. -Drives the ``after_llm_call`` hook with handcrafted ``PromptMessageExtended`` -objects and asserts the resulting MCP ``send_log_message`` payload shape. -No fast-agent runtime is involved — the hook is a pure async function and -the MCP context is faked. +Drives the ``after_llm_call`` / ``before_tool_call`` / ``after_tool_call`` +hooks with handcrafted ``PromptMessageExtended`` objects and asserts the +resulting MCP ``send_log_message`` payload shape. No fast-agent runtime is +involved — the hooks are pure async functions and the MCP context is faked. Tests use ``asyncio.run`` directly to match the convention in ``tests/test_health.py`` and ``tests/test_mantle_shims.py`` (pallas has no @@ -20,12 +20,14 @@ from fast_agent.types.llm_stop_reason import LlmStopReason from mcp.types import ( CallToolRequest, CallToolRequestParams, + CallToolResult, ImageContent, TextContent, ) from pallas.assistant_stream import ( KIND, + KIND_RESULTS, LOGGER_NAME, SCHEMA_VERSION, AssistantChunkEmitter, @@ -80,13 +82,22 @@ class _FakeAgent: self.tool_runner_hooks: ToolRunnerHooks | None = None -def _tool_call(call_id: str, name: str, arguments: dict | None = None) -> CallToolRequest: +def _tool_call(name: str, arguments: dict | None = None) -> CallToolRequest: return CallToolRequest( method="tools/call", params=CallToolRequestParams(name=name, arguments=arguments or {}), ) +def _tool_result( + text: str | None = "ok", *, is_error: bool = False +) -> CallToolResult: + content: list = [] + if text is not None: + content.append(TextContent(type="text", text=text)) + return CallToolResult(content=content, isError=is_error) + + # ── Tests ──────────────────────────────────────────────────────────────────── @@ -101,7 +112,7 @@ def test_emit_text_only_iteration() -> None: stop_reason=LlmStopReason.END_TURN, ) - _run(emitter._emit(msg)) + _run(emitter._emit_assistant_chunk(msg)) assert len(ctx.session.calls) == 1 call = ctx.session.calls[0] @@ -120,8 +131,8 @@ def test_emit_text_only_iteration() -> None: assert data["tool_calls"] == [] -def test_emit_text_with_tool_call_iteration() -> None: - """An assistant turn that emits text and then calls a tool ships both.""" +def test_emit_text_with_tool_call_iteration_carries_args_preview() -> None: + """An iteration that requests a tool ships name, server prefix, and args preview.""" ctx = _FakeContext() emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") @@ -129,33 +140,38 @@ def test_emit_text_with_tool_call_iteration() -> None: role="assistant", content=[TextContent(type="text", text="Logging this…")], tool_calls={ - "toolu_1": _tool_call("toolu_1", "time__get_current_time"), + "toolu_1": _tool_call( + "argos-search_web", + {"query": "ducati v2 vs v4 reliability"}, + ), }, stop_reason=LlmStopReason.TOOL_USE, ) - _run(emitter._emit(msg)) + _run(emitter._emit_assistant_chunk(msg)) assert len(ctx.session.calls) == 1 data = ctx.session.calls[0]["data"] assert data["stop_reason"] == "toolUse" assert data["content"] == [{"type": "text", "text": "Logging this…"}] - assert data["tool_calls"] == [{"id": "toolu_1", "name": "time__get_current_time"}] + assert data["tool_calls"] == [ + { + "id": "toolu_1", + "name": "argos-search_web", + "server": "argos", + "arguments_preview": "ducati v2 vs v4 reliability", + } + ] def test_emit_skips_completely_empty_iteration() -> None: - """A turn with no text blocks and no tool calls emits nothing. - - Tool-call lifecycle is already covered by notifications/progress. An - empty assistant_chunk would just be noise on the wire and a no-op - update on the live bubble. - """ + """A turn with no text blocks and no tool calls emits nothing.""" ctx = _FakeContext() emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") msg = PromptMessageExtended(role="assistant", content=[]) - _run(emitter._emit(msg)) + _run(emitter._emit_assistant_chunk(msg)) assert ctx.session.calls == [] # Iteration counter still bumps so subsequent chunks aren't mis-numbered. @@ -173,7 +189,7 @@ def test_emit_image_block_passes_through_with_mime_type_renamed() -> None: stop_reason=LlmStopReason.END_TURN, ) - _run(emitter._emit(msg)) + _run(emitter._emit_assistant_chunk(msg)) data = ctx.session.calls[0]["data"] assert data["content"] == [ @@ -188,21 +204,21 @@ def test_emit_iterations_are_numbered_in_order() -> None: emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") async def drive() -> None: - await emitter._emit( + await emitter._emit_assistant_chunk( PromptMessageExtended( role="assistant", content=[TextContent(type="text", text="first")], stop_reason=LlmStopReason.TOOL_USE, ) ) - await emitter._emit( + await emitter._emit_assistant_chunk( PromptMessageExtended( role="assistant", content=[TextContent(type="text", text="second")], stop_reason=LlmStopReason.TOOL_USE, ) ) - await emitter._emit( + await emitter._emit_assistant_chunk( PromptMessageExtended( role="assistant", content=[TextContent(type="text", text="third — done")], @@ -232,9 +248,96 @@ def test_emit_swallows_session_failure() -> None: stop_reason=LlmStopReason.END_TURN, ) - # Must not raise. - _run(emitter._emit(msg)) - # And no successful calls were recorded (fail_with raised before append). + _run(emitter._emit_assistant_chunk(msg)) + assert ctx.session.calls == [] + + +def test_emit_tool_results_pairs_call_id_with_iteration() -> None: + """``after_tool_call`` ships a results payload keyed to the iteration that called.""" + ctx = _FakeContext() + emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") + + # Iteration 1: text + tool request. + iter1 = PromptMessageExtended( + role="assistant", + content=[TextContent(type="text", text="Searching…")], + tool_calls={"toolu_1": _tool_call("argos-search_web", {"query": "foo"})}, + stop_reason=LlmStopReason.TOOL_USE, + ) + + async def drive() -> None: + # Simulate the runner: after_llm_call → before_tool_call → after_tool_call. + await emitter._emit_assistant_chunk(iter1) + await emitter.as_before_tool_call_hook()(None, iter1) + # The "user" message produced by the tool runner carries tool_results. + tool_msg = PromptMessageExtended( + role="user", + content=[], + tool_results={"toolu_1": _tool_result("12 results found")}, + ) + await emitter.as_after_tool_call_hook()(None, tool_msg) + + _run(drive()) + + assert len(ctx.session.calls) == 2 + chunk = ctx.session.calls[0]["data"] + results = ctx.session.calls[1]["data"] + assert chunk["kind"] == KIND + assert chunk["iteration"] == 1 + assert results["kind"] == KIND_RESULTS + assert results["iteration"] == 1 + assert results["agent"] == "alan" + assert results["conversation_id"] == "conv-1" + assert len(results["results"]) == 1 + entry = results["results"][0] + assert entry["id"] == "toolu_1" + assert entry["ok"] is True + assert entry["result_preview"] == "12 results found" + assert isinstance(entry["duration_ms"], int) + assert entry["duration_ms"] >= 0 + + +def test_emit_tool_results_marks_error() -> None: + """A failing tool call surfaces ``ok: False`` in the results entry.""" + ctx = _FakeContext() + emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="c") + + iter1 = PromptMessageExtended( + role="assistant", + content=[], + tool_calls={"toolu_1": _tool_call("argos-search_web", {"query": "foo"})}, + stop_reason=LlmStopReason.TOOL_USE, + ) + + async def drive() -> None: + await emitter._emit_assistant_chunk(iter1) + await emitter.as_before_tool_call_hook()(None, iter1) + tool_msg = PromptMessageExtended( + role="user", + content=[], + tool_results={ + "toolu_1": _tool_result("connection refused", is_error=True) + }, + ) + await emitter.as_after_tool_call_hook()(None, tool_msg) + + _run(drive()) + + # iter1 is a pure-tool turn (no text + has tool_calls) → still emits an + # assistant_chunk because tool_calls is non-empty. + assert len(ctx.session.calls) == 2 + results = ctx.session.calls[1]["data"] + assert results["results"][0]["ok"] is False + assert results["results"][0]["result_preview"] == "connection refused" + + +def test_emit_tool_results_empty_when_no_results() -> None: + """An ``after_tool_call`` carrying no tool_results emits nothing.""" + ctx = _FakeContext() + emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="c") + + msg = PromptMessageExtended(role="user", content=[], tool_results=None) + _run(emitter.as_after_tool_call_hook()(None, msg)) assert ctx.session.calls == [] @@ -262,23 +365,22 @@ def test_install_for_request_merges_with_existing_after_llm_call() -> None: _run(agent.tool_runner_hooks.after_llm_call(None, msg)) - # Base ran, and the assistant-stream emitter shipped the chunk. assert seen == ["base"] assert len(ctx.session.calls) == 1 assert ctx.session.calls[0]["data"]["content"] == [ {"type": "text", "text": "hi"} ] - # Other hook slots stay untouched. - assert agent.tool_runner_hooks.before_llm_call is None - assert agent.tool_runner_hooks.before_tool_call is None - assert agent.tool_runner_hooks.after_tool_call is None - assert agent.tool_runner_hooks.after_turn_complete is None + # All four hook slots used by the emitter are now bound; the restore + # call puts the original hooks back exactly. + assert agent.tool_runner_hooks.before_tool_call is not None + assert agent.tool_runner_hooks.after_tool_call is not None - # Restore puts the original hooks back exactly. restore() assert agent.tool_runner_hooks is not None assert agent.tool_runner_hooks.after_llm_call is base_after + assert agent.tool_runner_hooks.before_tool_call is None + assert agent.tool_runner_hooks.after_tool_call is None def test_install_for_request_with_no_existing_hooks() -> None: @@ -293,6 +395,8 @@ def test_install_for_request_with_no_existing_hooks() -> None: assert agent.tool_runner_hooks is not None assert agent.tool_runner_hooks.after_llm_call is not None + assert agent.tool_runner_hooks.before_tool_call is not None + assert agent.tool_runner_hooks.after_tool_call is not None restore() assert agent.tool_runner_hooks is None diff --git a/uv.lock b/uv.lock index 75fae21..f34bee6 100644 --- a/uv.lock +++ b/uv.lock @@ -1562,7 +1562,7 @@ wheels = [ [[package]] name = "pallas-mcp" -version = "0.2.0" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "fast-agent-mcp" },