diff --git a/pallas/multimodal_server.py b/pallas/multimodal_server.py index 001a88f..a0d9ce8 100644 --- a/pallas/multimodal_server.py +++ b/pallas/multimodal_server.py @@ -23,8 +23,9 @@ import fast_agent.core.prompt from fast_agent.core.logging.logger import get_logger from fast_agent.mcp.auth.context import request_bearer_token from fast_agent.mcp.server import AgentMCPServer -from fast_agent.mcp.tool_progress import MCPToolProgressManager from fast_agent.types import PromptMessageExtended, RequestParams + +from pallas.progress import EnrichedMCPToolProgressManager from fastmcp import Context as MCPContext from fastmcp.prompts import Message from mcp.types import ImageContent, TextContent @@ -117,7 +118,7 @@ class MultimodalAgentMCPServer(AgentMCPServer): saved_token = request_bearer_token.set(_get_request_bearer_token()) report_progress = self._build_progress_reporter(ctx) request_params = RequestParams( - tool_execution_handler=MCPToolProgressManager(report_progress), + tool_execution_handler=EnrichedMCPToolProgressManager(report_progress), emit_loop_progress=True, ) try: diff --git a/pallas/progress.py b/pallas/progress.py new file mode 100644 index 0000000..97d81a8 --- /dev/null +++ b/pallas/progress.py @@ -0,0 +1,149 @@ +""" +Enriched MCP progress reporting for Pallas agents. + +Wraps fast-agent's ``MCPToolProgressManager`` so that progress notifications +sent over MCP carry semantic detail — argument previews on tool start, +result summaries on tool completion — instead of bare ``started`` / +``completed`` markers. + +The MCP ``notifications/progress`` payload only has room for a string +``message``, so all extra detail is packed into that one field. Daedalus +renders it verbatim, which is enough for the operator to tell whether an +agent is making real progress or spinning. +""" + +from __future__ import annotations + +import json +from typing import Any + +from fast_agent.mcp.tool_progress import MCPToolProgressManager +from mcp.types import ContentBlock, ImageContent, TextContent + + +_MAX_ARGS_PREVIEW = 120 +_MAX_RESULT_PREVIEW = 160 + + +def _stringify(value: Any) -> str: + if isinstance(value, str): + return value + try: + return json.dumps(value, ensure_ascii=False, default=str) + except (TypeError, ValueError): + return str(value) + + +def _truncate(text: str, limit: int) -> str: + text = " ".join(text.split()) + if len(text) <= limit: + return text + return text[: limit - 1].rstrip() + "…" + + +def format_args_preview(arguments: dict | None) -> str: + """Render a one-line preview of tool arguments for progress messages. + + Picks the most informative value when there are multiple keys: the + longest string wins (queries, prompts, code), otherwise falls back + to a compact ``key=value, ...`` rendering. + """ + if not arguments: + return "" + + string_values = [ + (k, v) for k, v in arguments.items() if isinstance(v, str) and v.strip() + ] + if string_values: + key, value = max(string_values, key=lambda kv: len(kv[1])) + return _truncate(f"{key}={value}" if len(arguments) > 1 else value, _MAX_ARGS_PREVIEW) + + parts = [f"{k}={_stringify(v)}" for k, v in arguments.items()] + return _truncate(", ".join(parts), _MAX_ARGS_PREVIEW) + + +def format_result_preview(content: list[ContentBlock] | None) -> str: + """Render a one-line preview of a tool's result content.""" + if not content: + return "ok" + + text_parts: list[str] = [] + image_count = 0 + other_count = 0 + for block in content: + if isinstance(block, TextContent) and block.text: + text_parts.append(block.text) + elif isinstance(block, ImageContent): + image_count += 1 + else: + other_count += 1 + + if text_parts: + joined = " ".join(text_parts) + preview = _truncate(joined, _MAX_RESULT_PREVIEW) + suffix_bits = [] + if image_count: + suffix_bits.append(f"+{image_count} image{'s' if image_count != 1 else ''}") + if other_count: + suffix_bits.append(f"+{other_count} block{'s' if other_count != 1 else ''}") + if suffix_bits: + return f"{preview} ({', '.join(suffix_bits)})" + return preview + + if image_count or other_count: + bits = [] + if image_count: + bits.append(f"{image_count} image{'s' if image_count != 1 else ''}") + if other_count: + bits.append(f"{other_count} block{'s' if other_count != 1 else ''}") + return ", ".join(bits) + + return "ok" + + +class EnrichedMCPToolProgressManager(MCPToolProgressManager): + """Pallas progress manager that adds arg/result previews to MCP messages. + + Drop-in replacement for ``MCPToolProgressManager``. Overrides only the + label and start/complete messaging — progress and permission events + keep the parent's behaviour. + """ + + async def on_tool_start( + self, + tool_name: str, + server_name: str, + arguments: dict | None, + tool_use_id: str | None = None, + ) -> str: + tool_call_id = await super().on_tool_start( + tool_name, server_name, arguments, tool_use_id + ) + preview = format_args_preview(arguments) + if preview: + label = self._tool_labels.get(tool_call_id, f"{server_name}/{tool_name}") + self._tool_labels[tool_call_id] = f"{label}({preview})" + return tool_call_id + + async def on_tool_complete( + self, + tool_call_id: str, + success: bool, + content: list[ContentBlock] | None, + error: str | None, + ) -> None: + if success: + await self._report_progress( + 1.0, + 1.0, + self._format_message( + tool_call_id, message=f"→ {format_result_preview(content)}" + ), + ) + self._tool_labels.pop(tool_call_id, None) + tool_use_ids = self._tool_use_by_call.pop(tool_call_id, set()) + for tool_use_id in tool_use_ids: + self._tool_use_map.pop(tool_use_id, None) + return + + await super().on_tool_complete(tool_call_id, success, content, error)