Files
pallas/tests/test_assistant_stream.py

403 lines
13 KiB
Python

"""Tests for ``pallas.assistant_stream``.
Drives the ``after_llm_call`` / ``before_tool_call`` / ``after_tool_call``
hooks with handcrafted ``PromptMessageExtended`` objects and asserts the
resulting MCP ``send_log_message`` payload shape. No fast-agent runtime is
involved — the hooks are pure async functions and the MCP context is faked.
Tests use ``asyncio.run`` directly to match the convention in
``tests/test_health.py`` and ``tests/test_mantle_shims.py`` (pallas has no
pytest-asyncio dependency).
"""
from __future__ import annotations
import asyncio
from typing import Any
from fast_agent.agents.tool_runner import ToolRunnerHooks
from fast_agent.types import PromptMessageExtended
from fast_agent.types.llm_stop_reason import LlmStopReason
from mcp.types import (
CallToolRequest,
CallToolRequestParams,
CallToolResult,
ImageContent,
TextContent,
)
from pallas.assistant_stream import (
KIND,
KIND_RESULTS,
LOGGER_NAME,
SCHEMA_VERSION,
AssistantChunkEmitter,
install_for_request,
)
def _run(coro):
return asyncio.run(coro)
# ── Fakes ────────────────────────────────────────────────────────────────────
class _FakeSession:
"""Records every call to ``send_log_message`` for later assertion."""
def __init__(self) -> None:
self.calls: list[dict[str, Any]] = []
self.fail_with: Exception | None = None
async def send_log_message(
self,
*,
level: str,
data: Any,
logger: str | None = None,
related_request_id: Any = None,
) -> None:
if self.fail_with is not None:
raise self.fail_with
self.calls.append(
{
"level": level,
"data": data,
"logger": logger,
"related_request_id": related_request_id,
}
)
class _FakeContext:
def __init__(self, request_id: str = "req-1") -> None:
self.session = _FakeSession()
self.request_id = request_id
class _FakeAgent:
"""Minimal stand-in for a fast-agent agent — only carries hooks."""
def __init__(self) -> None:
self.tool_runner_hooks: ToolRunnerHooks | None = None
def _tool_call(name: str, arguments: dict | None = None) -> CallToolRequest:
return CallToolRequest(
method="tools/call",
params=CallToolRequestParams(name=name, arguments=arguments or {}),
)
def _tool_result(
text: str | None = "ok", *, is_error: bool = False
) -> CallToolResult:
content: list = []
if text is not None:
content.append(TextContent(type="text", text=text))
return CallToolResult(content=content, isError=is_error)
# ── Tests ────────────────────────────────────────────────────────────────────
def test_emit_text_only_iteration() -> None:
"""A pure-text assistant turn produces one log message with one text block."""
ctx = _FakeContext(request_id="req-text")
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
msg = PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="Fair. You can't size value...")],
stop_reason=LlmStopReason.END_TURN,
)
_run(emitter._emit_assistant_chunk(msg))
assert len(ctx.session.calls) == 1
call = ctx.session.calls[0]
assert call["level"] == "info"
assert call["logger"] == LOGGER_NAME
assert call["related_request_id"] == "req-text"
data = call["data"]
assert data["kind"] == KIND
assert data["schema_version"] == SCHEMA_VERSION
assert data["agent"] == "alan"
assert data["conversation_id"] == "conv-1"
assert data["iteration"] == 1
assert data["stop_reason"] == "endTurn"
assert data["content"] == [{"type": "text", "text": "Fair. You can't size value..."}]
assert data["tool_calls"] == []
def test_emit_text_with_tool_call_iteration_carries_args_preview() -> None:
"""An iteration that requests a tool ships name, server prefix, and args preview."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
msg = PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="Logging this…")],
tool_calls={
"toolu_1": _tool_call(
"argos-search_web",
{"query": "ducati v2 vs v4 reliability"},
),
},
stop_reason=LlmStopReason.TOOL_USE,
)
_run(emitter._emit_assistant_chunk(msg))
assert len(ctx.session.calls) == 1
data = ctx.session.calls[0]["data"]
assert data["stop_reason"] == "toolUse"
assert data["content"] == [{"type": "text", "text": "Logging this…"}]
assert data["tool_calls"] == [
{
"id": "toolu_1",
"name": "argos-search_web",
"server": "argos",
"arguments_preview": "ducati v2 vs v4 reliability",
}
]
def test_emit_skips_completely_empty_iteration() -> None:
"""A turn with no text blocks and no tool calls emits nothing."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
msg = PromptMessageExtended(role="assistant", content=[])
_run(emitter._emit_assistant_chunk(msg))
assert ctx.session.calls == []
# Iteration counter still bumps so subsequent chunks aren't mis-numbered.
assert emitter._iteration == 1
def test_emit_image_block_passes_through_with_mime_type_renamed() -> None:
"""ImageContent blocks are serialized with ``mime_type`` (snake-case)."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id=None)
msg = PromptMessageExtended(
role="assistant",
content=[ImageContent(type="image", data="ZmFrZQ==", mimeType="image/png")],
stop_reason=LlmStopReason.END_TURN,
)
_run(emitter._emit_assistant_chunk(msg))
data = ctx.session.calls[0]["data"]
assert data["content"] == [
{"type": "image", "data": "ZmFrZQ==", "mime_type": "image/png"}
]
assert data["conversation_id"] is None
def test_emit_iterations_are_numbered_in_order() -> None:
"""A multi-iteration loop produces sequentially numbered chunks."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
async def drive() -> None:
await emitter._emit_assistant_chunk(
PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="first")],
stop_reason=LlmStopReason.TOOL_USE,
)
)
await emitter._emit_assistant_chunk(
PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="second")],
stop_reason=LlmStopReason.TOOL_USE,
)
)
await emitter._emit_assistant_chunk(
PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="third — done")],
stop_reason=LlmStopReason.END_TURN,
)
)
_run(drive())
assert [c["data"]["iteration"] for c in ctx.session.calls] == [1, 2, 3]
assert [c["data"]["stop_reason"] for c in ctx.session.calls] == [
"toolUse",
"toolUse",
"endTurn",
]
def test_emit_swallows_session_failure() -> None:
"""If ``send_log_message`` raises, the hook does not propagate."""
ctx = _FakeContext()
ctx.session.fail_with = RuntimeError("transport closed")
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
msg = PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="hi")],
stop_reason=LlmStopReason.END_TURN,
)
_run(emitter._emit_assistant_chunk(msg))
assert ctx.session.calls == []
def test_emit_tool_results_pairs_call_id_with_iteration() -> None:
"""``after_tool_call`` ships a results payload keyed to the iteration that called."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
# Iteration 1: text + tool request.
iter1 = PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="Searching…")],
tool_calls={"toolu_1": _tool_call("argos-search_web", {"query": "foo"})},
stop_reason=LlmStopReason.TOOL_USE,
)
async def drive() -> None:
# Simulate the runner: after_llm_call → before_tool_call → after_tool_call.
await emitter._emit_assistant_chunk(iter1)
await emitter.as_before_tool_call_hook()(None, iter1)
# The "user" message produced by the tool runner carries tool_results.
tool_msg = PromptMessageExtended(
role="user",
content=[],
tool_results={"toolu_1": _tool_result("12 results found")},
)
await emitter.as_after_tool_call_hook()(None, tool_msg)
_run(drive())
assert len(ctx.session.calls) == 2
chunk = ctx.session.calls[0]["data"]
results = ctx.session.calls[1]["data"]
assert chunk["kind"] == KIND
assert chunk["iteration"] == 1
assert results["kind"] == KIND_RESULTS
assert results["iteration"] == 1
assert results["agent"] == "alan"
assert results["conversation_id"] == "conv-1"
assert len(results["results"]) == 1
entry = results["results"][0]
assert entry["id"] == "toolu_1"
assert entry["ok"] is True
assert entry["result_preview"] == "12 results found"
assert isinstance(entry["duration_ms"], int)
assert entry["duration_ms"] >= 0
def test_emit_tool_results_marks_error() -> None:
"""A failing tool call surfaces ``ok: False`` in the results entry."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="c")
iter1 = PromptMessageExtended(
role="assistant",
content=[],
tool_calls={"toolu_1": _tool_call("argos-search_web", {"query": "foo"})},
stop_reason=LlmStopReason.TOOL_USE,
)
async def drive() -> None:
await emitter._emit_assistant_chunk(iter1)
await emitter.as_before_tool_call_hook()(None, iter1)
tool_msg = PromptMessageExtended(
role="user",
content=[],
tool_results={
"toolu_1": _tool_result("connection refused", is_error=True)
},
)
await emitter.as_after_tool_call_hook()(None, tool_msg)
_run(drive())
# iter1 is a pure-tool turn (no text + has tool_calls) → still emits an
# assistant_chunk because tool_calls is non-empty.
assert len(ctx.session.calls) == 2
results = ctx.session.calls[1]["data"]
assert results["results"][0]["ok"] is False
assert results["results"][0]["result_preview"] == "connection refused"
def test_emit_tool_results_empty_when_no_results() -> None:
"""An ``after_tool_call`` carrying no tool_results emits nothing."""
ctx = _FakeContext()
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="c")
msg = PromptMessageExtended(role="user", content=[], tool_results=None)
_run(emitter.as_after_tool_call_hook()(None, msg))
assert ctx.session.calls == []
def test_install_for_request_merges_with_existing_after_llm_call() -> None:
"""A pre-existing ``after_llm_call`` hook is composed, not replaced."""
ctx = _FakeContext()
agent = _FakeAgent()
seen: list[str] = []
async def base_after(_runner: Any, message: PromptMessageExtended) -> None:
seen.append("base")
agent.tool_runner_hooks = ToolRunnerHooks(after_llm_call=base_after)
restore = install_for_request(
agent, ctx=ctx, agent_name="alan", conversation_id="conv-1"
)
msg = PromptMessageExtended(
role="assistant",
content=[TextContent(type="text", text="hi")],
stop_reason=LlmStopReason.END_TURN,
)
_run(agent.tool_runner_hooks.after_llm_call(None, msg))
assert seen == ["base"]
assert len(ctx.session.calls) == 1
assert ctx.session.calls[0]["data"]["content"] == [
{"type": "text", "text": "hi"}
]
# All four hook slots used by the emitter are now bound; the restore
# call puts the original hooks back exactly.
assert agent.tool_runner_hooks.before_tool_call is not None
assert agent.tool_runner_hooks.after_tool_call is not None
restore()
assert agent.tool_runner_hooks is not None
assert agent.tool_runner_hooks.after_llm_call is base_after
assert agent.tool_runner_hooks.before_tool_call is None
assert agent.tool_runner_hooks.after_tool_call is None
def test_install_for_request_with_no_existing_hooks() -> None:
"""When the agent has no prior hooks, ours is installed cleanly."""
ctx = _FakeContext()
agent = _FakeAgent()
assert agent.tool_runner_hooks is None
restore = install_for_request(
agent, ctx=ctx, agent_name="alan", conversation_id=None
)
assert agent.tool_runner_hooks is not None
assert agent.tool_runner_hooks.after_llm_call is not None
assert agent.tool_runner_hooks.before_tool_call is not None
assert agent.tool_runner_hooks.after_tool_call is not None
restore()
assert agent.tool_runner_hooks is None