diff --git a/pallas/multimodal_server.py b/pallas/multimodal_server.py index a760ec7..8cd5d7c 100644 --- a/pallas/multimodal_server.py +++ b/pallas/multimodal_server.py @@ -28,6 +28,7 @@ from fast_agent.core.logging.logger import get_logger from fast_agent.mcp.server import AgentMCPServer from fast_agent.types import PromptMessageExtended, RequestParams +from pallas.assistant_stream import install_for_request as _install_assistant_stream from pallas.progress import EnrichedMCPToolProgressManager from pallas import metrics as _pallas_metrics from fastmcp import Context as MCPContext @@ -220,6 +221,20 @@ class MultimodalAgentMCPServer(AgentMCPServer): agent_context = getattr(agent, "context", None) metrics_start = time.perf_counter() metrics_outcome = "ok" + + # Install per-request after_llm_call hook that ships every + # intermediate assistant turn over MCP as a notifications/message. + # Without this, only the final ``agent.send()`` return value + # crosses the MCP boundary — substantive assistant text emitted + # in earlier loop iterations stays trapped inside fast-agent's + # ``message_history`` and the user sees a spinner that ends with + # a thin wrap-up sentence. + restore_hooks = _install_assistant_stream( + agent, + ctx=ctx, + agent_name=agent_name, + conversation_id=conversation_id, + ) try: # Seed the freshly-created instance's message_history from the # caller-supplied history so the agent sees the full @@ -313,6 +328,14 @@ class MultimodalAgentMCPServer(AgentMCPServer): _pallas_metrics.send_message_total.labels( agent=agent_name, outcome=metrics_outcome ).inc() + # Restore the agent's prior tool_runner_hooks before the + # instance is released — defensive against any future + # shared-instance mode where leaking per-request hooks + # across requests would mis-attribute notifications. + try: + restore_hooks() + except Exception: + pass await self._release_instance(ctx, instance) diff --git a/tests/test_assistant_stream.py b/tests/test_assistant_stream.py new file mode 100644 index 0000000..43119f4 --- /dev/null +++ b/tests/test_assistant_stream.py @@ -0,0 +1,298 @@ +"""Tests for ``pallas.assistant_stream``. + +Drives the ``after_llm_call`` hook with handcrafted ``PromptMessageExtended`` +objects and asserts the resulting MCP ``send_log_message`` payload shape. +No fast-agent runtime is involved — the hook is a pure async function and +the MCP context is faked. + +Tests use ``asyncio.run`` directly to match the convention in +``tests/test_health.py`` and ``tests/test_mantle_shims.py`` (pallas has no +pytest-asyncio dependency). +""" +from __future__ import annotations + +import asyncio +from typing import Any + +from fast_agent.agents.tool_runner import ToolRunnerHooks +from fast_agent.types import PromptMessageExtended +from fast_agent.types.llm_stop_reason import LlmStopReason +from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + ImageContent, + TextContent, +) + +from pallas.assistant_stream import ( + KIND, + LOGGER_NAME, + SCHEMA_VERSION, + AssistantChunkEmitter, + install_for_request, +) + + +def _run(coro): + return asyncio.run(coro) + + +# ── Fakes ──────────────────────────────────────────────────────────────────── + + +class _FakeSession: + """Records every call to ``send_log_message`` for later assertion.""" + + def __init__(self) -> None: + self.calls: list[dict[str, Any]] = [] + self.fail_with: Exception | None = None + + async def send_log_message( + self, + *, + level: str, + data: Any, + logger: str | None = None, + related_request_id: Any = None, + ) -> None: + if self.fail_with is not None: + raise self.fail_with + self.calls.append( + { + "level": level, + "data": data, + "logger": logger, + "related_request_id": related_request_id, + } + ) + + +class _FakeContext: + def __init__(self, request_id: str = "req-1") -> None: + self.session = _FakeSession() + self.request_id = request_id + + +class _FakeAgent: + """Minimal stand-in for a fast-agent agent — only carries hooks.""" + + def __init__(self) -> None: + self.tool_runner_hooks: ToolRunnerHooks | None = None + + +def _tool_call(call_id: str, name: str, arguments: dict | None = None) -> CallToolRequest: + return CallToolRequest( + method="tools/call", + params=CallToolRequestParams(name=name, arguments=arguments or {}), + ) + + +# ── Tests ──────────────────────────────────────────────────────────────────── + + +def test_emit_text_only_iteration() -> None: + """A pure-text assistant turn produces one log message with one text block.""" + ctx = _FakeContext(request_id="req-text") + emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") + + msg = PromptMessageExtended( + role="assistant", + content=[TextContent(type="text", text="Fair. You can't size value...")], + stop_reason=LlmStopReason.END_TURN, + ) + + _run(emitter._emit(msg)) + + assert len(ctx.session.calls) == 1 + call = ctx.session.calls[0] + assert call["level"] == "info" + assert call["logger"] == LOGGER_NAME + assert call["related_request_id"] == "req-text" + + data = call["data"] + assert data["kind"] == KIND + assert data["schema_version"] == SCHEMA_VERSION + assert data["agent"] == "alan" + assert data["conversation_id"] == "conv-1" + assert data["iteration"] == 1 + assert data["stop_reason"] == "endTurn" + assert data["content"] == [{"type": "text", "text": "Fair. You can't size value..."}] + assert data["tool_calls"] == [] + + +def test_emit_text_with_tool_call_iteration() -> None: + """An assistant turn that emits text and then calls a tool ships both.""" + ctx = _FakeContext() + emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") + + msg = PromptMessageExtended( + role="assistant", + content=[TextContent(type="text", text="Logging this…")], + tool_calls={ + "toolu_1": _tool_call("toolu_1", "time__get_current_time"), + }, + stop_reason=LlmStopReason.TOOL_USE, + ) + + _run(emitter._emit(msg)) + + assert len(ctx.session.calls) == 1 + data = ctx.session.calls[0]["data"] + assert data["stop_reason"] == "toolUse" + assert data["content"] == [{"type": "text", "text": "Logging this…"}] + assert data["tool_calls"] == [{"id": "toolu_1", "name": "time__get_current_time"}] + + +def test_emit_skips_completely_empty_iteration() -> None: + """A turn with no text blocks and no tool calls emits nothing. + + Tool-call lifecycle is already covered by notifications/progress. An + empty assistant_chunk would just be noise on the wire and a no-op + update on the live bubble. + """ + ctx = _FakeContext() + emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") + + msg = PromptMessageExtended(role="assistant", content=[]) + + _run(emitter._emit(msg)) + + assert ctx.session.calls == [] + # Iteration counter still bumps so subsequent chunks aren't mis-numbered. + assert emitter._iteration == 1 + + +def test_emit_image_block_passes_through_with_mime_type_renamed() -> None: + """ImageContent blocks are serialized with ``mime_type`` (snake-case).""" + ctx = _FakeContext() + emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id=None) + + msg = PromptMessageExtended( + role="assistant", + content=[ImageContent(type="image", data="ZmFrZQ==", mimeType="image/png")], + stop_reason=LlmStopReason.END_TURN, + ) + + _run(emitter._emit(msg)) + + data = ctx.session.calls[0]["data"] + assert data["content"] == [ + {"type": "image", "data": "ZmFrZQ==", "mime_type": "image/png"} + ] + assert data["conversation_id"] is None + + +def test_emit_iterations_are_numbered_in_order() -> None: + """A multi-iteration loop produces sequentially numbered chunks.""" + ctx = _FakeContext() + emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") + + async def drive() -> None: + await emitter._emit( + PromptMessageExtended( + role="assistant", + content=[TextContent(type="text", text="first")], + stop_reason=LlmStopReason.TOOL_USE, + ) + ) + await emitter._emit( + PromptMessageExtended( + role="assistant", + content=[TextContent(type="text", text="second")], + stop_reason=LlmStopReason.TOOL_USE, + ) + ) + await emitter._emit( + PromptMessageExtended( + role="assistant", + content=[TextContent(type="text", text="third — done")], + stop_reason=LlmStopReason.END_TURN, + ) + ) + + _run(drive()) + + assert [c["data"]["iteration"] for c in ctx.session.calls] == [1, 2, 3] + assert [c["data"]["stop_reason"] for c in ctx.session.calls] == [ + "toolUse", + "toolUse", + "endTurn", + ] + + +def test_emit_swallows_session_failure() -> None: + """If ``send_log_message`` raises, the hook does not propagate.""" + ctx = _FakeContext() + ctx.session.fail_with = RuntimeError("transport closed") + emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1") + + msg = PromptMessageExtended( + role="assistant", + content=[TextContent(type="text", text="hi")], + stop_reason=LlmStopReason.END_TURN, + ) + + # Must not raise. + _run(emitter._emit(msg)) + # And no successful calls were recorded (fail_with raised before append). + assert ctx.session.calls == [] + + +def test_install_for_request_merges_with_existing_after_llm_call() -> None: + """A pre-existing ``after_llm_call`` hook is composed, not replaced.""" + ctx = _FakeContext() + agent = _FakeAgent() + + seen: list[str] = [] + + async def base_after(_runner: Any, message: PromptMessageExtended) -> None: + seen.append("base") + + agent.tool_runner_hooks = ToolRunnerHooks(after_llm_call=base_after) + + restore = install_for_request( + agent, ctx=ctx, agent_name="alan", conversation_id="conv-1" + ) + + msg = PromptMessageExtended( + role="assistant", + content=[TextContent(type="text", text="hi")], + stop_reason=LlmStopReason.END_TURN, + ) + + _run(agent.tool_runner_hooks.after_llm_call(None, msg)) + + # Base ran, and the assistant-stream emitter shipped the chunk. + assert seen == ["base"] + assert len(ctx.session.calls) == 1 + assert ctx.session.calls[0]["data"]["content"] == [ + {"type": "text", "text": "hi"} + ] + + # Other hook slots stay untouched. + assert agent.tool_runner_hooks.before_llm_call is None + assert agent.tool_runner_hooks.before_tool_call is None + assert agent.tool_runner_hooks.after_tool_call is None + assert agent.tool_runner_hooks.after_turn_complete is None + + # Restore puts the original hooks back exactly. + restore() + assert agent.tool_runner_hooks is not None + assert agent.tool_runner_hooks.after_llm_call is base_after + + +def test_install_for_request_with_no_existing_hooks() -> None: + """When the agent has no prior hooks, ours is installed cleanly.""" + ctx = _FakeContext() + agent = _FakeAgent() + assert agent.tool_runner_hooks is None + + restore = install_for_request( + agent, ctx=ctx, agent_name="alan", conversation_id=None + ) + + assert agent.tool_runner_hooks is not None + assert agent.tool_runner_hooks.after_llm_call is not None + + restore() + assert agent.tool_runner_hooks is None