feat(pallas): stream mid-turn assistant chunks over MCP

Add `AssistantChunkEmitter` that hooks into fast-agent's
`ToolRunnerHooks.after_llm_call` to emit one
`notifications/message` per LLM iteration, carrying structured
content blocks as JSON via the existing StreamableHTTP transport.

This exposes intermediate assistant messages (substantive replies
produced before tool calls) that would otherwise be hidden inside
fast-agent's message_history and never cross the MCP boundary,
letting Daedalus update its live bubble during multi-iteration
tool loops instead of only seeing the final wrap-up text.
This commit is contained in:
2026-05-28 06:09:03 -04:00
parent 440f7fb60c
commit 8a5046fef0

246
pallas/assistant_stream.py Normal file
View File

@@ -0,0 +1,246 @@
"""
Mid-turn assistant chunk streaming over MCP.
The MCP ``tools/call`` contract returns a single ``CallToolResult`` at end of
turn. When fast-agent runs a multi-iteration tool loop, every intermediate
assistant message — including substantive natural-language replies the LLM
emits before deciding to call a tool — stays inside fast-agent's
``message_history`` and never crosses the MCP boundary. The caller (Daedalus)
sees only the final ``last_text()``, which is often a thin wrap-up sentence
while the substantive answer was produced two iterations earlier.
This module fixes that by emitting one MCP ``notifications/message``
(``LoggingMessageNotification``) per assistant turn, carrying the structured
content blocks as opaque JSON in the ``data`` field. The transport is
already plumbed through StreamableHTTP — the SDK delivers it to the client's
``logging_callback`` — so no SDK changes are needed on either side.
The hook installs onto fast-agent's ``ToolRunnerHooks.after_llm_call``,
which runs once per LLM iteration in the tool loop. Pallas merges this hook
with whatever hooks the agent (or fast-agent itself) already has via the
existing ``_merge_tool_runner_hooks`` plumbing — see ``multimodal_server.py``
``_install_assistant_stream_hook`` for the merge-and-restore wiring.
"""
from __future__ import annotations
import logging
from typing import Any, TYPE_CHECKING
from fast_agent.agents.tool_runner import ToolRunnerHooks
from fast_agent.types import PromptMessageExtended
from mcp.types import ImageContent, TextContent
if TYPE_CHECKING:
from fastmcp import Context as MCPContext
logger = logging.getLogger(__name__)
# Logger string the client uses to filter our notifications from any other
# log messages the server might emit. Stays narrow on both sides.
LOGGER_NAME = "pallas.assistant_stream"
# Discriminator inside ``data`` so this transport stays multiplexable later.
KIND = "assistant_chunk"
# Bumped if the on-wire payload shape changes incompatibly. Daedalus reads
# this and ignores chunks whose schema_version it doesn't recognise.
SCHEMA_VERSION = 1
class AssistantChunkEmitter:
"""Per-request emitter that ships each assistant turn over MCP.
One instance per ``send_message`` MCP call. Tracks an iteration counter
so consumers can order chunks within a single turn (and ignore stale
ones from an aborted retry).
"""
def __init__(
self,
ctx: "MCPContext",
*,
agent_name: str,
conversation_id: str | None = None,
) -> None:
self._ctx = ctx
self._agent_name = agent_name
self._conversation_id = conversation_id
self._iteration = 0
def as_after_llm_call_hook(self):
"""Return an async callable suitable for ``ToolRunnerHooks.after_llm_call``."""
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
await self._emit(message)
return _hook
async def _emit(self, message: PromptMessageExtended) -> None:
"""Serialize one assistant turn and ship it as a log notification.
Failures here must never break the agent turn — wrap everything and
log a warning. The user-facing consequence of a dropped chunk is
"the live bubble didn't update once"; the canonical final message
still arrives via the tools/call response.
"""
self._iteration += 1
try:
payload = self._build_payload(message)
except Exception:
logger.warning(
"assistant_stream payload build failed",
exc_info=True,
extra={
"agent": self._agent_name,
"conversation_id": self._conversation_id,
"iteration": self._iteration,
},
)
return
# Skip empty turns (e.g. a pure tool-only iteration with no text).
# The tool-call lifecycle is already surfaced via
# notifications/progress, so an empty assistant_chunk would just be
# noise on the wire and the live bubble.
if not payload["content"] and not payload["tool_calls"]:
return
try:
session = self._ctx.session
related_request_id = getattr(self._ctx, "request_id", None)
await session.send_log_message(
level="info",
data=payload,
logger=LOGGER_NAME,
related_request_id=related_request_id,
)
except Exception:
logger.warning(
"assistant_stream send_log_message failed",
exc_info=True,
extra={
"agent": self._agent_name,
"conversation_id": self._conversation_id,
"iteration": self._iteration,
},
)
def _build_payload(self, message: PromptMessageExtended) -> dict[str, Any]:
content_blocks: list[dict[str, Any]] = []
for block in message.content or []:
if isinstance(block, TextContent) and block.text:
content_blocks.append({"type": "text", "text": block.text})
elif isinstance(block, ImageContent):
# Images on assistant turns are unusual but legal. We send
# the raw base64 — same shape the final tools/call result
# already uses so the client can render it identically.
content_blocks.append(
{
"type": "image",
"data": block.data,
"mime_type": block.mimeType,
}
)
# Other block types (resources, embedded resources, etc.) are
# intentionally omitted from the live stream — they're rare on
# assistant turns and the canonical final message carries them.
tool_calls: list[dict[str, Any]] = []
for call_id, call in (message.tool_calls or {}).items():
params = getattr(call, "params", None)
tool_calls.append(
{
"id": call_id,
"name": getattr(params, "name", None),
}
)
stop_reason = message.stop_reason
if stop_reason is not None:
# LlmStopReason is a str-Enum; .value is the wire form fast-agent
# already uses ("toolUse", "endTurn", ...).
stop_reason_str = getattr(stop_reason, "value", str(stop_reason))
else:
stop_reason_str = None
return {
"kind": KIND,
"schema_version": SCHEMA_VERSION,
"agent": self._agent_name,
"conversation_id": self._conversation_id,
"iteration": self._iteration,
"stop_reason": stop_reason_str,
"content": content_blocks,
"tool_calls": tool_calls,
}
def install_for_request(
agent: Any,
*,
ctx: "MCPContext",
agent_name: str,
conversation_id: str | None,
):
"""Install the assistant-stream hook on a request-scoped agent instance.
Returns a no-arg callable that restores the agent's previous
``tool_runner_hooks`` — call it in a ``finally`` so per-request hooks
don't leak across requests if the underlying instance is ever reused.
Pallas runs ``instance_scope="request"`` so in normal operation the
instance is disposed immediately after; the restore call is defensive
against config changes or future shared-instance modes.
"""
emitter = AssistantChunkEmitter(
ctx, agent_name=agent_name, conversation_id=conversation_id
)
extra = ToolRunnerHooks(after_llm_call=emitter.as_after_llm_call_hook())
previous = getattr(agent, "tool_runner_hooks", None)
merged = _merge_after_llm_call(previous, extra)
agent.tool_runner_hooks = merged
def restore() -> None:
agent.tool_runner_hooks = previous
return restore
def _merge_after_llm_call(
base: ToolRunnerHooks | None, extra: ToolRunnerHooks
) -> ToolRunnerHooks:
"""Compose ``extra.after_llm_call`` after ``base.after_llm_call``.
Mirrors fast-agent's own ``ToolAgent._merge_tool_runner_hooks`` (private)
so we don't depend on its signature. Only the ``after_llm_call`` slot
actually gets composed here — the others are passed through unchanged
from ``base`` since ``extra`` only ever sets ``after_llm_call``.
"""
if base is None:
return extra
base_hook = base.after_llm_call
extra_hook = extra.after_llm_call
if base_hook is None:
merged_after = extra_hook
elif extra_hook is None: # pragma: no cover - extra always sets it
merged_after = base_hook
else:
async def merged(runner: Any, message: PromptMessageExtended) -> None:
await base_hook(runner, message)
await extra_hook(runner, message)
merged_after = merged
return ToolRunnerHooks(
before_llm_call=base.before_llm_call,
after_llm_call=merged_after,
before_tool_call=base.before_tool_call,
after_tool_call=base.after_tool_call,
after_turn_complete=base.after_turn_complete,
)