feat(pallas): stream mid-turn assistant chunks over MCP
Add `AssistantChunkEmitter` that hooks into fast-agent's `ToolRunnerHooks.after_llm_call` to emit one `notifications/message` per LLM iteration, carrying structured content blocks as JSON via the existing StreamableHTTP transport. This exposes intermediate assistant messages (substantive replies produced before tool calls) that would otherwise be hidden inside fast-agent's message_history and never cross the MCP boundary, letting Daedalus update its live bubble during multi-iteration tool loops instead of only seeing the final wrap-up text.
This commit is contained in:
246
pallas/assistant_stream.py
Normal file
246
pallas/assistant_stream.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
"""
|
||||||
|
Mid-turn assistant chunk streaming over MCP.
|
||||||
|
|
||||||
|
The MCP ``tools/call`` contract returns a single ``CallToolResult`` at end of
|
||||||
|
turn. When fast-agent runs a multi-iteration tool loop, every intermediate
|
||||||
|
assistant message — including substantive natural-language replies the LLM
|
||||||
|
emits before deciding to call a tool — stays inside fast-agent's
|
||||||
|
``message_history`` and never crosses the MCP boundary. The caller (Daedalus)
|
||||||
|
sees only the final ``last_text()``, which is often a thin wrap-up sentence
|
||||||
|
while the substantive answer was produced two iterations earlier.
|
||||||
|
|
||||||
|
This module fixes that by emitting one MCP ``notifications/message``
|
||||||
|
(``LoggingMessageNotification``) per assistant turn, carrying the structured
|
||||||
|
content blocks as opaque JSON in the ``data`` field. The transport is
|
||||||
|
already plumbed through StreamableHTTP — the SDK delivers it to the client's
|
||||||
|
``logging_callback`` — so no SDK changes are needed on either side.
|
||||||
|
|
||||||
|
The hook installs onto fast-agent's ``ToolRunnerHooks.after_llm_call``,
|
||||||
|
which runs once per LLM iteration in the tool loop. Pallas merges this hook
|
||||||
|
with whatever hooks the agent (or fast-agent itself) already has via the
|
||||||
|
existing ``_merge_tool_runner_hooks`` plumbing — see ``multimodal_server.py``
|
||||||
|
``_install_assistant_stream_hook`` for the merge-and-restore wiring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
from fast_agent.agents.tool_runner import ToolRunnerHooks
|
||||||
|
from fast_agent.types import PromptMessageExtended
|
||||||
|
from mcp.types import ImageContent, TextContent
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastmcp import Context as MCPContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Logger string the client uses to filter our notifications from any other
|
||||||
|
# log messages the server might emit. Stays narrow on both sides.
|
||||||
|
LOGGER_NAME = "pallas.assistant_stream"
|
||||||
|
|
||||||
|
# Discriminator inside ``data`` so this transport stays multiplexable later.
|
||||||
|
KIND = "assistant_chunk"
|
||||||
|
|
||||||
|
# Bumped if the on-wire payload shape changes incompatibly. Daedalus reads
|
||||||
|
# this and ignores chunks whose schema_version it doesn't recognise.
|
||||||
|
SCHEMA_VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantChunkEmitter:
|
||||||
|
"""Per-request emitter that ships each assistant turn over MCP.
|
||||||
|
|
||||||
|
One instance per ``send_message`` MCP call. Tracks an iteration counter
|
||||||
|
so consumers can order chunks within a single turn (and ignore stale
|
||||||
|
ones from an aborted retry).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ctx: "MCPContext",
|
||||||
|
*,
|
||||||
|
agent_name: str,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._ctx = ctx
|
||||||
|
self._agent_name = agent_name
|
||||||
|
self._conversation_id = conversation_id
|
||||||
|
self._iteration = 0
|
||||||
|
|
||||||
|
def as_after_llm_call_hook(self):
|
||||||
|
"""Return an async callable suitable for ``ToolRunnerHooks.after_llm_call``."""
|
||||||
|
|
||||||
|
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
await self._emit(message)
|
||||||
|
|
||||||
|
return _hook
|
||||||
|
|
||||||
|
async def _emit(self, message: PromptMessageExtended) -> None:
|
||||||
|
"""Serialize one assistant turn and ship it as a log notification.
|
||||||
|
|
||||||
|
Failures here must never break the agent turn — wrap everything and
|
||||||
|
log a warning. The user-facing consequence of a dropped chunk is
|
||||||
|
"the live bubble didn't update once"; the canonical final message
|
||||||
|
still arrives via the tools/call response.
|
||||||
|
"""
|
||||||
|
self._iteration += 1
|
||||||
|
try:
|
||||||
|
payload = self._build_payload(message)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"assistant_stream payload build failed",
|
||||||
|
exc_info=True,
|
||||||
|
extra={
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
"iteration": self._iteration,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip empty turns (e.g. a pure tool-only iteration with no text).
|
||||||
|
# The tool-call lifecycle is already surfaced via
|
||||||
|
# notifications/progress, so an empty assistant_chunk would just be
|
||||||
|
# noise on the wire and the live bubble.
|
||||||
|
if not payload["content"] and not payload["tool_calls"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
session = self._ctx.session
|
||||||
|
related_request_id = getattr(self._ctx, "request_id", None)
|
||||||
|
await session.send_log_message(
|
||||||
|
level="info",
|
||||||
|
data=payload,
|
||||||
|
logger=LOGGER_NAME,
|
||||||
|
related_request_id=related_request_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"assistant_stream send_log_message failed",
|
||||||
|
exc_info=True,
|
||||||
|
extra={
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
"iteration": self._iteration,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_payload(self, message: PromptMessageExtended) -> dict[str, Any]:
|
||||||
|
content_blocks: list[dict[str, Any]] = []
|
||||||
|
for block in message.content or []:
|
||||||
|
if isinstance(block, TextContent) and block.text:
|
||||||
|
content_blocks.append({"type": "text", "text": block.text})
|
||||||
|
elif isinstance(block, ImageContent):
|
||||||
|
# Images on assistant turns are unusual but legal. We send
|
||||||
|
# the raw base64 — same shape the final tools/call result
|
||||||
|
# already uses so the client can render it identically.
|
||||||
|
content_blocks.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"data": block.data,
|
||||||
|
"mime_type": block.mimeType,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Other block types (resources, embedded resources, etc.) are
|
||||||
|
# intentionally omitted from the live stream — they're rare on
|
||||||
|
# assistant turns and the canonical final message carries them.
|
||||||
|
|
||||||
|
tool_calls: list[dict[str, Any]] = []
|
||||||
|
for call_id, call in (message.tool_calls or {}).items():
|
||||||
|
params = getattr(call, "params", None)
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": call_id,
|
||||||
|
"name": getattr(params, "name", None),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stop_reason = message.stop_reason
|
||||||
|
if stop_reason is not None:
|
||||||
|
# LlmStopReason is a str-Enum; .value is the wire form fast-agent
|
||||||
|
# already uses ("toolUse", "endTurn", ...).
|
||||||
|
stop_reason_str = getattr(stop_reason, "value", str(stop_reason))
|
||||||
|
else:
|
||||||
|
stop_reason_str = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"kind": KIND,
|
||||||
|
"schema_version": SCHEMA_VERSION,
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
"iteration": self._iteration,
|
||||||
|
"stop_reason": stop_reason_str,
|
||||||
|
"content": content_blocks,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def install_for_request(
|
||||||
|
agent: Any,
|
||||||
|
*,
|
||||||
|
ctx: "MCPContext",
|
||||||
|
agent_name: str,
|
||||||
|
conversation_id: str | None,
|
||||||
|
):
|
||||||
|
"""Install the assistant-stream hook on a request-scoped agent instance.
|
||||||
|
|
||||||
|
Returns a no-arg callable that restores the agent's previous
|
||||||
|
``tool_runner_hooks`` — call it in a ``finally`` so per-request hooks
|
||||||
|
don't leak across requests if the underlying instance is ever reused.
|
||||||
|
|
||||||
|
Pallas runs ``instance_scope="request"`` so in normal operation the
|
||||||
|
instance is disposed immediately after; the restore call is defensive
|
||||||
|
against config changes or future shared-instance modes.
|
||||||
|
"""
|
||||||
|
emitter = AssistantChunkEmitter(
|
||||||
|
ctx, agent_name=agent_name, conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
extra = ToolRunnerHooks(after_llm_call=emitter.as_after_llm_call_hook())
|
||||||
|
|
||||||
|
previous = getattr(agent, "tool_runner_hooks", None)
|
||||||
|
merged = _merge_after_llm_call(previous, extra)
|
||||||
|
agent.tool_runner_hooks = merged
|
||||||
|
|
||||||
|
def restore() -> None:
|
||||||
|
agent.tool_runner_hooks = previous
|
||||||
|
|
||||||
|
return restore
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_after_llm_call(
|
||||||
|
base: ToolRunnerHooks | None, extra: ToolRunnerHooks
|
||||||
|
) -> ToolRunnerHooks:
|
||||||
|
"""Compose ``extra.after_llm_call`` after ``base.after_llm_call``.
|
||||||
|
|
||||||
|
Mirrors fast-agent's own ``ToolAgent._merge_tool_runner_hooks`` (private)
|
||||||
|
so we don't depend on its signature. Only the ``after_llm_call`` slot
|
||||||
|
actually gets composed here — the others are passed through unchanged
|
||||||
|
from ``base`` since ``extra`` only ever sets ``after_llm_call``.
|
||||||
|
"""
|
||||||
|
if base is None:
|
||||||
|
return extra
|
||||||
|
|
||||||
|
base_hook = base.after_llm_call
|
||||||
|
extra_hook = extra.after_llm_call
|
||||||
|
|
||||||
|
if base_hook is None:
|
||||||
|
merged_after = extra_hook
|
||||||
|
elif extra_hook is None: # pragma: no cover - extra always sets it
|
||||||
|
merged_after = base_hook
|
||||||
|
else:
|
||||||
|
|
||||||
|
async def merged(runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
await base_hook(runner, message)
|
||||||
|
await extra_hook(runner, message)
|
||||||
|
|
||||||
|
merged_after = merged
|
||||||
|
|
||||||
|
return ToolRunnerHooks(
|
||||||
|
before_llm_call=base.before_llm_call,
|
||||||
|
after_llm_call=merged_after,
|
||||||
|
before_tool_call=base.before_tool_call,
|
||||||
|
after_tool_call=base.after_tool_call,
|
||||||
|
after_turn_complete=base.after_turn_complete,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user