Files
pallas/pallas/multimodal_server.py
Robert Helewka 440f7fb60c feat: add per-agent loop safeguards for tool-call turns
Introduce three optional per-agent config fields to bound tool-call loop
execution: `max_iterations` (default 15), `streaming_timeout` (default
120s), and `turn_timeout` (default 300s wall-clock).

- Plumb limits from agent config through `_build_agents_table` and
  `_start_agent` into `MultimodalAgentMCPServer` via `request_limits`
- Apply `max_iterations` and `streaming_timeout` to `RequestParams`
- Wrap turn dispatch in `asyncio.wait_for` to enforce `turn_timeout`,
  logging a warning on timeout
- Document the new fields in README
2026-05-27 05:41:08 -04:00

350 lines
15 KiB
Python

"""
MultimodalAgentMCPServer — AgentMCPServer subclass with images support.
Overrides register_agent_tools to:
* accept an optional ``images`` parameter on each agent's ``send_message``
tool so callers can attach base64-encoded images alongside the text,
* accept an optional ``history`` parameter (list of role/content dicts)
so callers own conversation state and seed it on every turn,
* accept an optional ``conversation_id`` string that is recorded in
structured logs and progress notification metadata for end-to-end
trace correlation.
Drop-in replacement for AgentMCPServer. When combined with
``instance_scope="request"`` (the Pallas default), this gives a fully
stateless bridge: each MCP ``tools/call`` is handled by a freshly-created
fast-agent instance whose ``message_history`` is seeded from the caller's
``history`` argument — no cross-conversation bleed, no process-lifetime
memory, no restart amnesia.
"""
import asyncio
import time
from typing import Any
import fast_agent.core.prompt
from fast_agent.core.logging.logger import get_logger
from fast_agent.mcp.server import AgentMCPServer
from fast_agent.types import PromptMessageExtended, RequestParams
from pallas.progress import EnrichedMCPToolProgressManager
from pallas import metrics as _pallas_metrics
from fastmcp import Context as MCPContext
from fastmcp.prompts import Message
from mcp.types import ImageContent, TextContent
from prometheus_client import CONTENT_TYPE_LATEST, generate_latest
from starlette.responses import JSONResponse, Response
logger = get_logger(__name__)
def _history_to_fastmcp_messages(
message_history: list[PromptMessageExtended],
) -> list[Message]:
"""Convert stored agent history into FastMCP prompt messages."""
from fast_agent.mcp.prompts.prompt_server import convert_to_fastmcp_messages
prompt_messages = fast_agent.core.prompt.Prompt.from_multipart(message_history)
return convert_to_fastmcp_messages(prompt_messages)
def _history_payload_to_multipart(
history: list[dict] | None,
) -> list[PromptMessageExtended]:
"""Convert the caller-supplied ``history`` argument to PromptMessageExtended.
Each entry must be a mapping with at least ``role`` ("user"|"assistant")
and ``content`` (str). An optional ``images`` list may contain
``{"data": base64, "mime_type": str}`` entries; they are appended to the
same turn as additional ``ImageContent`` blocks.
Entries that cannot be coerced (missing/invalid role, non-string content,
malformed images) are skipped with a warning — the remaining history is
still seeded so a single bad row cannot wipe an entire conversation.
"""
if not history:
return []
out: list[PromptMessageExtended] = []
for idx, entry in enumerate(history):
if not isinstance(entry, dict):
logger.warning(
f"history entry {idx} is not a dict; skipping",
name="history_entry_invalid",
index=idx,
)
continue
role = entry.get("role")
if role not in ("user", "assistant"):
logger.warning(
f"history entry {idx} has invalid role {role!r}; skipping",
name="history_entry_invalid_role",
index=idx,
role=role,
)
continue
content_text = entry.get("content", "")
if not isinstance(content_text, str):
content_text = str(content_text or "")
blocks: list[Any] = []
if content_text:
blocks.append(TextContent(type="text", text=content_text))
images = entry.get("images") or []
if isinstance(images, list):
for img_idx, img in enumerate(images):
if not isinstance(img, dict):
continue
data = img.get("data")
mime = img.get("mime_type") or img.get("mimeType")
if not data or not mime:
logger.warning(
f"history entry {idx} image {img_idx} missing data/mime_type",
name="history_image_invalid",
index=idx,
image_index=img_idx,
)
continue
blocks.append(
ImageContent(type="image", data=data, mimeType=mime)
)
if not blocks:
# An empty turn conveys nothing — skip rather than emit a zero-block
# PromptMessageExtended which the LLM adapter would reject.
continue
out.append(PromptMessageExtended(role=role, content=blocks))
return out
class MultimodalAgentMCPServer(AgentMCPServer):
"""AgentMCPServer with optional image + history support on send_message."""
def __init__(self, *args, request_limits: dict | None = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._request_limits = request_limits or {}
self._register_health_routes()
def _register_health_routes(self) -> None:
"""Add /live, /ready, and /metrics to this agent's HTTP server.
Uses FastMCP's custom_route decorator — the same mechanism used by
fast-agent itself for the root ``/`` info route. HAProxy can health
check individual agent backends at ``/ready``.
"""
@self.mcp_server.custom_route("/live", methods=["GET"])
async def live(request):
return JSONResponse({"status": "alive"})
@self.mcp_server.custom_route("/ready", methods=["GET"])
async def ready(request):
return JSONResponse({"status": "ready"})
@self.mcp_server.custom_route("/metrics", methods=["GET"])
async def metrics(request):
# Serve the process-global Pallas registry so this per-agent
# endpoint exposes the same snapshot as the deployment-wide
# registry endpoint. Useful when scraping a single agent
# directly (e.g. behind HAProxy per-backend).
data = generate_latest(_pallas_metrics.REGISTRY)
return Response(content=data, media_type=CONTENT_TYPE_LATEST)
def register_agent_tools(self, agent_name: str) -> None:
"""Register a send_message tool that accepts text + optional images + history."""
self._registered_agents.add(agent_name)
tool_description = (
self._tool_description.format(agent=agent_name)
if self._tool_description and "{agent}" in self._tool_description
else self._tool_description
)
agent_obj = self.primary_instance.agents.get(agent_name)
agent_description = None
if agent_obj is not None:
config = getattr(agent_obj, "config", None)
agent_description = getattr(config, "description", None)
tool_name = self._tool_name_template.format(agent=agent_name)
@self.mcp_server.tool(
name=tool_name,
description=tool_description
or agent_description
or f"Send a message to the {agent_name} agent",
)
async def send_message(
message: str,
ctx: MCPContext,
images: list[dict] | None = None,
history: list[dict] | None = None,
conversation_id: str | None = None,
) -> str:
"""Send a single turn to the agent.
Parameters
----------
message:
The new user turn, plain text.
images:
Optional list of ``{"data": base64, "mime_type": str}`` image
attachments sent with this turn. Requires a vision-capable
model.
history:
Optional prior conversation history as a list of
``{"role": "user"|"assistant", "content": str, "images": [...]}``
entries in chronological order. When provided, seeds the
freshly-created agent's ``message_history`` before executing
the new turn. Pallas never persists this — the caller
(typically Daedalus) owns conversation state.
conversation_id:
Optional opaque identifier, logged for trace correlation.
Pallas does not interpret it.
"""
report_progress = self._build_progress_reporter(ctx)
request_params = RequestParams(
tool_execution_handler=EnrichedMCPToolProgressManager(report_progress),
emit_loop_progress=True,
max_iterations=self._request_limits.get("max_iterations", 15),
streaming_timeout=self._request_limits.get("streaming_timeout", 120.0),
)
instance = await self._acquire_instance(ctx)
agent = instance.app[agent_name]
agent_context = getattr(agent, "context", None)
metrics_start = time.perf_counter()
metrics_outcome = "ok"
try:
# Seed the freshly-created instance's message_history from the
# caller-supplied history so the agent sees the full
# conversation the caller is tracking. Safe no-op when the
# instance is scoped "shared" because load_message_history
# replaces existing history in that case too — but callers
# should only pass history when talking to a "request"-scoped
# agent. With an empty/absent history this is skipped so
# shared-mode deployments retain today's behaviour.
history_count = 0
if history:
seeded = _history_payload_to_multipart(history)
if seeded:
agent.load_message_history(seeded)
history_count = len(seeded)
if images:
content: list = [TextContent(type="text", text=message)]
for img in images:
content.append(
ImageContent(
type="image",
data=img["data"],
mimeType=img["mime_type"],
)
)
payload: str | PromptMessageExtended = PromptMessageExtended(
role="user", content=content
)
else:
payload = message
async def execute_send() -> str:
start = time.perf_counter()
logger.debug(
f"MCP request received for agent '{agent_name}'",
name="mcp_request_start",
agent=agent_name,
session=self._session_identifier(ctx),
conversation_id=conversation_id,
history_count=history_count,
image_count=len(images) if images else 0,
)
response = await agent.send(payload, request_params=request_params)
duration = time.perf_counter() - start
logger.debug(
f"Agent '{agent_name}' completed MCP request",
name="mcp_request_complete",
agent=agent_name,
duration=duration,
session=self._session_identifier(ctx),
conversation_id=conversation_id,
)
return response
turn_timeout = self._request_limits.get("turn_timeout", 300.0)
async def _dispatch() -> str:
if agent_context and ctx:
return await self.with_bridged_context(
agent_context, ctx, execute_send
)
return await execute_send()
try:
return await asyncio.wait_for(_dispatch(), timeout=turn_timeout)
except asyncio.TimeoutError:
logger.warning(
f"Agent '{agent_name}' turn exceeded {turn_timeout}s wall-clock limit",
name="turn_timeout",
agent=agent_name,
turn_timeout=turn_timeout,
conversation_id=conversation_id,
)
raise
except BaseException:
metrics_outcome = "error"
raise
finally:
# Capture token usage before disposal — the request-scoped
# instance is torn down inside _release_instance and the
# accumulator goes with it.
try:
accumulator = getattr(agent, "usage_accumulator", None)
_pallas_metrics.record_usage(agent_name, accumulator)
except Exception:
pass
_pallas_metrics.send_message_duration_seconds.labels(
agent=agent_name
).observe(time.perf_counter() - metrics_start)
_pallas_metrics.send_message_total.labels(
agent=agent_name, outcome=metrics_outcome
).inc()
await self._release_instance(ctx, instance)
if self._instance_scope == "request":
# With request-scoped instances there is no persistent server-side
# history to expose — the caller owns it. We still register the
# prompt so clients that query `{agent}_history` get a well-formed
# empty response rather than a 404, but it always returns [].
@self.mcp_server.prompt(
name=f"{agent_name}_history",
description=(
f"Conversation history for the {agent_name} agent "
"(always empty — Pallas is stateless; the caller owns history)"
),
)
async def get_history_prompt_stateless(ctx: MCPContext) -> list[Message]:
return []
return
@self.mcp_server.prompt(
name=f"{agent_name}_history",
description=f"Conversation history for the {agent_name} agent",
)
async def get_history_prompt(ctx: MCPContext) -> list[Message]:
instance = await self._acquire_instance(ctx)
agent = instance.app[agent_name]
try:
multipart_history = agent.message_history
if not multipart_history:
return []
return _history_to_fastmcp_messages(multipart_history)
finally:
await self._release_instance(ctx, instance, reuse_connection=True)