diff --git a/pallas/assistant_stream.py b/pallas/assistant_stream.py new file mode 100644 index 0000000..2361c36 --- /dev/null +++ b/pallas/assistant_stream.py @@ -0,0 +1,246 @@ +""" +Mid-turn assistant chunk streaming over MCP. + +The MCP ``tools/call`` contract returns a single ``CallToolResult`` at end of +turn. When fast-agent runs a multi-iteration tool loop, every intermediate +assistant message — including substantive natural-language replies the LLM +emits before deciding to call a tool — stays inside fast-agent's +``message_history`` and never crosses the MCP boundary. The caller (Daedalus) +sees only the final ``last_text()``, which is often a thin wrap-up sentence +while the substantive answer was produced two iterations earlier. + +This module fixes that by emitting one MCP ``notifications/message`` +(``LoggingMessageNotification``) per assistant turn, carrying the structured +content blocks as opaque JSON in the ``data`` field. The transport is +already plumbed through StreamableHTTP — the SDK delivers it to the client's +``logging_callback`` — so no SDK changes are needed on either side. + +The hook installs onto fast-agent's ``ToolRunnerHooks.after_llm_call``, +which runs once per LLM iteration in the tool loop. Pallas merges this hook +with whatever hooks the agent (or fast-agent itself) already has via the +existing ``_merge_tool_runner_hooks`` plumbing — see ``multimodal_server.py`` +``_install_assistant_stream_hook`` for the merge-and-restore wiring. +""" + +from __future__ import annotations + +import logging +from typing import Any, TYPE_CHECKING + +from fast_agent.agents.tool_runner import ToolRunnerHooks +from fast_agent.types import PromptMessageExtended +from mcp.types import ImageContent, TextContent + +if TYPE_CHECKING: + from fastmcp import Context as MCPContext + +logger = logging.getLogger(__name__) + +# Logger string the client uses to filter our notifications from any other +# log messages the server might emit. Stays narrow on both sides. +LOGGER_NAME = "pallas.assistant_stream" + +# Discriminator inside ``data`` so this transport stays multiplexable later. +KIND = "assistant_chunk" + +# Bumped if the on-wire payload shape changes incompatibly. Daedalus reads +# this and ignores chunks whose schema_version it doesn't recognise. +SCHEMA_VERSION = 1 + + +class AssistantChunkEmitter: + """Per-request emitter that ships each assistant turn over MCP. + + One instance per ``send_message`` MCP call. Tracks an iteration counter + so consumers can order chunks within a single turn (and ignore stale + ones from an aborted retry). + """ + + def __init__( + self, + ctx: "MCPContext", + *, + agent_name: str, + conversation_id: str | None = None, + ) -> None: + self._ctx = ctx + self._agent_name = agent_name + self._conversation_id = conversation_id + self._iteration = 0 + + def as_after_llm_call_hook(self): + """Return an async callable suitable for ``ToolRunnerHooks.after_llm_call``.""" + + async def _hook(_runner: Any, message: PromptMessageExtended) -> None: + await self._emit(message) + + return _hook + + async def _emit(self, message: PromptMessageExtended) -> None: + """Serialize one assistant turn and ship it as a log notification. + + Failures here must never break the agent turn — wrap everything and + log a warning. The user-facing consequence of a dropped chunk is + "the live bubble didn't update once"; the canonical final message + still arrives via the tools/call response. + """ + self._iteration += 1 + try: + payload = self._build_payload(message) + except Exception: + logger.warning( + "assistant_stream payload build failed", + exc_info=True, + extra={ + "agent": self._agent_name, + "conversation_id": self._conversation_id, + "iteration": self._iteration, + }, + ) + return + + # Skip empty turns (e.g. a pure tool-only iteration with no text). + # The tool-call lifecycle is already surfaced via + # notifications/progress, so an empty assistant_chunk would just be + # noise on the wire and the live bubble. + if not payload["content"] and not payload["tool_calls"]: + return + + try: + session = self._ctx.session + related_request_id = getattr(self._ctx, "request_id", None) + await session.send_log_message( + level="info", + data=payload, + logger=LOGGER_NAME, + related_request_id=related_request_id, + ) + except Exception: + logger.warning( + "assistant_stream send_log_message failed", + exc_info=True, + extra={ + "agent": self._agent_name, + "conversation_id": self._conversation_id, + "iteration": self._iteration, + }, + ) + + def _build_payload(self, message: PromptMessageExtended) -> dict[str, Any]: + content_blocks: list[dict[str, Any]] = [] + for block in message.content or []: + if isinstance(block, TextContent) and block.text: + content_blocks.append({"type": "text", "text": block.text}) + elif isinstance(block, ImageContent): + # Images on assistant turns are unusual but legal. We send + # the raw base64 — same shape the final tools/call result + # already uses so the client can render it identically. + content_blocks.append( + { + "type": "image", + "data": block.data, + "mime_type": block.mimeType, + } + ) + # Other block types (resources, embedded resources, etc.) are + # intentionally omitted from the live stream — they're rare on + # assistant turns and the canonical final message carries them. + + tool_calls: list[dict[str, Any]] = [] + for call_id, call in (message.tool_calls or {}).items(): + params = getattr(call, "params", None) + tool_calls.append( + { + "id": call_id, + "name": getattr(params, "name", None), + } + ) + + stop_reason = message.stop_reason + if stop_reason is not None: + # LlmStopReason is a str-Enum; .value is the wire form fast-agent + # already uses ("toolUse", "endTurn", ...). + stop_reason_str = getattr(stop_reason, "value", str(stop_reason)) + else: + stop_reason_str = None + + return { + "kind": KIND, + "schema_version": SCHEMA_VERSION, + "agent": self._agent_name, + "conversation_id": self._conversation_id, + "iteration": self._iteration, + "stop_reason": stop_reason_str, + "content": content_blocks, + "tool_calls": tool_calls, + } + + +def install_for_request( + agent: Any, + *, + ctx: "MCPContext", + agent_name: str, + conversation_id: str | None, +): + """Install the assistant-stream hook on a request-scoped agent instance. + + Returns a no-arg callable that restores the agent's previous + ``tool_runner_hooks`` — call it in a ``finally`` so per-request hooks + don't leak across requests if the underlying instance is ever reused. + + Pallas runs ``instance_scope="request"`` so in normal operation the + instance is disposed immediately after; the restore call is defensive + against config changes or future shared-instance modes. + """ + emitter = AssistantChunkEmitter( + ctx, agent_name=agent_name, conversation_id=conversation_id + ) + + extra = ToolRunnerHooks(after_llm_call=emitter.as_after_llm_call_hook()) + + previous = getattr(agent, "tool_runner_hooks", None) + merged = _merge_after_llm_call(previous, extra) + agent.tool_runner_hooks = merged + + def restore() -> None: + agent.tool_runner_hooks = previous + + return restore + + +def _merge_after_llm_call( + base: ToolRunnerHooks | None, extra: ToolRunnerHooks +) -> ToolRunnerHooks: + """Compose ``extra.after_llm_call`` after ``base.after_llm_call``. + + Mirrors fast-agent's own ``ToolAgent._merge_tool_runner_hooks`` (private) + so we don't depend on its signature. Only the ``after_llm_call`` slot + actually gets composed here — the others are passed through unchanged + from ``base`` since ``extra`` only ever sets ``after_llm_call``. + """ + if base is None: + return extra + + base_hook = base.after_llm_call + extra_hook = extra.after_llm_call + + if base_hook is None: + merged_after = extra_hook + elif extra_hook is None: # pragma: no cover - extra always sets it + merged_after = base_hook + else: + + async def merged(runner: Any, message: PromptMessageExtended) -> None: + await base_hook(runner, message) + await extra_hook(runner, message) + + merged_after = merged + + return ToolRunnerHooks( + before_llm_call=base.before_llm_call, + after_llm_call=merged_after, + before_tool_call=base.before_tool_call, + after_tool_call=base.after_tool_call, + after_turn_complete=base.after_turn_complete, + )