Compare commits
2 Commits
440f7fb60c
...
63ced9ba2b
| Author | SHA1 | Date | |
|---|---|---|---|
| 63ced9ba2b | |||
| 8a5046fef0 |
246
pallas/assistant_stream.py
Normal file
246
pallas/assistant_stream.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
"""
|
||||||
|
Mid-turn assistant chunk streaming over MCP.
|
||||||
|
|
||||||
|
The MCP ``tools/call`` contract returns a single ``CallToolResult`` at end of
|
||||||
|
turn. When fast-agent runs a multi-iteration tool loop, every intermediate
|
||||||
|
assistant message — including substantive natural-language replies the LLM
|
||||||
|
emits before deciding to call a tool — stays inside fast-agent's
|
||||||
|
``message_history`` and never crosses the MCP boundary. The caller (Daedalus)
|
||||||
|
sees only the final ``last_text()``, which is often a thin wrap-up sentence
|
||||||
|
while the substantive answer was produced two iterations earlier.
|
||||||
|
|
||||||
|
This module fixes that by emitting one MCP ``notifications/message``
|
||||||
|
(``LoggingMessageNotification``) per assistant turn, carrying the structured
|
||||||
|
content blocks as opaque JSON in the ``data`` field. The transport is
|
||||||
|
already plumbed through StreamableHTTP — the SDK delivers it to the client's
|
||||||
|
``logging_callback`` — so no SDK changes are needed on either side.
|
||||||
|
|
||||||
|
The hook installs onto fast-agent's ``ToolRunnerHooks.after_llm_call``,
|
||||||
|
which runs once per LLM iteration in the tool loop. Pallas merges this hook
|
||||||
|
with whatever hooks the agent (or fast-agent itself) already has via the
|
||||||
|
existing ``_merge_tool_runner_hooks`` plumbing — see ``multimodal_server.py``
|
||||||
|
``_install_assistant_stream_hook`` for the merge-and-restore wiring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
from fast_agent.agents.tool_runner import ToolRunnerHooks
|
||||||
|
from fast_agent.types import PromptMessageExtended
|
||||||
|
from mcp.types import ImageContent, TextContent
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastmcp import Context as MCPContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Logger string the client uses to filter our notifications from any other
|
||||||
|
# log messages the server might emit. Stays narrow on both sides.
|
||||||
|
LOGGER_NAME = "pallas.assistant_stream"
|
||||||
|
|
||||||
|
# Discriminator inside ``data`` so this transport stays multiplexable later.
|
||||||
|
KIND = "assistant_chunk"
|
||||||
|
|
||||||
|
# Bumped if the on-wire payload shape changes incompatibly. Daedalus reads
|
||||||
|
# this and ignores chunks whose schema_version it doesn't recognise.
|
||||||
|
SCHEMA_VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantChunkEmitter:
|
||||||
|
"""Per-request emitter that ships each assistant turn over MCP.
|
||||||
|
|
||||||
|
One instance per ``send_message`` MCP call. Tracks an iteration counter
|
||||||
|
so consumers can order chunks within a single turn (and ignore stale
|
||||||
|
ones from an aborted retry).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ctx: "MCPContext",
|
||||||
|
*,
|
||||||
|
agent_name: str,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._ctx = ctx
|
||||||
|
self._agent_name = agent_name
|
||||||
|
self._conversation_id = conversation_id
|
||||||
|
self._iteration = 0
|
||||||
|
|
||||||
|
def as_after_llm_call_hook(self):
|
||||||
|
"""Return an async callable suitable for ``ToolRunnerHooks.after_llm_call``."""
|
||||||
|
|
||||||
|
async def _hook(_runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
await self._emit(message)
|
||||||
|
|
||||||
|
return _hook
|
||||||
|
|
||||||
|
async def _emit(self, message: PromptMessageExtended) -> None:
|
||||||
|
"""Serialize one assistant turn and ship it as a log notification.
|
||||||
|
|
||||||
|
Failures here must never break the agent turn — wrap everything and
|
||||||
|
log a warning. The user-facing consequence of a dropped chunk is
|
||||||
|
"the live bubble didn't update once"; the canonical final message
|
||||||
|
still arrives via the tools/call response.
|
||||||
|
"""
|
||||||
|
self._iteration += 1
|
||||||
|
try:
|
||||||
|
payload = self._build_payload(message)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"assistant_stream payload build failed",
|
||||||
|
exc_info=True,
|
||||||
|
extra={
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
"iteration": self._iteration,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip empty turns (e.g. a pure tool-only iteration with no text).
|
||||||
|
# The tool-call lifecycle is already surfaced via
|
||||||
|
# notifications/progress, so an empty assistant_chunk would just be
|
||||||
|
# noise on the wire and the live bubble.
|
||||||
|
if not payload["content"] and not payload["tool_calls"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
session = self._ctx.session
|
||||||
|
related_request_id = getattr(self._ctx, "request_id", None)
|
||||||
|
await session.send_log_message(
|
||||||
|
level="info",
|
||||||
|
data=payload,
|
||||||
|
logger=LOGGER_NAME,
|
||||||
|
related_request_id=related_request_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"assistant_stream send_log_message failed",
|
||||||
|
exc_info=True,
|
||||||
|
extra={
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
"iteration": self._iteration,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_payload(self, message: PromptMessageExtended) -> dict[str, Any]:
|
||||||
|
content_blocks: list[dict[str, Any]] = []
|
||||||
|
for block in message.content or []:
|
||||||
|
if isinstance(block, TextContent) and block.text:
|
||||||
|
content_blocks.append({"type": "text", "text": block.text})
|
||||||
|
elif isinstance(block, ImageContent):
|
||||||
|
# Images on assistant turns are unusual but legal. We send
|
||||||
|
# the raw base64 — same shape the final tools/call result
|
||||||
|
# already uses so the client can render it identically.
|
||||||
|
content_blocks.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"data": block.data,
|
||||||
|
"mime_type": block.mimeType,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Other block types (resources, embedded resources, etc.) are
|
||||||
|
# intentionally omitted from the live stream — they're rare on
|
||||||
|
# assistant turns and the canonical final message carries them.
|
||||||
|
|
||||||
|
tool_calls: list[dict[str, Any]] = []
|
||||||
|
for call_id, call in (message.tool_calls or {}).items():
|
||||||
|
params = getattr(call, "params", None)
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": call_id,
|
||||||
|
"name": getattr(params, "name", None),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stop_reason = message.stop_reason
|
||||||
|
if stop_reason is not None:
|
||||||
|
# LlmStopReason is a str-Enum; .value is the wire form fast-agent
|
||||||
|
# already uses ("toolUse", "endTurn", ...).
|
||||||
|
stop_reason_str = getattr(stop_reason, "value", str(stop_reason))
|
||||||
|
else:
|
||||||
|
stop_reason_str = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"kind": KIND,
|
||||||
|
"schema_version": SCHEMA_VERSION,
|
||||||
|
"agent": self._agent_name,
|
||||||
|
"conversation_id": self._conversation_id,
|
||||||
|
"iteration": self._iteration,
|
||||||
|
"stop_reason": stop_reason_str,
|
||||||
|
"content": content_blocks,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def install_for_request(
|
||||||
|
agent: Any,
|
||||||
|
*,
|
||||||
|
ctx: "MCPContext",
|
||||||
|
agent_name: str,
|
||||||
|
conversation_id: str | None,
|
||||||
|
):
|
||||||
|
"""Install the assistant-stream hook on a request-scoped agent instance.
|
||||||
|
|
||||||
|
Returns a no-arg callable that restores the agent's previous
|
||||||
|
``tool_runner_hooks`` — call it in a ``finally`` so per-request hooks
|
||||||
|
don't leak across requests if the underlying instance is ever reused.
|
||||||
|
|
||||||
|
Pallas runs ``instance_scope="request"`` so in normal operation the
|
||||||
|
instance is disposed immediately after; the restore call is defensive
|
||||||
|
against config changes or future shared-instance modes.
|
||||||
|
"""
|
||||||
|
emitter = AssistantChunkEmitter(
|
||||||
|
ctx, agent_name=agent_name, conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
extra = ToolRunnerHooks(after_llm_call=emitter.as_after_llm_call_hook())
|
||||||
|
|
||||||
|
previous = getattr(agent, "tool_runner_hooks", None)
|
||||||
|
merged = _merge_after_llm_call(previous, extra)
|
||||||
|
agent.tool_runner_hooks = merged
|
||||||
|
|
||||||
|
def restore() -> None:
|
||||||
|
agent.tool_runner_hooks = previous
|
||||||
|
|
||||||
|
return restore
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_after_llm_call(
|
||||||
|
base: ToolRunnerHooks | None, extra: ToolRunnerHooks
|
||||||
|
) -> ToolRunnerHooks:
|
||||||
|
"""Compose ``extra.after_llm_call`` after ``base.after_llm_call``.
|
||||||
|
|
||||||
|
Mirrors fast-agent's own ``ToolAgent._merge_tool_runner_hooks`` (private)
|
||||||
|
so we don't depend on its signature. Only the ``after_llm_call`` slot
|
||||||
|
actually gets composed here — the others are passed through unchanged
|
||||||
|
from ``base`` since ``extra`` only ever sets ``after_llm_call``.
|
||||||
|
"""
|
||||||
|
if base is None:
|
||||||
|
return extra
|
||||||
|
|
||||||
|
base_hook = base.after_llm_call
|
||||||
|
extra_hook = extra.after_llm_call
|
||||||
|
|
||||||
|
if base_hook is None:
|
||||||
|
merged_after = extra_hook
|
||||||
|
elif extra_hook is None: # pragma: no cover - extra always sets it
|
||||||
|
merged_after = base_hook
|
||||||
|
else:
|
||||||
|
|
||||||
|
async def merged(runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
await base_hook(runner, message)
|
||||||
|
await extra_hook(runner, message)
|
||||||
|
|
||||||
|
merged_after = merged
|
||||||
|
|
||||||
|
return ToolRunnerHooks(
|
||||||
|
before_llm_call=base.before_llm_call,
|
||||||
|
after_llm_call=merged_after,
|
||||||
|
before_tool_call=base.before_tool_call,
|
||||||
|
after_tool_call=base.after_tool_call,
|
||||||
|
after_turn_complete=base.after_turn_complete,
|
||||||
|
)
|
||||||
@@ -28,6 +28,7 @@ from fast_agent.core.logging.logger import get_logger
|
|||||||
from fast_agent.mcp.server import AgentMCPServer
|
from fast_agent.mcp.server import AgentMCPServer
|
||||||
from fast_agent.types import PromptMessageExtended, RequestParams
|
from fast_agent.types import PromptMessageExtended, RequestParams
|
||||||
|
|
||||||
|
from pallas.assistant_stream import install_for_request as _install_assistant_stream
|
||||||
from pallas.progress import EnrichedMCPToolProgressManager
|
from pallas.progress import EnrichedMCPToolProgressManager
|
||||||
from pallas import metrics as _pallas_metrics
|
from pallas import metrics as _pallas_metrics
|
||||||
from fastmcp import Context as MCPContext
|
from fastmcp import Context as MCPContext
|
||||||
@@ -220,6 +221,20 @@ class MultimodalAgentMCPServer(AgentMCPServer):
|
|||||||
agent_context = getattr(agent, "context", None)
|
agent_context = getattr(agent, "context", None)
|
||||||
metrics_start = time.perf_counter()
|
metrics_start = time.perf_counter()
|
||||||
metrics_outcome = "ok"
|
metrics_outcome = "ok"
|
||||||
|
|
||||||
|
# Install per-request after_llm_call hook that ships every
|
||||||
|
# intermediate assistant turn over MCP as a notifications/message.
|
||||||
|
# Without this, only the final ``agent.send()`` return value
|
||||||
|
# crosses the MCP boundary — substantive assistant text emitted
|
||||||
|
# in earlier loop iterations stays trapped inside fast-agent's
|
||||||
|
# ``message_history`` and the user sees a spinner that ends with
|
||||||
|
# a thin wrap-up sentence.
|
||||||
|
restore_hooks = _install_assistant_stream(
|
||||||
|
agent,
|
||||||
|
ctx=ctx,
|
||||||
|
agent_name=agent_name,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# Seed the freshly-created instance's message_history from the
|
# Seed the freshly-created instance's message_history from the
|
||||||
# caller-supplied history so the agent sees the full
|
# caller-supplied history so the agent sees the full
|
||||||
@@ -313,6 +328,14 @@ class MultimodalAgentMCPServer(AgentMCPServer):
|
|||||||
_pallas_metrics.send_message_total.labels(
|
_pallas_metrics.send_message_total.labels(
|
||||||
agent=agent_name, outcome=metrics_outcome
|
agent=agent_name, outcome=metrics_outcome
|
||||||
).inc()
|
).inc()
|
||||||
|
# Restore the agent's prior tool_runner_hooks before the
|
||||||
|
# instance is released — defensive against any future
|
||||||
|
# shared-instance mode where leaking per-request hooks
|
||||||
|
# across requests would mis-attribute notifications.
|
||||||
|
try:
|
||||||
|
restore_hooks()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
await self._release_instance(ctx, instance)
|
await self._release_instance(ctx, instance)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
298
tests/test_assistant_stream.py
Normal file
298
tests/test_assistant_stream.py
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
"""Tests for ``pallas.assistant_stream``.
|
||||||
|
|
||||||
|
Drives the ``after_llm_call`` hook with handcrafted ``PromptMessageExtended``
|
||||||
|
objects and asserts the resulting MCP ``send_log_message`` payload shape.
|
||||||
|
No fast-agent runtime is involved — the hook is a pure async function and
|
||||||
|
the MCP context is faked.
|
||||||
|
|
||||||
|
Tests use ``asyncio.run`` directly to match the convention in
|
||||||
|
``tests/test_health.py`` and ``tests/test_mantle_shims.py`` (pallas has no
|
||||||
|
pytest-asyncio dependency).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fast_agent.agents.tool_runner import ToolRunnerHooks
|
||||||
|
from fast_agent.types import PromptMessageExtended
|
||||||
|
from fast_agent.types.llm_stop_reason import LlmStopReason
|
||||||
|
from mcp.types import (
|
||||||
|
CallToolRequest,
|
||||||
|
CallToolRequestParams,
|
||||||
|
ImageContent,
|
||||||
|
TextContent,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pallas.assistant_stream import (
|
||||||
|
KIND,
|
||||||
|
LOGGER_NAME,
|
||||||
|
SCHEMA_VERSION,
|
||||||
|
AssistantChunkEmitter,
|
||||||
|
install_for_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fakes ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
"""Records every call to ``send_log_message`` for later assertion."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls: list[dict[str, Any]] = []
|
||||||
|
self.fail_with: Exception | None = None
|
||||||
|
|
||||||
|
async def send_log_message(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
level: str,
|
||||||
|
data: Any,
|
||||||
|
logger: str | None = None,
|
||||||
|
related_request_id: Any = None,
|
||||||
|
) -> None:
|
||||||
|
if self.fail_with is not None:
|
||||||
|
raise self.fail_with
|
||||||
|
self.calls.append(
|
||||||
|
{
|
||||||
|
"level": level,
|
||||||
|
"data": data,
|
||||||
|
"logger": logger,
|
||||||
|
"related_request_id": related_request_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeContext:
|
||||||
|
def __init__(self, request_id: str = "req-1") -> None:
|
||||||
|
self.session = _FakeSession()
|
||||||
|
self.request_id = request_id
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeAgent:
|
||||||
|
"""Minimal stand-in for a fast-agent agent — only carries hooks."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.tool_runner_hooks: ToolRunnerHooks | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call(call_id: str, name: str, arguments: dict | None = None) -> CallToolRequest:
|
||||||
|
return CallToolRequest(
|
||||||
|
method="tools/call",
|
||||||
|
params=CallToolRequestParams(name=name, arguments=arguments or {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_text_only_iteration() -> None:
|
||||||
|
"""A pure-text assistant turn produces one log message with one text block."""
|
||||||
|
ctx = _FakeContext(request_id="req-text")
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="Fair. You can't size value...")],
|
||||||
|
stop_reason=LlmStopReason.END_TURN,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(emitter._emit(msg))
|
||||||
|
|
||||||
|
assert len(ctx.session.calls) == 1
|
||||||
|
call = ctx.session.calls[0]
|
||||||
|
assert call["level"] == "info"
|
||||||
|
assert call["logger"] == LOGGER_NAME
|
||||||
|
assert call["related_request_id"] == "req-text"
|
||||||
|
|
||||||
|
data = call["data"]
|
||||||
|
assert data["kind"] == KIND
|
||||||
|
assert data["schema_version"] == SCHEMA_VERSION
|
||||||
|
assert data["agent"] == "alan"
|
||||||
|
assert data["conversation_id"] == "conv-1"
|
||||||
|
assert data["iteration"] == 1
|
||||||
|
assert data["stop_reason"] == "endTurn"
|
||||||
|
assert data["content"] == [{"type": "text", "text": "Fair. You can't size value..."}]
|
||||||
|
assert data["tool_calls"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_text_with_tool_call_iteration() -> None:
|
||||||
|
"""An assistant turn that emits text and then calls a tool ships both."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="Logging this…")],
|
||||||
|
tool_calls={
|
||||||
|
"toolu_1": _tool_call("toolu_1", "time__get_current_time"),
|
||||||
|
},
|
||||||
|
stop_reason=LlmStopReason.TOOL_USE,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(emitter._emit(msg))
|
||||||
|
|
||||||
|
assert len(ctx.session.calls) == 1
|
||||||
|
data = ctx.session.calls[0]["data"]
|
||||||
|
assert data["stop_reason"] == "toolUse"
|
||||||
|
assert data["content"] == [{"type": "text", "text": "Logging this…"}]
|
||||||
|
assert data["tool_calls"] == [{"id": "toolu_1", "name": "time__get_current_time"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_skips_completely_empty_iteration() -> None:
|
||||||
|
"""A turn with no text blocks and no tool calls emits nothing.
|
||||||
|
|
||||||
|
Tool-call lifecycle is already covered by notifications/progress. An
|
||||||
|
empty assistant_chunk would just be noise on the wire and a no-op
|
||||||
|
update on the live bubble.
|
||||||
|
"""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(role="assistant", content=[])
|
||||||
|
|
||||||
|
_run(emitter._emit(msg))
|
||||||
|
|
||||||
|
assert ctx.session.calls == []
|
||||||
|
# Iteration counter still bumps so subsequent chunks aren't mis-numbered.
|
||||||
|
assert emitter._iteration == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_image_block_passes_through_with_mime_type_renamed() -> None:
|
||||||
|
"""ImageContent blocks are serialized with ``mime_type`` (snake-case)."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id=None)
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[ImageContent(type="image", data="ZmFrZQ==", mimeType="image/png")],
|
||||||
|
stop_reason=LlmStopReason.END_TURN,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(emitter._emit(msg))
|
||||||
|
|
||||||
|
data = ctx.session.calls[0]["data"]
|
||||||
|
assert data["content"] == [
|
||||||
|
{"type": "image", "data": "ZmFrZQ==", "mime_type": "image/png"}
|
||||||
|
]
|
||||||
|
assert data["conversation_id"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_iterations_are_numbered_in_order() -> None:
|
||||||
|
"""A multi-iteration loop produces sequentially numbered chunks."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||||
|
|
||||||
|
async def drive() -> None:
|
||||||
|
await emitter._emit(
|
||||||
|
PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="first")],
|
||||||
|
stop_reason=LlmStopReason.TOOL_USE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await emitter._emit(
|
||||||
|
PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="second")],
|
||||||
|
stop_reason=LlmStopReason.TOOL_USE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await emitter._emit(
|
||||||
|
PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="third — done")],
|
||||||
|
stop_reason=LlmStopReason.END_TURN,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(drive())
|
||||||
|
|
||||||
|
assert [c["data"]["iteration"] for c in ctx.session.calls] == [1, 2, 3]
|
||||||
|
assert [c["data"]["stop_reason"] for c in ctx.session.calls] == [
|
||||||
|
"toolUse",
|
||||||
|
"toolUse",
|
||||||
|
"endTurn",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_emit_swallows_session_failure() -> None:
|
||||||
|
"""If ``send_log_message`` raises, the hook does not propagate."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
ctx.session.fail_with = RuntimeError("transport closed")
|
||||||
|
emitter = AssistantChunkEmitter(ctx, agent_name="alan", conversation_id="conv-1")
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="hi")],
|
||||||
|
stop_reason=LlmStopReason.END_TURN,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Must not raise.
|
||||||
|
_run(emitter._emit(msg))
|
||||||
|
# And no successful calls were recorded (fail_with raised before append).
|
||||||
|
assert ctx.session.calls == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_install_for_request_merges_with_existing_after_llm_call() -> None:
|
||||||
|
"""A pre-existing ``after_llm_call`` hook is composed, not replaced."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
agent = _FakeAgent()
|
||||||
|
|
||||||
|
seen: list[str] = []
|
||||||
|
|
||||||
|
async def base_after(_runner: Any, message: PromptMessageExtended) -> None:
|
||||||
|
seen.append("base")
|
||||||
|
|
||||||
|
agent.tool_runner_hooks = ToolRunnerHooks(after_llm_call=base_after)
|
||||||
|
|
||||||
|
restore = install_for_request(
|
||||||
|
agent, ctx=ctx, agent_name="alan", conversation_id="conv-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = PromptMessageExtended(
|
||||||
|
role="assistant",
|
||||||
|
content=[TextContent(type="text", text="hi")],
|
||||||
|
stop_reason=LlmStopReason.END_TURN,
|
||||||
|
)
|
||||||
|
|
||||||
|
_run(agent.tool_runner_hooks.after_llm_call(None, msg))
|
||||||
|
|
||||||
|
# Base ran, and the assistant-stream emitter shipped the chunk.
|
||||||
|
assert seen == ["base"]
|
||||||
|
assert len(ctx.session.calls) == 1
|
||||||
|
assert ctx.session.calls[0]["data"]["content"] == [
|
||||||
|
{"type": "text", "text": "hi"}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Other hook slots stay untouched.
|
||||||
|
assert agent.tool_runner_hooks.before_llm_call is None
|
||||||
|
assert agent.tool_runner_hooks.before_tool_call is None
|
||||||
|
assert agent.tool_runner_hooks.after_tool_call is None
|
||||||
|
assert agent.tool_runner_hooks.after_turn_complete is None
|
||||||
|
|
||||||
|
# Restore puts the original hooks back exactly.
|
||||||
|
restore()
|
||||||
|
assert agent.tool_runner_hooks is not None
|
||||||
|
assert agent.tool_runner_hooks.after_llm_call is base_after
|
||||||
|
|
||||||
|
|
||||||
|
def test_install_for_request_with_no_existing_hooks() -> None:
|
||||||
|
"""When the agent has no prior hooks, ours is installed cleanly."""
|
||||||
|
ctx = _FakeContext()
|
||||||
|
agent = _FakeAgent()
|
||||||
|
assert agent.tool_runner_hooks is None
|
||||||
|
|
||||||
|
restore = install_for_request(
|
||||||
|
agent, ctx=ctx, agent_name="alan", conversation_id=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent.tool_runner_hooks is not None
|
||||||
|
assert agent.tool_runner_hooks.after_llm_call is not None
|
||||||
|
|
||||||
|
restore()
|
||||||
|
assert agent.tool_runner_hooks is None
|
||||||
Reference in New Issue
Block a user