Compare commits

...

2 Commits

Author SHA1 Message Date
63ced9ba2b feat: stream intermediate assistant turns over MCP
Install a per-request `after_llm_call` hook that emits each intermediate
assistant turn as an MCP `notifications/message`, so users see
substantive text from earlier loop iterations instead of only the final
`agent.send()` return value.

Add tests covering the hook's payload shape, error handling, and
lifecycle via `install_for_request`.
2026-05-28 06:09:41 -04:00
8a5046fef0 feat(pallas): stream mid-turn assistant chunks over MCP
Add `AssistantChunkEmitter` that hooks into fast-agent's
`ToolRunnerHooks.after_llm_call` to emit one
`notifications/message` per LLM iteration, carrying structured
content blocks as JSON via the existing StreamableHTTP transport.

This exposes intermediate assistant messages (substantive replies
produced before tool calls) that would otherwise be hidden inside
fast-agent's message_history and never cross the MCP boundary,
letting Daedalus update its live bubble during multi-iteration
tool loops instead of only seeing the final wrap-up text.
2026-05-28 06:09:03 -04:00
3 changed files with 567 additions and 0 deletions

246
pallas/assistant_stream.py Normal file
View File

@@ -0,0 +1,246 @@
"""
Mid-turn assistant chunk streaming over MCP.
The MCP ``tools/call`` contract returns a single ``CallToolResult`` at end of
turn. When fast-agent runs a multi-iteration tool loop, every intermediate
assistant message — including substantive natural-language replies the LLM
emits before deciding to call a tool — stays inside fast-agent's
``message_history`` and never crosses the MCP boundary. The caller (Daedalus)
sees only the final ``last_text()``, which is often a thin wrap-up sentence
while the substantive answer was produced two iterations earlier.
This module fixes that by emitting one MCP ``notifications/message``
(``LoggingMessageNotification``) per assistant turn, carrying the structured
content blocks as opaque JSON in the ``data`` field. The transport is
already plumbed through StreamableHTTP — the SDK delivers it to the client's
``logging_callback`` — so no SDK changes are needed on either side.
The hook installs onto fast-agent's ``ToolRunnerHooks.after_llm_call``,
which runs once per LLM iteration in the tool loop. Pallas merges this hook
with whatever hooks the agent (or fast-agent itself) already has via the
existing ``_merge_tool_runner_hooks`` plumbing — see ``multimodal_server.py``
``_install_assistant_stream_hook`` for the merge-and-restore wiring.
"""
from __future__ import annotations
import logging
from typing import Any, TYPE_CHECKING
from fast_agent.agents.tool_runner import ToolRunnerHooks
from fast_agent.types import PromptMessageExtended
from mcp.types import ImageContent, TextContent
if TYPE_CHECKING:
from fastmcp import Context as MCPContext
logger = logging.getLogger(__name__)
# Logger string the client uses to filter our notifications from any other
# log messages the server might emit. Stays narrow on both sides.
LOGGER_NAME = "pallas.assistant_stream"
# Discriminator inside ``data`` so this transport stays multiplexable later.
KIND = "assistant_chunk"
# Bumped if the on-wire payload shape changes incompatibly. Daedalus reads
# this and ignores chunks whose schema_version it doesn't recognise.
SCHEMA_VERSION = 1
class AssistantChunkEmitter:
"""Per-request emitter that ships each assistant turn over MCP.
One instance per ``send_message`` MCP call. Tracks an iteration counter
so consumers can order chunks within a single turn (and ignore stale
ones from an aborted retry).
"""
def __init__(
self,
ctx: "MCPContext",
*,
agent_name: str,
conversation_id: str | None = None,
) -> None:
self._ctx = ctx
self._agent_name = agent_name
self._conversation_id = conversation_id
self._iteration = 0
def as_after_llm_call_hook(self):
"""Return an async callable suitable for ``ToolRunnerHooks.after_llm_call``."""
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
await self._emit(message)
return _hook
async def _emit(self, message: PromptMessageExtended) -> None:
"""Serialize one assistant turn and ship it as a log notification.
Failures here must never break the agent turn — wrap everything and
log a warning. The user-facing consequence of a dropped chunk is
"the live bubble didn't update once"; the canonical final message
still arrives via the tools/call response.
"""
self._iteration += 1
try:
payload = self._build_payload(message)
except Exception:
logger.warning(
"assistant_stream payload build failed",
exc_info=True,
extra={
"agent": self._agent_name,
"conversation_id": self._conversation_id,
"iteration": self._iteration,
},
)
return
# Skip empty turns (e.g. a pure tool-only iteration with no text).
# The tool-call lifecycle is already surfaced via
# notifications/progress, so an empty assistant_chunk would just be
# noise on the wire and the live bubble.
if not payload["content"] and not payload["tool_calls"]:
return
try:
session = self._ctx.session
related_request_id = getattr(self._ctx, "request_id", None)
await session.send_log_message(
level="info",
data=payload,
logger=LOGGER_NAME,
related_request_id=related_request_id,
)
except Exception:
logger.warning(
"assistant_stream send_log_message failed",
exc_info=True,
extra={
"agent": self._agent_name,
"conversation_id": self._conversation_id,
"iteration": self._iteration,
},
)
def _build_payload(self, message: PromptMessageExtended) -> dict[str, Any]:
content_blocks: list[dict[str, Any]] = []
for block in message.content or []:
if isinstance(block, TextContent) and block.text:
content_blocks.append({"type": "text", "text": block.text})
elif isinstance(block, ImageContent):
# Images on assistant turns are unusual but legal. We send
# the raw base64 — same shape the final tools/call result
# already uses so the client can render it identically.
content_blocks.append(
{
"type": "image",
"data": block.data,
"mime_type": block.mimeType,
}
)
# Other block types (resources, embedded resources, etc.) are
# intentionally omitted from the live stream — they're rare on
# assistant turns and the canonical final message carries them.
tool_calls: list[dict[str, Any]] = []
for call_id, call in (message.tool_calls or {}).items():
params = getattr(call, "params", None)
tool_calls.append(
{
"id": call_id,
"name": getattr(params, "name", None),
}
)
stop_reason = message.stop_reason
if stop_reason is not None:
# LlmStopReason is a str-Enum; .value is the wire form fast-agent
# already uses ("toolUse", "endTurn", ...).
stop_reason_str = getattr(stop_reason, "value", str(stop_reason))
else:
stop_reason_str = None
return {
"kind": KIND,
"schema_version": SCHEMA_VERSION,
"agent": self._agent_name,
"conversation_id": self._conversation_id,
"iteration": self._iteration,
"stop_reason": stop_reason_str,
"content": content_blocks,
"tool_calls": tool_calls,
}
def install_for_request(
agent: Any,
*,
ctx: "MCPContext",
agent_name: str,
conversation_id: str | None,
):
"""Install the assistant-stream hook on a request-scoped agent instance.
Returns a no-arg callable that restores the agent's previous
``tool_runner_hooks`` — call it in a ``finally`` so per-request hooks
don't leak across requests if the underlying instance is ever reused.
Pallas runs ``instance_scope="request"`` so in normal operation the
instance is disposed immediately after; the restore call is defensive
against config changes or future shared-instance modes.
"""
emitter = AssistantChunkEmitter(
ctx, agent_name=agent_name, conversation_id=conversation_id
)
extra = ToolRunnerHooks(after_llm_call=emitter.as_after_llm_call_hook())
previous = getattr(agent, "tool_runner_hooks", None)
merged = _merge_after_llm_call(previous, extra)
agent.tool_runner_hooks = merged
def restore() -> None:
agent.tool_runner_hooks = previous
return restore
def _merge_after_llm_call(
base: ToolRunnerHooks | None, extra: ToolRunnerHooks
) -> ToolRunnerHooks:
"""Compose ``extra.after_llm_call`` after ``base.after_llm_call``.
Mirrors fast-agent's own ``ToolAgent._merge_tool_runner_hooks`` (private)
so we don't depend on its signature. Only the ``after_llm_call`` slot
actually gets composed here — the others are passed through unchanged
from ``base`` since ``extra`` only ever sets ``after_llm_call``.
"""
if base is None:
return extra
base_hook = base.after_llm_call
extra_hook = extra.after_llm_call
if base_hook is None:
merged_after = extra_hook
elif extra_hook is None: # pragma: no cover - extra always sets it
merged_after = base_hook
else:
async def merged(runner: Any, message: PromptMessageExtended) -> None:
await base_hook(runner, message)
await extra_hook(runner, message)
merged_after = merged
return ToolRunnerHooks(
before_llm_call=base.before_llm_call,
after_llm_call=merged_after,
before_tool_call=base.before_tool_call,
after_tool_call=base.after_tool_call,
after_turn_complete=base.after_turn_complete,
)

View File

@@ -28,6 +28,7 @@ from fast_agent.core.logging.logger import get_logger
from fast_agent.mcp.server import AgentMCPServer from fast_agent.mcp.server import AgentMCPServer
from fast_agent.types import PromptMessageExtended, RequestParams from fast_agent.types import PromptMessageExtended, RequestParams
from pallas.assistant_stream import install_for_request as _install_assistant_stream
from pallas.progress import EnrichedMCPToolProgressManager from pallas.progress import EnrichedMCPToolProgressManager
from pallas import metrics as _pallas_metrics from pallas import metrics as _pallas_metrics
from fastmcp import Context as MCPContext from fastmcp import Context as MCPContext
@@ -220,6 +221,20 @@ class MultimodalAgentMCPServer(AgentMCPServer):
agent_context = getattr(agent, "context", None) agent_context = getattr(agent, "context", None)
metrics_start = time.perf_counter() metrics_start = time.perf_counter()
metrics_outcome = "ok" metrics_outcome = "ok"
# Install per-request after_llm_call hook that ships every
# intermediate assistant turn over MCP as a notifications/message.
# Without this, only the final ``agent.send()`` return value
# crosses the MCP boundary — substantive assistant text emitted
# in earlier loop iterations stays trapped inside fast-agent's
# ``message_history`` and the user sees a spinner that ends with
# a thin wrap-up sentence.
restore_hooks = _install_assistant_stream(
agent,
ctx=ctx,
agent_name=agent_name,
conversation_id=conversation_id,
)
try: try:
# Seed the freshly-created instance's message_history from the # Seed the freshly-created instance's message_history from the
# caller-supplied history so the agent sees the full # caller-supplied history so the agent sees the full
@@ -313,6 +328,14 @@ class MultimodalAgentMCPServer(AgentMCPServer):
_pallas_metrics.send_message_total.labels( _pallas_metrics.send_message_total.labels(
agent=agent_name, outcome=metrics_outcome agent=agent_name, outcome=metrics_outcome
).inc() ).inc()
# Restore the agent's prior tool_runner_hooks before the
# instance is released — defensive against any future
# shared-instance mode where leaking per-request hooks
# across requests would mis-attribute notifications.
try:
restore_hooks()
except Exception:
pass
await self._release_instance(ctx, instance) await self._release_instance(ctx, instance)

View File

@@ -0,0 +1,298 @@
"""Tests for ``pallas.assistant_stream``.
Drives the ``after_llm_call`` hook with handcrafted ``PromptMessageExtended``
objects and asserts the resulting MCP ``send_log_message`` payload shape.
No fast-agent runtime is involved — the hook is a pure async function and
the MCP context is faked.
Tests use ``asyncio.run`` directly to match the convention in
``tests/test_health.py`` and ``tests/test_mantle_shims.py`` (pallas has no
pytest-asyncio dependency).
"""
from __future__ import annotations
import asyncio
from typing import Any
from fast_agent.agents.tool_runner import ToolRunnerHooks
from fast_agent.types import PromptMessageExtended
from fast_agent.types.llm_stop_reason import LlmStopReason
from mcp.types import (
CallToolRequest,
CallToolRequestParams,
ImageContent,
TextContent,
)
from pallas.assistant_stream import (
KIND,
LOGGER_NAME,
SCHEMA_VERSION,
AssistantChunkEmitter,
install_for_request,
)
def _run(coro):
return asyncio.run(coro)
# ── Fakes ────────────────────────────────────────────────────────────────────
class _FakeSession:
"""Records every call to ``send_log_message`` for later assertion."""
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []
self.fail_with: Exception | None = None
async def send_log_message(
self,
*,
level: str,
data: Any,
logger: str | None = None,
related_request_id: Any = None,
) -> None:
if self.fail_with is not None:
raise self.fail_with
self.calls.append(
{
"level": level,
"data": data,
"logger": logger,
"related_request_id": related_request_id,
}
)
class _FakeContext:
def __init__(self, request_id: str = "req-1") -> None:
self.session = _FakeSession()
self.request_id = request_id
class _FakeAgent:
"""Minimal stand-in for a fast-agent agent — only carries hooks."""
def __init__(self) -> None:
self.tool_runner_hooks: ToolRunnerHooks | None = None
def _tool_call(call_id: str, name: str, arguments: dict | None = None) -> CallToolRequest:
return CallToolRequest(
method="tools/call",
params=CallToolRequestParams(name=name, arguments=arguments or {}),
)
# ── Tests ────────────────────────────────────────────────────────────────────
def test_emit_text_only_iteration() -> None:
"""A pure-text assistant turn produces one log message with one text block."""
ctx = _FakeContext(request_id="req-text")
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
msg = PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="Fair. You can't size value...")],
stop_reason=LlmStopReason.END_TURN,
)
_run(emitter._emit(msg))
assert len(ctx.session.calls) == 1
call = ctx.session.calls[0]
assert call["level"] == "info"
assert call["logger"] == LOGGER_NAME
assert call["related_request_id"] == "req-text"
data = call["data"]
assert data["kind"] == KIND
assert data["schema_version"] == SCHEMA_VERSION
assert data["agent"] == "alan"
assert data["conversation_id"] == "conv-1"
assert data["iteration"] == 1
assert data["stop_reason"] == "endTurn"
assert data["content"] == [{"type": "text", "text": "Fair. You can't size value..."}]
assert data["tool_calls"] == []
def test_emit_text_with_tool_call_iteration() -> None:
"""An assistant turn that emits text and then calls a tool ships both."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
msg = PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="Logging this…")],
tool_calls={
"toolu_1": _tool_call("toolu_1", "time__get_current_time"),
},
stop_reason=LlmStopReason.TOOL_USE,
)
_run(emitter._emit(msg))
assert len(ctx.session.calls) == 1
data = ctx.session.calls[0]["data"]
assert data["stop_reason"] == "toolUse"
assert data["content"] == [{"type": "text", "text": "Logging this…"}]
assert data["tool_calls"] == [{"id": "toolu_1", "name": "time__get_current_time"}]
def test_emit_skips_completely_empty_iteration() -> None:
"""A turn with no text blocks and no tool calls emits nothing.
Tool-call lifecycle is already covered by notifications/progress. An
empty assistant_chunk would just be noise on the wire and a no-op
update on the live bubble.
"""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
msg = PromptMessageExtended(role="assistant", content=[])
_run(emitter._emit(msg))
assert ctx.session.calls == []
# Iteration counter still bumps so subsequent chunks aren't mis-numbered.
assert emitter._iteration == 1
def test_emit_image_block_passes_through_with_mime_type_renamed() -> None:
"""ImageContent blocks are serialized with ``mime_type`` (snake-case)."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id=None)
msg = PromptMessageExtended(
role="assistant",
content=[ImageContent(type="image", data="ZmFrZQ==", mimeType="image/png")],
stop_reason=LlmStopReason.END_TURN,
)
_run(emitter._emit(msg))
data = ctx.session.calls[0]["data"]
assert data["content"] == [
{"type": "image", "data": "ZmFrZQ==", "mime_type": "image/png"}
]
assert data["conversation_id"] is None
def test_emit_iterations_are_numbered_in_order() -> None:
"""A multi-iteration loop produces sequentially numbered chunks."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
async def drive() -> None:
await emitter._emit(
PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="first")],
stop_reason=LlmStopReason.TOOL_USE,
)
)
await emitter._emit(
PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="second")],
stop_reason=LlmStopReason.TOOL_USE,
)
)
await emitter._emit(
PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="third — done")],
stop_reason=LlmStopReason.END_TURN,
)
)
_run(drive())
assert [c["data"]["iteration"] for c in ctx.session.calls] == [1, 2, 3]
assert [c["data"]["stop_reason"] for c in ctx.session.calls] == [
"toolUse",
"toolUse",
"endTurn",
]
def test_emit_swallows_session_failure() -> None:
"""If ``send_log_message`` raises, the hook does not propagate."""
ctx = _FakeContext()
ctx.session.fail_with = RuntimeError("transport closed")
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
msg = PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="hi")],
stop_reason=LlmStopReason.END_TURN,
)
# Must not raise.
_run(emitter._emit(msg))
# And no successful calls were recorded (fail_with raised before append).
assert ctx.session.calls == []
def test_install_for_request_merges_with_existing_after_llm_call() -> None:
"""A pre-existing ``after_llm_call`` hook is composed, not replaced."""
ctx = _FakeContext()
agent = _FakeAgent()
seen: list[str] = []
async def base_after(_runner: Any, message: PromptMessageExtended) -> None:
seen.append("base")
agent.tool_runner_hooks = ToolRunnerHooks(after_llm_call=base_after)
restore = install_for_request(
agent, ctx=ctx, agent_name="alan", conversation_id="conv-1"
)
msg = PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="hi")],
stop_reason=LlmStopReason.END_TURN,
)
_run(agent.tool_runner_hooks.after_llm_call(None, msg))
# Base ran, and the assistant-stream emitter shipped the chunk.
assert seen == ["base"]
assert len(ctx.session.calls) == 1
assert ctx.session.calls[0]["data"]["content"] == [
{"type": "text", "text": "hi"}
]
# Other hook slots stay untouched.
assert agent.tool_runner_hooks.before_llm_call is None
assert agent.tool_runner_hooks.before_tool_call is None
assert agent.tool_runner_hooks.after_tool_call is None
assert agent.tool_runner_hooks.after_turn_complete is None
# Restore puts the original hooks back exactly.
restore()
assert agent.tool_runner_hooks is not None
assert agent.tool_runner_hooks.after_llm_call is base_after
def test_install_for_request_with_no_existing_hooks() -> None:
"""When the agent has no prior hooks, ours is installed cleanly."""
ctx = _FakeContext()
agent = _FakeAgent()
assert agent.tool_runner_hooks is None
restore = install_for_request(
agent, ctx=ctx, agent_name="alan", conversation_id=None
)
assert agent.tool_runner_hooks is not None
assert agent.tool_runner_hooks.after_llm_call is not None
restore()
assert agent.tool_runner_hooks is None