Files
pallas/pallas/mcp_subagent_relay.py

214 lines
8.5 KiB
Python

"""Sub-agent activity relay over the MCP logging channel.
When Alan (or any peer-orchestrator agent) calls a sub-agent like ``research``,
that sub-agent is itself a Pallas server emitting its own ``assistant_chunk``
notifications. Those notifications are delivered to fast-agent's
``MCPAgentClientSession`` for the peer — and stop there: fast-agent's default
session has no ``logging_callback``, so the only thing that ever crosses
back up to Daedalus is the final ``CallToolResult`` text. When a sub-agent
does substantive multi-turn reasoning, all that reasoning disappears into
server logs.
This module installs a ``logging_callback`` on every peer session Pallas
opens and re-emits the relevant notifications upstream, decorated so the
consumer can render them as **nested iterations** under the parent
tool-call entry.
Two pieces of provenance are added:
* ``via_agent`` — call-stack list (e.g. ``["research"]`` for chunks the
``research`` peer emitted; ``["research", "search"]`` if research itself
relayed deeper). The parent agent is implicit (it's whose message
these attach to) so the list contains only relayed-through peers.
* ``parent_call_id`` — the LLM ``tool_use_id`` (e.g. ``toolu_42``) that
the parent's LLM emitted to invoke this peer. The Daedalus collector
uses it to nest the relayed iterations under the matching tool_call
entry on the parent's iteration card.
The schema_version is bumped to **3** when relaying. Top-level chunks
(produced by the parent agent's own assistant_stream hooks, no
``parent_call_id``) are unchanged at v2 — adding a new optional field on
top-level entries doesn't change the wire shape.
Depth cap (``max_relay_depth``, default 3) drops chunks that would push
``len(via_agent)`` past the cap. This guards against runaway nested-pyramid
JSON in ``Message.iterations`` if a future agent wired up indefinite peer
chains; today the realistic depth is 2 (Alan → research → search/knowledge).
"""
from __future__ import annotations
import logging
from collections import deque
from typing import Any, Awaitable, Callable, Deque, Tuple
from mcp.types import LoggingMessageNotificationParams
logger = logging.getLogger(__name__)
# Logger string Pallas tags assistant-stream notifications with.
_ASSISTANT_STREAM_LOGGER = "pallas.assistant_stream"
_KIND = "assistant_chunk"
_KIND_RESULTS = "assistant_chunk_results"
_KINDS = frozenset({_KIND, _KIND_RESULTS})
# Schema produced by this relay. The relay only re-emits payloads at v2+
# (the version that carries enriched tool_calls + companion results); the
# bump to 3 signals "may carry via_agent / parent_call_id".
RELAY_SCHEMA_VERSION = 3
# Default depth cap. Override at install time via ``max_relay_depth``.
DEFAULT_MAX_RELAY_DEPTH = 3
# Per-session in-flight stack: (parent_call_id, peer_server_name).
# Pushed when ``MCPAggregator.call_tool`` enters; popped when it exits.
# Stored as a deque so concurrent same-server tool calls (rare — fast-agent
# supports parallel tool execution within an iteration, see
# ``tool_agent.py`` ``run_tools`` ``should_parallel`` branch) get attributed
# to whichever push is currently most recent. This is a best-effort
# attribution; the realistic concurrent shape is "different servers" not
# "same server".
_INFLIGHT_ATTR = "_pallas_relay_inflight"
def push_inflight(session: Any, parent_call_id: str, peer_name: str) -> None:
"""Record an in-flight peer call so the relay can tag its chunks.
The session is the ``MCPAgentClientSession`` instance that fast-agent's
aggregator opened to the peer. We attach a deque lazily so this
helper is a no-op on sessions that never receive any chunks.
"""
stack: Deque[Tuple[str, str]] = getattr(session, _INFLIGHT_ATTR, None)
if stack is None:
stack = deque()
setattr(session, _INFLIGHT_ATTR, stack)
stack.append((parent_call_id, peer_name))
def pop_inflight(session: Any, parent_call_id: str) -> None:
"""Remove the most-recent matching entry for this parent_call_id.
Match by id rather than blind ``pop()`` so out-of-order completions
(when fast-agent runs parallel tool calls) don't strand entries.
"""
stack: Deque[Tuple[str, str]] | None = getattr(session, _INFLIGHT_ATTR, None)
if not stack:
return
# Walk the deque from the right (newest first) to find the match.
for i in range(len(stack) - 1, -1, -1):
if stack[i][0] == parent_call_id:
del stack[i]
return
def current_inflight(session: Any) -> Tuple[str, str] | None:
"""Return ``(parent_call_id, peer_name)`` for the most-recent in-flight call.
Returns ``None`` when no peer call is currently in flight on this
session — a chunk arriving in that window came from server-side
bookkeeping (e.g. the peer's own startup chatter) that we shouldn't
attribute to any specific parent tool_call.
"""
stack: Deque[Tuple[str, str]] | None = getattr(session, _INFLIGHT_ATTR, None)
if not stack:
return None
return stack[-1]
def make_logging_callback(
*,
upstream_send: Callable[[dict[str, Any]], Awaitable[None]],
session: Any,
max_depth: int = DEFAULT_MAX_RELAY_DEPTH,
) -> Callable[[LoggingMessageNotificationParams], Awaitable[None]]:
"""Build a ``logging_callback`` that relays peer assistant-stream chunks.
Parameters
----------
upstream_send:
Async callable that ships a payload upward via the parent's MCP
session (typically wraps ``ctx.session.send_log_message`` from
the parent's per-request ``MCPContext``). Failures inside it
must not propagate — the relay logs and moves on.
session:
The peer ``MCPAgentClientSession`` this callback is attached to.
Used to look up the in-flight ``(parent_call_id, peer_name)`` for
attribution.
max_depth:
Hard cap on ``len(via_agent)`` before relaying. Chunks that
would exceed it are dropped with a debug log.
"""
async def _callback(params: LoggingMessageNotificationParams) -> None:
# 1) Filter — only Pallas assistant-stream notifications get relayed;
# every other log message routes through the existing trace
# handlers and isn't our concern.
if params.logger != _ASSISTANT_STREAM_LOGGER:
return
data = params.data
if not isinstance(data, dict):
return
if data.get("kind") not in _KINDS:
return
schema = data.get("schema_version")
if not isinstance(schema, int) or schema < 2:
# v1 chunks don't carry the enriched tool_calls / results
# shape the consumer needs to render nested iterations.
return
# 2) Look up the in-flight tool call this chunk belongs to. No
# in-flight call → drop (the chunk arrived during peer
# bookkeeping outside any tool_call, which we can't attribute).
inflight = current_inflight(session)
if inflight is None:
return
parent_call_id, peer_name = inflight
# 3) Enforce the depth cap. ``via_agent`` may already exist if
# the peer was itself relaying chunks from a deeper sub-agent.
existing_via: list[str] = list(data.get("via_agent") or [])
if len(existing_via) >= max_depth:
logger.debug(
"subagent_relay_depth_capped",
extra={
"depth": len(existing_via),
"max_depth": max_depth,
"peer": peer_name,
"parent_call_id": parent_call_id,
},
)
return
# 4) Decorate the payload and re-emit upstream. Mutate a shallow
# copy so we don't change a dict the SDK might still reference.
relayed = dict(data)
relayed["via_agent"] = existing_via + [peer_name]
relayed["parent_call_id"] = parent_call_id
relayed["schema_version"] = RELAY_SCHEMA_VERSION
try:
await upstream_send(relayed)
except Exception:
logger.warning(
"subagent_relay_upstream_send_failed",
exc_info=True,
extra={
"peer": peer_name,
"parent_call_id": parent_call_id,
"kind": relayed.get("kind"),
},
)
return _callback
__all__ = [
"RELAY_SCHEMA_VERSION",
"DEFAULT_MAX_RELAY_DEPTH",
"current_inflight",
"make_logging_callback",
"pop_inflight",
"push_inflight",
]