"""Tests for ``pallas.assistant_stream``. Drives the ``after_llm_call`` / ``before_tool_call`` / ``after_tool_call`` hooks with handcrafted ``PromptMessageExtended`` objects and asserts the resulting MCP ``send_log_message`` payload shape. No fast-agent runtime is involved — the hooks are pure async functions and the MCP context is faked. Tests use ``asyncio.run`` directly to match the convention in ``tests/test_health.py`` and ``tests/test_mantle_shims.py`` (pallas has no pytest-asyncio dependency). """ from __future__ import annotations import asyncio from typing import Any from fast_agent.agents.tool_runner import ToolRunnerHooks from fast_agent.types import PromptMessageExtended from fast_agent.types.llm_stop_reason import LlmStopReason from mcp.types import ( CallToolRequest, CallToolRequestParams, CallToolResult, ImageContent, TextContent, ) from pallas.assistant_stream import ( KIND, KIND_RESULTS, LOGGER_NAME, SCHEMA_VERSION, AssistantChunkEmitter, install_for_request, ) def _run(coro): return asyncio.run(coro) # ── Fakes ──────────────────────────────────────────────────────────────────── class _FakeSession: """Records every call to ``send_log_message`` for later assertion.""" def __init__(self) -> None: self.calls: list[dict[str, Any]] = [] self.fail_with: Exception | None = None async def send_log_message( self, *, level: str, data: Any, logger: str | None = None, related_request_id: Any = None, ) -> None: if self.fail_with is not None: raise self.fail_with self.calls.append( { "level": level, "data": data, "logger": logger, "related_request_id": related_request_id, } ) class _FakeContext: def __init__(self, request_id: str = "req-1") -> None: self.session = _FakeSession() self.request_id = request_id class _FakeAgent: """Minimal stand-in for a fast-agent agent — only carries hooks.""" def __init__(self) -> None: self.tool_runner_hooks: ToolRunnerHooks | None = None def _tool_call(name: str, arguments: dict | None = None) -> CallToolRequest: return CallToolRequest( method="tools/call", params=CallToolRequestParams(name=name, arguments=arguments or {}), ) def _tool_result( text: str | None = "ok", *, is_error: bool = False ) -> CallToolResult: content: list = [] if text is not None: content.append(TextContent(type="text", text=text)) return CallToolResult(content=content, isError=is_error) # ── Tests ──────────────────────────────────────────────────────────────────── def test_emit_text_only_iteration() -> None: """A pure-text assistant turn produces one log message with one text block.""" ctx = _FakeContext(request_id="req-text") emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") msg = PromptMessageExtended( role="assistant", content=[TextContent(type="text", text="Fair. You can't size value...")], stop_reason=LlmStopReason.END_TURN, ) _run(emitter._emit_assistant_chunk(msg)) assert len(ctx.session.calls) == 1 call = ctx.session.calls[0] assert call["level"] == "info" assert call["logger"] == LOGGER_NAME assert call["related_request_id"] == "req-text" data = call["data"] assert data["kind"] == KIND assert data["schema_version"] == SCHEMA_VERSION assert data["agent"] == "alan" assert data["conversation_id"] == "conv-1" assert data["iteration"] == 1 assert data["stop_reason"] == "endTurn" assert data["content"] == [{"type": "text", "text": "Fair. You can't size value..."}] assert data["tool_calls"] == [] def test_emit_text_with_tool_call_iteration_carries_args_preview() -> None: """An iteration that requests a tool ships name, server prefix, and args preview.""" ctx = _FakeContext() emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") msg = PromptMessageExtended( role="assistant", content=[TextContent(type="text", text="Logging this…")], tool_calls={ "toolu_1": _tool_call( "argos-search_web", {"query": "ducati v2 vs v4 reliability"}, ), }, stop_reason=LlmStopReason.TOOL_USE, ) _run(emitter._emit_assistant_chunk(msg)) assert len(ctx.session.calls) == 1 data = ctx.session.calls[0]["data"] assert data["stop_reason"] == "toolUse" assert data["content"] == [{"type": "text", "text": "Logging this…"}] assert data["tool_calls"] == [ { "id": "toolu_1", "name": "argos-search_web", "server": "argos", "arguments_preview": "ducati v2 vs v4 reliability", } ] def test_emit_skips_completely_empty_iteration() -> None: """A turn with no text blocks and no tool calls emits nothing.""" ctx = _FakeContext() emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") msg = PromptMessageExtended(role="assistant", content=[]) _run(emitter._emit_assistant_chunk(msg)) assert ctx.session.calls == [] # Iteration counter still bumps so subsequent chunks aren't mis-numbered. assert emitter._iteration == 1 def test_emit_image_block_passes_through_with_mime_type_renamed() -> None: """ImageContent blocks are serialized with ``mime_type`` (snake-case).""" ctx = _FakeContext() emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id=None) msg = PromptMessageExtended( role="assistant", content=[ImageContent(type="image", data="ZmFrZQ==", mimeType="image/png")], stop_reason=LlmStopReason.END_TURN, ) _run(emitter._emit_assistant_chunk(msg)) data = ctx.session.calls[0]["data"] assert data["content"] == [ {"type": "image", "data": "ZmFrZQ==", "mime_type": "image/png"} ] assert data["conversation_id"] is None def test_emit_iterations_are_numbered_in_order() -> None: """A multi-iteration loop produces sequentially numbered chunks.""" ctx = _FakeContext() emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") async def drive() -> None: await emitter._emit_assistant_chunk( PromptMessageExtended( role="assistant", content=[TextContent(type="text", text="first")], stop_reason=LlmStopReason.TOOL_USE, ) ) await emitter._emit_assistant_chunk( PromptMessageExtended( role="assistant", content=[TextContent(type="text", text="second")], stop_reason=LlmStopReason.TOOL_USE, ) ) await emitter._emit_assistant_chunk( PromptMessageExtended( role="assistant", content=[TextContent(type="text", text="third — done")], stop_reason=LlmStopReason.END_TURN, ) ) _run(drive()) assert [c["data"]["iteration"] for c in ctx.session.calls] == [1, 2, 3] assert [c["data"]["stop_reason"] for c in ctx.session.calls] == [ "toolUse", "toolUse", "endTurn", ] def test_emit_swallows_session_failure() -> None: """If ``send_log_message`` raises, the hook does not propagate.""" ctx = _FakeContext() ctx.session.fail_with = RuntimeError("transport closed") emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") msg = PromptMessageExtended( role="assistant", content=[TextContent(type="text", text="hi")], stop_reason=LlmStopReason.END_TURN, ) _run(emitter._emit_assistant_chunk(msg)) assert ctx.session.calls == [] def test_emit_tool_results_pairs_call_id_with_iteration() -> None: """``after_tool_call`` ships a results payload keyed to the iteration that called.""" ctx = _FakeContext() emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") # Iteration 1: text + tool request. iter1 = PromptMessageExtended( role="assistant", content=[TextContent(type="text", text="Searching…")], tool_calls={"toolu_1": _tool_call("argos-search_web", {"query": "foo"})}, stop_reason=LlmStopReason.TOOL_USE, ) async def drive() -> None: # Simulate the runner: after_llm_call → before_tool_call → after_tool_call. await emitter._emit_assistant_chunk(iter1) await emitter.as_before_tool_call_hook()(None, iter1) # The "user" message produced by the tool runner carries tool_results. tool_msg = PromptMessageExtended( role="user", content=[], tool_results={"toolu_1": _tool_result("12 results found")}, ) await emitter.as_after_tool_call_hook()(None, tool_msg) _run(drive()) assert len(ctx.session.calls) == 2 chunk = ctx.session.calls[0]["data"] results = ctx.session.calls[1]["data"] assert chunk["kind"] == KIND assert chunk["iteration"] == 1 assert results["kind"] == KIND_RESULTS assert results["iteration"] == 1 assert results["agent"] == "alan" assert results["conversation_id"] == "conv-1" assert len(results["results"]) == 1 entry = results["results"][0] assert entry["id"] == "toolu_1" assert entry["ok"] is True assert entry["result_preview"] == "12 results found" assert isinstance(entry["duration_ms"], int) assert entry["duration_ms"] >= 0 def test_emit_tool_results_marks_error() -> None: """A failing tool call surfaces ``ok: False`` in the results entry.""" ctx = _FakeContext() emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="c") iter1 = PromptMessageExtended( role="assistant", content=[], tool_calls={"toolu_1": _tool_call("argos-search_web", {"query": "foo"})}, stop_reason=LlmStopReason.TOOL_USE, ) async def drive() -> None: await emitter._emit_assistant_chunk(iter1) await emitter.as_before_tool_call_hook()(None, iter1) tool_msg = PromptMessageExtended( role="user", content=[], tool_results={ "toolu_1": _tool_result("connection refused", is_error=True) }, ) await emitter.as_after_tool_call_hook()(None, tool_msg) _run(drive()) # iter1 is a pure-tool turn (no text + has tool_calls) → still emits an # assistant_chunk because tool_calls is non-empty. assert len(ctx.session.calls) == 2 results = ctx.session.calls[1]["data"] assert results["results"][0]["ok"] is False assert results["results"][0]["result_preview"] == "connection refused" def test_emit_tool_results_empty_when_no_results() -> None: """An ``after_tool_call`` carrying no tool_results emits nothing.""" ctx = _FakeContext() emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="c") msg = PromptMessageExtended(role="user", content=[], tool_results=None) _run(emitter.as_after_tool_call_hook()(None, msg)) assert ctx.session.calls == [] def test_install_for_request_merges_with_existing_after_llm_call() -> None: """A pre-existing ``after_llm_call`` hook is composed, not replaced.""" ctx = _FakeContext() agent = _FakeAgent() seen: list[str] = [] async def base_after(_runner: Any, message: PromptMessageExtended) -> None: seen.append("base") agent.tool_runner_hooks = ToolRunnerHooks(after_llm_call=base_after) restore = install_for_request( agent, ctx=ctx, agent_name="alan", conversation_id="conv-1" ) msg = PromptMessageExtended( role="assistant", content=[TextContent(type="text", text="hi")], stop_reason=LlmStopReason.END_TURN, ) _run(agent.tool_runner_hooks.after_llm_call(None, msg)) assert seen == ["base"] assert len(ctx.session.calls) == 1 assert ctx.session.calls[0]["data"]["content"] == [ {"type": "text", "text": "hi"} ] # All four hook slots used by the emitter are now bound; the restore # call puts the original hooks back exactly. assert agent.tool_runner_hooks.before_tool_call is not None assert agent.tool_runner_hooks.after_tool_call is not None restore() assert agent.tool_runner_hooks is not None assert agent.tool_runner_hooks.after_llm_call is base_after assert agent.tool_runner_hooks.before_tool_call is None assert agent.tool_runner_hooks.after_tool_call is None def test_install_for_request_with_no_existing_hooks() -> None: """When the agent has no prior hooks, ours is installed cleanly.""" ctx = _FakeContext() agent = _FakeAgent() assert agent.tool_runner_hooks is None restore = install_for_request( agent, ctx=ctx, agent_name="alan", conversation_id=None ) assert agent.tool_runner_hooks is not None assert agent.tool_runner_hooks.after_llm_call is not None assert agent.tool_runner_hooks.before_tool_call is not None assert agent.tool_runner_hooks.after_tool_call is not None restore() assert agent.tool_runner_hooks is None