feat: stream intermediate assistant turns over MCP
Install a per-request `after_llm_call` hook that emits each intermediate assistant turn as an MCP `notifications/message`, so users see substantive text from earlier loop iterations instead of only the final `agent.send()` return value. Add tests covering the hook's payload shape, error handling, and lifecycle via `install_for_request`.
This commit is contained in:
298
tests/test_assistant_stream.py
Normal file
298
tests/test_assistant_stream.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""Tests for ``pallas.assistant_stream``.
|
||||
|
||||
Drives the ``after_llm_call`` hook with handcrafted ``PromptMessageExtended``
|
||||
objects and asserts the resulting MCP ``send_log_message`` payload shape.
|
||||
No fast-agent runtime is involved — the hook is a pure async function and
|
||||
the MCP context is faked.
|
||||
|
||||
Tests use ``asyncio.run`` directly to match the convention in
|
||||
``tests/test_health.py`` and ``tests/test_mantle_shims.py`` (pallas has no
|
||||
pytest-asyncio dependency).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from fast_agent.agents.tool_runner import ToolRunnerHooks
|
||||
from fast_agent.types import PromptMessageExtended
|
||||
from fast_agent.types.llm_stop_reason import LlmStopReason
|
||||
from mcp.types import (
|
||||
CallToolRequest,
|
||||
CallToolRequestParams,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
from pallas.assistant_stream import (
|
||||
KIND,
|
||||
LOGGER_NAME,
|
||||
SCHEMA_VERSION,
|
||||
AssistantChunkEmitter,
|
||||
install_for_request,
|
||||
)
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
# ── Fakes ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Records every call to ``send_log_message`` for later assertion."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
self.fail_with: Exception | None = None
|
||||
|
||||
async def send_log_message(
|
||||
self,
|
||||
*,
|
||||
level: str,
|
||||
data: Any,
|
||||
logger: str | None = None,
|
||||
related_request_id: Any = None,
|
||||
) -> None:
|
||||
if self.fail_with is not None:
|
||||
raise self.fail_with
|
||||
self.calls.append(
|
||||
{
|
||||
"level": level,
|
||||
"data": data,
|
||||
"logger": logger,
|
||||
"related_request_id": related_request_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class _FakeContext:
|
||||
def __init__(self, request_id: str = "req-1") -> None:
|
||||
self.session = _FakeSession()
|
||||
self.request_id = request_id
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
"""Minimal stand-in for a fast-agent agent — only carries hooks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tool_runner_hooks: ToolRunnerHooks | None = None
|
||||
|
||||
|
||||
def _tool_call(call_id: str, name: str, arguments: dict | None = None) -> CallToolRequest:
|
||||
return CallToolRequest(
|
||||
method="tools/call",
|
||||
params=CallToolRequestParams(name=name, arguments=arguments or {}),
|
||||
)
|
||||
|
||||
|
||||
# ── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_emit_text_only_iteration() -> None:
|
||||
"""A pure-text assistant turn produces one log message with one text block."""
|
||||
ctx = _FakeContext(request_id="req-text")
|
||||
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||
|
||||
msg = PromptMessageExtended(
|
||||
role="assistant",
|
||||
content=[TextContent(type="text", text="Fair. You can't size value...")],
|
||||
stop_reason=LlmStopReason.END_TURN,
|
||||
)
|
||||
|
||||
_run(emitter._emit(msg))
|
||||
|
||||
assert len(ctx.session.calls) == 1
|
||||
call = ctx.session.calls[0]
|
||||
assert call["level"] == "info"
|
||||
assert call["logger"] == LOGGER_NAME
|
||||
assert call["related_request_id"] == "req-text"
|
||||
|
||||
data = call["data"]
|
||||
assert data["kind"] == KIND
|
||||
assert data["schema_version"] == SCHEMA_VERSION
|
||||
assert data["agent"] == "alan"
|
||||
assert data["conversation_id"] == "conv-1"
|
||||
assert data["iteration"] == 1
|
||||
assert data["stop_reason"] == "endTurn"
|
||||
assert data["content"] == [{"type": "text", "text": "Fair. You can't size value..."}]
|
||||
assert data["tool_calls"] == []
|
||||
|
||||
|
||||
def test_emit_text_with_tool_call_iteration() -> None:
|
||||
"""An assistant turn that emits text and then calls a tool ships both."""
|
||||
ctx = _FakeContext()
|
||||
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||
|
||||
msg = PromptMessageExtended(
|
||||
role="assistant",
|
||||
content=[TextContent(type="text", text="Logging this…")],
|
||||
tool_calls={
|
||||
"toolu_1": _tool_call("toolu_1", "time__get_current_time"),
|
||||
},
|
||||
stop_reason=LlmStopReason.TOOL_USE,
|
||||
)
|
||||
|
||||
_run(emitter._emit(msg))
|
||||
|
||||
assert len(ctx.session.calls) == 1
|
||||
data = ctx.session.calls[0]["data"]
|
||||
assert data["stop_reason"] == "toolUse"
|
||||
assert data["content"] == [{"type": "text", "text": "Logging this…"}]
|
||||
assert data["tool_calls"] == [{"id": "toolu_1", "name": "time__get_current_time"}]
|
||||
|
||||
|
||||
def test_emit_skips_completely_empty_iteration() -> None:
|
||||
"""A turn with no text blocks and no tool calls emits nothing.
|
||||
|
||||
Tool-call lifecycle is already covered by notifications/progress. An
|
||||
empty assistant_chunk would just be noise on the wire and a no-op
|
||||
update on the live bubble.
|
||||
"""
|
||||
ctx = _FakeContext()
|
||||
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||
|
||||
msg = PromptMessageExtended(role="assistant", content=[])
|
||||
|
||||
_run(emitter._emit(msg))
|
||||
|
||||
assert ctx.session.calls == []
|
||||
# Iteration counter still bumps so subsequent chunks aren't mis-numbered.
|
||||
assert emitter._iteration == 1
|
||||
|
||||
|
||||
def test_emit_image_block_passes_through_with_mime_type_renamed() -> None:
|
||||
"""ImageContent blocks are serialized with ``mime_type`` (snake-case)."""
|
||||
ctx = _FakeContext()
|
||||
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id=None)
|
||||
|
||||
msg = PromptMessageExtended(
|
||||
role="assistant",
|
||||
content=[ImageContent(type="image", data="ZmFrZQ==", mimeType="image/png")],
|
||||
stop_reason=LlmStopReason.END_TURN,
|
||||
)
|
||||
|
||||
_run(emitter._emit(msg))
|
||||
|
||||
data = ctx.session.calls[0]["data"]
|
||||
assert data["content"] == [
|
||||
{"type": "image", "data": "ZmFrZQ==", "mime_type": "image/png"}
|
||||
]
|
||||
assert data["conversation_id"] is None
|
||||
|
||||
|
||||
def test_emit_iterations_are_numbered_in_order() -> None:
|
||||
"""A multi-iteration loop produces sequentially numbered chunks."""
|
||||
ctx = _FakeContext()
|
||||
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||
|
||||
async def drive() -> None:
|
||||
await emitter._emit(
|
||||
PromptMessageExtended(
|
||||
role="assistant",
|
||||
content=[TextContent(type="text", text="first")],
|
||||
stop_reason=LlmStopReason.TOOL_USE,
|
||||
)
|
||||
)
|
||||
await emitter._emit(
|
||||
PromptMessageExtended(
|
||||
role="assistant",
|
||||
content=[TextContent(type="text", text="second")],
|
||||
stop_reason=LlmStopReason.TOOL_USE,
|
||||
)
|
||||
)
|
||||
await emitter._emit(
|
||||
PromptMessageExtended(
|
||||
role="assistant",
|
||||
content=[TextContent(type="text", text="third — done")],
|
||||
stop_reason=LlmStopReason.END_TURN,
|
||||
)
|
||||
)
|
||||
|
||||
_run(drive())
|
||||
|
||||
assert [c["data"]["iteration"] for c in ctx.session.calls] == [1, 2, 3]
|
||||
assert [c["data"]["stop_reason"] for c in ctx.session.calls] == [
|
||||
"toolUse",
|
||||
"toolUse",
|
||||
"endTurn",
|
||||
]
|
||||
|
||||
|
||||
def test_emit_swallows_session_failure() -> None:
|
||||
"""If ``send_log_message`` raises, the hook does not propagate."""
|
||||
ctx = _FakeContext()
|
||||
ctx.session.fail_with = RuntimeError("transport closed")
|
||||
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||
|
||||
msg = PromptMessageExtended(
|
||||
role="assistant",
|
||||
content=[TextContent(type="text", text="hi")],
|
||||
stop_reason=LlmStopReason.END_TURN,
|
||||
)
|
||||
|
||||
# Must not raise.
|
||||
_run(emitter._emit(msg))
|
||||
# And no successful calls were recorded (fail_with raised before append).
|
||||
assert ctx.session.calls == []
|
||||
|
||||
|
||||
def test_install_for_request_merges_with_existing_after_llm_call() -> None:
|
||||
"""A pre-existing ``after_llm_call`` hook is composed, not replaced."""
|
||||
ctx = _FakeContext()
|
||||
agent = _FakeAgent()
|
||||
|
||||
seen: list[str] = []
|
||||
|
||||
async def base_after(_runner: Any, message: PromptMessageExtended) -> None:
|
||||
seen.append("base")
|
||||
|
||||
agent.tool_runner_hooks = ToolRunnerHooks(after_llm_call=base_after)
|
||||
|
||||
restore = install_for_request(
|
||||
agent, ctx=ctx, agent_name="alan", conversation_id="conv-1"
|
||||
)
|
||||
|
||||
msg = PromptMessageExtended(
|
||||
role="assistant",
|
||||
content=[TextContent(type="text", text="hi")],
|
||||
stop_reason=LlmStopReason.END_TURN,
|
||||
)
|
||||
|
||||
_run(agent.tool_runner_hooks.after_llm_call(None, msg))
|
||||
|
||||
# Base ran, and the assistant-stream emitter shipped the chunk.
|
||||
assert seen == ["base"]
|
||||
assert len(ctx.session.calls) == 1
|
||||
assert ctx.session.calls[0]["data"]["content"] == [
|
||||
{"type": "text", "text": "hi"}
|
||||
]
|
||||
|
||||
# Other hook slots stay untouched.
|
||||
assert agent.tool_runner_hooks.before_llm_call is None
|
||||
assert agent.tool_runner_hooks.before_tool_call is None
|
||||
assert agent.tool_runner_hooks.after_tool_call is None
|
||||
assert agent.tool_runner_hooks.after_turn_complete is None
|
||||
|
||||
# Restore puts the original hooks back exactly.
|
||||
restore()
|
||||
assert agent.tool_runner_hooks is not None
|
||||
assert agent.tool_runner_hooks.after_llm_call is base_after
|
||||
|
||||
|
||||
def test_install_for_request_with_no_existing_hooks() -> None:
|
||||
"""When the agent has no prior hooks, ours is installed cleanly."""
|
||||
ctx = _FakeContext()
|
||||
agent = _FakeAgent()
|
||||
assert agent.tool_runner_hooks is None
|
||||
|
||||
restore = install_for_request(
|
||||
agent, ctx=ctx, agent_name="alan", conversation_id=None
|
||||
)
|
||||
|
||||
assert agent.tool_runner_hooks is not None
|
||||
assert agent.tool_runner_hooks.after_llm_call is not None
|
||||
|
||||
restore()
|
||||
assert agent.tool_runner_hooks is None
|
||||
Reference in New Issue
Block a user