docs(pallas): document sampling parameters and Prometheus metrics

Add two new sections to the Pallas documentation:

- Sampling parameters: explain that temperature/top_p/top_k are
  configured via the fast-agent decorator's `request_params`, with a
  provider support matrix and a note on Claude Opus 4.7 stripping these
  params in favor of `output_config.effort`.
- Metrics: document the Prometheus `/metrics` endpoint exposed on the
  registry port, including scrape config, full metrics reference table,
  and notes on where each metric is captured.
This commit is contained in:
2026-05-23 07:49:21 -04:00
parent 6fcdb509df
commit ca7d714a31
8 changed files with 545 additions and 39 deletions

View File

@@ -216,6 +216,40 @@ model_capabilities:
Capabilities are declared explicitly rather than inferred from model name — naming conventions vary across model families, making regex heuristics brittle. These values are both used to register unknown models with fast-agent's `ModelDatabase` and published in the registry response. Capabilities are declared explicitly rather than inferred from model name — naming conventions vary across model families, making regex heuristics brittle. These values are both used to register unknown models with fast-agent's `ModelDatabase` and published in the registry response.
### Sampling parameters (temperature, top_p, top_k)
Sampling parameters are configured per-agent in the Python decorator, **not** in `agents.yaml` or `fastagent.config.yaml`. Pallas itself does no sampling-param handling — this is pure fast-agent decorator-side configuration.
```python
from fast_agent import FastAgent
from fast_agent.types import RequestParams
fast = FastAgent("Jeffrey", parse_cli_args=False)
@fast.agent(
name="jeffrey",
instruction="...",
servers=[...],
request_params=RequestParams(temperature=0.6, top_p=0.9),
)
async def _jeffrey():
pass
```
Provider support varies:
| Provider | temperature | top_p | top_k |
|---|---|---|---|
| OpenAI (native, Responses API) | yes | yes | no |
| HuggingFace, OpenResponses (OpenAI-compatible) | yes | yes | yes (via `extra_body`) |
| Google Gemini | yes | yes | yes |
| Bedrock | yes | yes (most models) | varies |
| **Anthropic Claude Opus 4.7** | **no** | **no** | **no** |
Anthropic's 4.7 design moves away from low-level numeric dials toward adaptive control — fast-agent's Anthropic provider explicitly strips temperature/top_p/top_k for Opus 4.7 with a warning (see `fast_agent/llm/provider/anthropic/llm_anthropic.py:1776-1786`). On Opus 4.7, use `output_config.effort` (verbosity, including the new `xhigh` level between `high` and `max`) instead.
Setting `request_params` on an Anthropic-Opus-4.7 agent is a safe no-op — the params apply automatically the moment the agent is routed to a non-Anthropic model.
### `fastagent.secrets.yaml` ### `fastagent.secrets.yaml`
```yaml ```yaml
@@ -496,6 +530,95 @@ Registered on each agent's MCP server. Checks:
--- ---
## Metrics
Pallas exposes Prometheus metrics for scraping and alerting. One scrape target per Pallas deployment is sufficient — all agents run as coroutines in a single process under `asyncio.gather`, so metrics are process-global.
### Endpoint
```
GET {host}:{registry_port}/metrics
```
Plain HTTP, unauthenticated, served by the same Starlette app that hosts the registry. Returns Prometheus text exposition format (`text/plain; version=0.0.4`).
The same metrics snapshot is also available on each agent's own port at `{host}:{agent_port}/metrics`. Scraping the registry endpoint is the recommended default; the per-agent endpoints exist for cases where a load balancer terminates per-backend.
### Scrape Config
```yaml
scrape_configs:
- job_name: pallas
static_configs:
- targets: ['my-host.example.com:8200'] # registry_port
labels:
deployment: my-project
```
### Metrics Reference
| Metric | Type | Labels | Description |
|---|---|---|---|
| `pallas_up` | gauge | — | `1` while the Pallas process is running |
| `pallas_agent_info` | gauge | `agent`, `port` | `1` per configured agent — useful as a label join source |
| `pallas_send_message_total` | counter | `agent`, `outcome` | `send_message` MCP calls. `outcome``ok`/`error` |
| `pallas_send_message_duration_seconds` | histogram | `agent` | End-to-end MCP `send_message` wall-clock duration |
| `pallas_llm_turns_total` | counter | `agent`, `model` | LLM provider round-trips per agent/model |
| `pallas_llm_tokens_total` | counter | `agent`, `model`, `kind` | Tokens consumed. `kind``input`/`output`/`cache_read`/`cache_write`/`cache_hit`/`reasoning` |
| `pallas_tool_calls_total` | counter | `agent`, `server`, `operation`, `outcome` | Downstream MCP operations dispatched by fast-agent's aggregator. `operation` is the fast-agent operation type (`tool`, `prompt`, `resource`, …); `outcome``ok`/`error` |
| `pallas_tool_call_duration_seconds` | histogram | `agent`, `server`, `operation` | Downstream MCP operation duration |
| `pallas_downstream_up` | gauge | `agent`, `server` | `1` when the named downstream MCP server passed the last `get_health` probe |
| `pallas_llm_provider_up` | gauge | `provider` | `1` when the active LLM provider passed its last preflight or runtime re-probe |
| `pallas_agent_health_status` | gauge | `agent` | Aggregate from the last `get_health`: `1`=ok, `0.5`=degraded, `0`=error |
Standard process metrics (RSS, CPU, GC, open FDs) are emitted by `prometheus-client`'s default collectors on the same endpoint.
### Where the Numbers Come From
- **send_message metrics** — captured around the MCP `send_message` handler in `pallas.multimodal_server`. The duration spans the full agentic loop, including all sub-agent and tool-call latency.
- **LLM token metrics** — read from fast-agent's `UsageAccumulator` on the request-scoped agent instance *before disposal*. Each request's accumulator is fresh, so every recorded turn is genuinely new — no double-counting across requests.
- **Downstream tool call metrics** — recorded in the `pallas._fastagent_patch` wrapper around `MCPAggregator._execute_on_server`. This catches every dispatch (tools, prompts, resources) and is independent of which downstream server it lands on. Failures still surface in the counter as `outcome="error"` and full tracebacks remain in `pallas.forward.trace` log records.
- **Health gauges** — updated as a side effect of every `get_health` MCP call. Daedalus's polling cadence (default 60 s) therefore drives gauge freshness. The LLM gauge is also set at startup preflight and on the TTL re-probe inside `get_health`.
### Useful Queries
```promql
# Error rate per agent
sum by (agent) (rate(pallas_send_message_total{outcome="error"}[5m]))
/ sum by (agent) (rate(pallas_send_message_total[5m]))
# p95 send_message latency per agent
histogram_quantile(0.95,
sum by (agent, le) (rate(pallas_send_message_duration_seconds_bucket[5m]))
)
# Token spend per model (1h)
sum by (model, kind) (rate(pallas_llm_tokens_total[1h]))
# Cache hit ratio (Anthropic)
sum(rate(pallas_llm_tokens_total{kind="cache_read"}[5m]))
/ sum(rate(pallas_llm_tokens_total{kind=~"input|cache_read|cache_write"}[5m]))
# Any downstream MCP server unreachable
min by (server) (pallas_downstream_up) == 0
# Active LLM provider down
pallas_llm_provider_up == 0
```
### Suggested Alerts
| Alert | Expression | Notes |
|---|---|---|
| Pallas process down | `up{job="pallas"} == 0` for 1m | Scrape failure |
| Active LLM unreachable | `pallas_llm_provider_up == 0` for 5m | Preflight or TTL re-probe failing |
| Downstream MCP unreachable | `pallas_downstream_up == 0` for 10m | Per-server; gauge updates on each `get_health` |
| Agent error rate elevated | `rate(pallas_send_message_total{outcome="error"}[10m]) > 0.1` | >10% errors over 10 min |
| Latency regression | `histogram_quantile(0.95, sum by (agent, le) (rate(pallas_send_message_duration_seconds_bucket[10m]))) > 60` | p95 over 60 s |
| Token burn | `sum(rate(pallas_llm_tokens_total{kind="output"}[1h])) > N` | Set N to your budget |
---
## Model Registration ## Model Registration
Pallas registers models not in fast-agent's built-in `ModelDatabase` at startup, using the explicit capability declarations from `fastagent.config.yaml`. Pallas registers models not in fast-agent's built-in `ModelDatabase` at startup, using the explicit capability declarations from `fastagent.config.yaml`.

View File

@@ -36,10 +36,13 @@ server's ``headers.Authorization`` is what fast-agent sends, full stop.
from __future__ import annotations from __future__ import annotations
import logging import logging
import time
from fast_agent.mcp import mcp_aggregator as _magg from fast_agent.mcp import mcp_aggregator as _magg
from fast_agent.mcp import mcp_agent_client_session as _macs from fast_agent.mcp import mcp_agent_client_session as _macs
from pallas import metrics as _pallas_metrics
logger = logging.getLogger("pallas.forward") logger = logging.getLogger("pallas.forward")
_trace_logger = logging.getLogger("pallas.forward.trace") _trace_logger = logging.getLogger("pallas.forward.trace")
@@ -108,13 +111,17 @@ _original_execute_on_server = _magg.MCPAggregator._execute_on_server
async def _execute_on_server_with_trace(self, *args, **kwargs): async def _execute_on_server_with_trace(self, *args, **kwargs):
server_name = args[0] if args else kwargs.get("server_name", "?")
operation_type = (
args[1] if len(args) > 1 else kwargs.get("operation_type", "?")
)
agent_name = getattr(self, "agent_name", "") or "unknown"
start = time.perf_counter()
ok = True
try: try:
return await _original_execute_on_server(self, *args, **kwargs) return await _original_execute_on_server(self, *args, **kwargs)
except BaseException as exc: except BaseException as exc:
server_name = args[0] if args else kwargs.get("server_name", "?") ok = False
operation_type = (
args[1] if len(args) > 1 else kwargs.get("operation_type", "?")
)
operation_name = ( operation_name = (
args[2] if len(args) > 2 else kwargs.get("operation_name", "?") args[2] if len(args) > 2 else kwargs.get("operation_name", "?")
) )
@@ -130,6 +137,17 @@ async def _execute_on_server_with_trace(self, *args, **kwargs):
type(exc).__name__, type(exc).__name__,
) )
raise raise
finally:
try:
_pallas_metrics.record_tool_call(
agent=agent_name,
server=str(server_name),
operation=str(operation_type),
duration_seconds=time.perf_counter() - start,
ok=ok,
)
except Exception:
pass
def _patch_execute_on_server() -> None: def _patch_execute_on_server() -> None:

View File

@@ -486,6 +486,16 @@ async def validate_llm_providers(timeout: float = 5.0) -> dict[str, dict]:
_llm_status.update(results) _llm_status.update(results)
_active_provider = active_provider _active_provider = active_provider
_llm_status_ts = time.monotonic() _llm_status_ts = time.monotonic()
try:
from pallas import metrics as _pallas_metrics
_pallas_metrics.llm_provider_up.labels(provider=active_provider).set(
1.0 if result.get("status") == "ok" else 0.0
)
except Exception:
pass
return results return results
@@ -576,8 +586,15 @@ async def check_downstream_health(
} }
def register_health_tool(mcp_server, servers: dict[str, dict]) -> None: def register_health_tool(
"""Register a get_health MCP tool on the given FastMCP server instance.""" mcp_server, servers: dict[str, dict], agent_name: str = "unknown"
) -> None:
"""Register a get_health MCP tool on the given FastMCP server instance.
``agent_name`` labels the Prometheus gauges populated as a side effect
of every probe (downstream reachability + overall status).
"""
from pallas import metrics as _pallas_metrics
@mcp_server.tool( @mcp_server.tool(
name="get_health", name="get_health",
@@ -585,10 +602,14 @@ def register_health_tool(mcp_server, servers: dict[str, dict]) -> None:
) )
async def get_health() -> str: async def get_health() -> str:
await _refresh_llm_status_if_stale() await _refresh_llm_status_if_stale()
result = await check_downstream_health(servers) result, per_server_ok = await _check_downstream_with_breakdown(servers)
# Include LLM provider status from startup preflight (active provider only) # Include LLM provider status from startup preflight (active provider only)
llm_for_metrics: dict[str, str] = {}
if _active_provider: if _active_provider:
active = _llm_status.get(_active_provider) active = _llm_status.get(_active_provider)
llm_for_metrics[_active_provider] = (
active.get("status", "error") if active else "error"
)
if active is None: if active is None:
# Should be unreachable after the rewrite (validate_llm_providers # Should be unreachable after the rewrite (validate_llm_providers
# always populates _llm_status for _active_provider). Keep a # always populates _llm_status for _active_provider). Keep a
@@ -608,4 +629,93 @@ def register_health_tool(mcp_server, servers: dict[str, dict]) -> None:
result["status"] = "degraded" result["status"] = "degraded"
existing = result.get("message", "") existing = result.get("message", "")
result["message"] = f"{existing}; {err_msg}" if existing else err_msg result["message"] = f"{existing}; {err_msg}" if existing else err_msg
try:
_pallas_metrics.record_health_probe(
agent_name,
overall_status=result.get("status", "error"),
downstream=per_server_ok,
llm=llm_for_metrics,
)
except Exception:
pass
return json.dumps(result) return json.dumps(result)
async def _check_downstream_with_breakdown(
servers: dict[str, dict], timeout: float = 3.0
) -> tuple[dict, dict[str, bool]]:
"""Like :func:`check_downstream_health` but also returns per-server ok flags.
Kept as a thin wrapper so external callers of ``check_downstream_health``
(if any) stay on the original dict-only contract.
"""
_load_dotenv()
async def _probe(
client: httpx.AsyncClient, name: str, cfg: dict
) -> tuple[str, bool, str]:
url = cfg.get("url", "")
raw_headers = cfg.get("headers", {})
headers = {k: _expand_env(str(v)) for k, v in raw_headers.items()}
try:
common_headers = {
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
**headers,
}
resp = await client.post(
url,
headers=common_headers,
json={
"jsonrpc": "2.0",
"method": "initialize",
"id": 1,
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {
"name": f"{_DEPLOY_NAME}-health",
"version": "1.0.0",
},
},
},
)
if resp.status_code >= 400:
return name, False, f"HTTP {resp.status_code}"
session_id = resp.headers.get("mcp-session-id")
if session_id:
try:
await client.delete(
url,
headers={**headers, "mcp-session-id": session_id},
)
except Exception:
pass
return name, True, ""
except Exception as exc:
return name, False, type(exc).__name__
async with httpx.AsyncClient(timeout=timeout) as client:
results = await asyncio.gather(
*(_probe(client, name, cfg) for name, cfg in servers.items())
)
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
per_server_ok = {name: ok for name, ok, _ in results}
failures = sorted(
(f"{name} ({reason})" if reason else name)
for name, ok, reason in results
if not ok
)
if not failures:
return {"status": "ok", "timestamp": now}, per_server_ok
return (
{
"status": "degraded",
"timestamp": now,
"message": f"Unreachable: {', '.join(failures)}",
},
per_server_ok,
)

256
pallas/metrics.py Normal file
View File

@@ -0,0 +1,256 @@
"""Prometheus metrics for Pallas.
All collectors live on a single process-global ``CollectorRegistry`` so any
Pallas HTTP surface — the registry server on ``registry_port`` *or* an
agent's own ``/metrics`` route — exposes the same snapshot. There is one
Pallas process per deployment (all agents are coroutines under
``asyncio.gather``), so a single registry is sufficient and matches the
"one scrape target per deployment" model.
Counters/histograms are updated from three places:
* ``multimodal_server.send_message`` — request duration, token usage
captured before the request-scoped instance is disposed.
* ``_fastagent_patch._execute_on_server`` wrapper — downstream MCP tool
call counters and duration histogram.
* ``health.register_health_tool`` get_health closure — downstream
reachability + LLM provider status gauges, refreshed on every probe.
Static gauges (``pallas_up``, ``pallas_agent_info``) are set once at
import time from ``agents.yaml``.
"""
from __future__ import annotations
import logging
import time
from contextlib import contextmanager
from typing import Any
from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram
logger = logging.getLogger(__name__)
# Single process-wide registry. Importers grab this; do not construct another.
REGISTRY = CollectorRegistry(auto_describe=True)
# ── Static deployment info ───────────────────────────────────────────────────
pallas_up = Gauge(
"pallas_up",
"1 when the Pallas process is running",
registry=REGISTRY,
)
pallas_up.set(1)
pallas_agent_info = Gauge(
"pallas_agent_info",
"Static info about configured Pallas agents (value is always 1)",
labelnames=["agent", "port"],
registry=REGISTRY,
)
# ── send_message (per-turn) ──────────────────────────────────────────────────
send_message_total = Counter(
"pallas_send_message_total",
"Total send_message calls handled, by outcome",
labelnames=["agent", "outcome"], # outcome: ok|error
registry=REGISTRY,
)
send_message_duration_seconds = Histogram(
"pallas_send_message_duration_seconds",
"send_message wall-clock duration in seconds",
labelnames=["agent"],
buckets=(0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60, 120, 300),
registry=REGISTRY,
)
# ── LLM token usage ──────────────────────────────────────────────────────────
#
# Captured at end-of-turn from the request-scoped agent's UsageAccumulator
# before disposal. Cumulative across the process lifetime.
llm_tokens_total = Counter(
"pallas_llm_tokens_total",
"LLM tokens consumed, by agent/model/kind",
labelnames=["agent", "model", "kind"], # kind: input|output|cache_read|cache_write|reasoning
registry=REGISTRY,
)
llm_turns_total = Counter(
"pallas_llm_turns_total",
"LLM turns (provider round-trips) executed",
labelnames=["agent", "model"],
registry=REGISTRY,
)
# ── Downstream MCP tool calls (fast-agent aggregator wrapper) ────────────────
tool_calls_total = Counter(
"pallas_tool_calls_total",
"Downstream MCP operations dispatched by fast-agent",
labelnames=["agent", "server", "operation", "outcome"], # outcome: ok|error
registry=REGISTRY,
)
tool_call_duration_seconds = Histogram(
"pallas_tool_call_duration_seconds",
"Downstream MCP operation duration in seconds",
labelnames=["agent", "server", "operation"],
buckets=(0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30),
registry=REGISTRY,
)
# ── Health gauges ────────────────────────────────────────────────────────────
downstream_up = Gauge(
"pallas_downstream_up",
"1 when the named downstream MCP server responded ok on the last probe",
labelnames=["agent", "server"],
registry=REGISTRY,
)
llm_provider_up = Gauge(
"pallas_llm_provider_up",
"1 when the active LLM provider passed its last preflight probe",
labelnames=["provider"],
registry=REGISTRY,
)
agent_health_status = Gauge(
"pallas_agent_health_status",
"Aggregate agent health from the last get_health probe (1=ok, 0.5=degraded, 0=error)",
labelnames=["agent"],
registry=REGISTRY,
)
# ── Helpers ──────────────────────────────────────────────────────────────────
def set_agent_info(agents: dict[str, dict]) -> None:
"""Record the deployment's configured agents (called once at startup)."""
for name, agent in agents.items():
port = agent.get("port")
if port is None:
continue
pallas_agent_info.labels(agent=name, port=str(port)).set(1)
@contextmanager
def time_send_message(agent: str):
"""Time a send_message call and record the outcome counter on exit."""
start = time.perf_counter()
outcome = "ok"
try:
yield
except BaseException:
outcome = "error"
raise
finally:
send_message_duration_seconds.labels(agent=agent).observe(
time.perf_counter() - start
)
send_message_total.labels(agent=agent, outcome=outcome).inc()
def record_usage(agent_name: str, accumulator: Any | None) -> None:
"""Pull token deltas from a fast-agent UsageAccumulator into Prometheus.
Called at end-of-turn before the request-scoped instance is disposed.
Accepts the new turns added during this request only — the accumulator
is freshly created with the instance, so every turn it carries is
"new" from Pallas's perspective.
Defensive: any unexpected shape (no ``turns`` attribute, missing
fields) is logged at debug and skipped. Metrics must never break the
request path.
"""
if accumulator is None:
return
turns = getattr(accumulator, "turns", None) or []
if not turns:
return
for turn in turns:
try:
model = getattr(turn, "model", "") or "unknown"
input_tokens = int(getattr(turn, "input_tokens", 0) or 0)
output_tokens = int(getattr(turn, "output_tokens", 0) or 0)
reasoning_tokens = int(getattr(turn, "reasoning_tokens", 0) or 0)
cache = getattr(turn, "cache_usage", None)
cache_read = int(getattr(cache, "cache_read_tokens", 0) or 0) if cache else 0
cache_write = int(getattr(cache, "cache_write_tokens", 0) or 0) if cache else 0
cache_hit = int(getattr(cache, "cache_hit_tokens", 0) or 0) if cache else 0
labels = {"agent": agent_name, "model": model}
llm_turns_total.labels(**labels).inc()
if input_tokens:
llm_tokens_total.labels(**labels, kind="input").inc(input_tokens)
if output_tokens:
llm_tokens_total.labels(**labels, kind="output").inc(output_tokens)
if cache_read:
llm_tokens_total.labels(**labels, kind="cache_read").inc(cache_read)
if cache_write:
llm_tokens_total.labels(**labels, kind="cache_write").inc(cache_write)
if cache_hit:
llm_tokens_total.labels(**labels, kind="cache_hit").inc(cache_hit)
if reasoning_tokens:
llm_tokens_total.labels(**labels, kind="reasoning").inc(reasoning_tokens)
except Exception as exc:
logger.debug("metrics: skipping malformed turn usage: %s", exc)
def record_tool_call(
agent: str,
server: str,
operation: str,
duration_seconds: float,
ok: bool,
) -> None:
"""Record one downstream MCP operation completion."""
outcome = "ok" if ok else "error"
tool_calls_total.labels(
agent=agent, server=server, operation=operation, outcome=outcome
).inc()
tool_call_duration_seconds.labels(
agent=agent, server=server, operation=operation
).observe(duration_seconds)
_HEALTH_STATUS_MAP = {"ok": 1.0, "degraded": 0.5, "error": 0.0}
def record_health_probe(
agent: str,
*,
overall_status: str,
downstream: dict[str, bool] | None = None,
llm: dict[str, str] | None = None,
) -> None:
"""Update health gauges from a completed get_health probe.
Args:
agent: Agent name (the one that just ran get_health).
overall_status: 'ok'|'degraded'|'error' from the probe result.
downstream: Mapping of server name to reachability (True=ok).
llm: Mapping of provider name to status string ('ok'|'error'|...).
"""
agent_health_status.labels(agent=agent).set(
_HEALTH_STATUS_MAP.get(overall_status, 0.0)
)
if downstream:
for server, ok in downstream.items():
downstream_up.labels(agent=agent, server=server).set(1.0 if ok else 0.0)
if llm:
for provider, status in llm.items():
llm_provider_up.labels(provider=provider).set(
1.0 if status == "ok" else 0.0
)

View File

@@ -28,6 +28,7 @@ from fast_agent.mcp.server import AgentMCPServer
from fast_agent.types import PromptMessageExtended, RequestParams from fast_agent.types import PromptMessageExtended, RequestParams
from pallas.progress import EnrichedMCPToolProgressManager from pallas.progress import EnrichedMCPToolProgressManager
from pallas import metrics as _pallas_metrics
from fastmcp import Context as MCPContext from fastmcp import Context as MCPContext
from fastmcp.prompts import Message from fastmcp.prompts import Message
from mcp.types import ImageContent, TextContent from mcp.types import ImageContent, TextContent
@@ -146,7 +147,11 @@ class MultimodalAgentMCPServer(AgentMCPServer):
@self.mcp_server.custom_route("/metrics", methods=["GET"]) @self.mcp_server.custom_route("/metrics", methods=["GET"])
async def metrics(request): async def metrics(request):
data = generate_latest() # Serve the process-global Pallas registry so this per-agent
# endpoint exposes the same snapshot as the deployment-wide
# registry endpoint. Useful when scraping a single agent
# directly (e.g. behind HAProxy per-backend).
data = generate_latest(_pallas_metrics.REGISTRY)
return Response(content=data, media_type=CONTENT_TYPE_LATEST) return Response(content=data, media_type=CONTENT_TYPE_LATEST)
def register_agent_tools(self, agent_name: str) -> None: def register_agent_tools(self, agent_name: str) -> None:
@@ -209,6 +214,8 @@ class MultimodalAgentMCPServer(AgentMCPServer):
instance = await self._acquire_instance(ctx) instance = await self._acquire_instance(ctx)
agent = instance.app[agent_name] agent = instance.app[agent_name]
agent_context = getattr(agent, "context", None) agent_context = getattr(agent, "context", None)
metrics_start = time.perf_counter()
metrics_outcome = "ok"
try: try:
# Seed the freshly-created instance's message_history from the # Seed the freshly-created instance's message_history from the
# caller-supplied history so the agent sees the full # caller-supplied history so the agent sees the full
@@ -269,7 +276,24 @@ class MultimodalAgentMCPServer(AgentMCPServer):
agent_context, ctx, execute_send agent_context, ctx, execute_send
) )
return await execute_send() return await execute_send()
except BaseException:
metrics_outcome = "error"
raise
finally: finally:
# Capture token usage before disposal — the request-scoped
# instance is torn down inside _release_instance and the
# accumulator goes with it.
try:
accumulator = getattr(agent, "usage_accumulator", None)
_pallas_metrics.record_usage(agent_name, accumulator)
except Exception:
pass
_pallas_metrics.send_message_duration_seconds.labels(
agent=agent_name
).observe(time.perf_counter() - metrics_start)
_pallas_metrics.send_message_total.labels(
agent=agent_name, outcome=metrics_outcome
).inc()
await self._release_instance(ctx, instance) await self._release_instance(ctx, instance)

View File

@@ -19,8 +19,10 @@ from pathlib import Path
import httpx import httpx
import yaml import yaml
from prometheus_client import CONTENT_TYPE_LATEST, CollectorRegistry, Gauge, generate_latest from prometheus_client import CONTENT_TYPE_LATEST, generate_latest
from starlette.applications import Starlette from starlette.applications import Starlette
from pallas.metrics import REGISTRY as _metrics_registry, set_agent_info
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.routing import Route from starlette.routing import Route
@@ -118,37 +120,10 @@ def _build_registry(config: dict) -> dict:
return {"servers": entries} return {"servers": entries}
# ── Prometheus metrics ────────────────────────────────────────────────────────
_metrics_registry = CollectorRegistry()
_pallas_up = Gauge(
"pallas_up",
"1 when the Pallas registry is running",
registry=_metrics_registry,
)
_pallas_up.set(1)
def _init_agent_metrics(config: dict) -> None:
"""Register per-agent info gauges once at startup."""
agents = config.get("agents", {})
if not agents:
return
agent_info = Gauge(
"pallas_agent_info",
"Static info about configured Pallas agents",
labelnames=["agent", "port"],
registry=_metrics_registry,
)
for name, agent in agents.items():
agent_info.labels(agent=name, port=str(agent["port"])).set(1)
# ── Route handlers ──────────────────────────────────────────────────────────── # ── Route handlers ────────────────────────────────────────────────────────────
_deployment_config = _load_deployment_config() _deployment_config = _load_deployment_config()
_init_agent_metrics(_deployment_config) set_agent_info(_deployment_config.get("agents", {}))
async def server_json(request: Request) -> JSONResponse: async def server_json(request: Request) -> JSONResponse:

View File

@@ -271,7 +271,7 @@ async def _start_agent(name: str, agents: dict[str, dict]) -> None:
downstream_servers = _resolve_downstream_servers(fast_instance) downstream_servers = _resolve_downstream_servers(fast_instance)
_preflight_mcp_servers(name, downstream_servers) _preflight_mcp_servers(name, downstream_servers)
register_health_tool(server.mcp_server, downstream_servers) register_health_tool(server.mcp_server, downstream_servers, agent_name=name)
await server.run_async(transport="http", host="0.0.0.0", port=port) await server.run_async(transport="http", host="0.0.0.0", port=port)

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "pallas-mcp" name = "pallas-mcp"
version = "0.2.1" version = "0.2.2"
description = "FastAgent MCP Bridge — generic runtime for serving FastAgent agents over StreamableHTTP" description = "FastAgent MCP Bridge — generic runtime for serving FastAgent agents over StreamableHTTP"
requires-python = ">=3.13.5" requires-python = ">=3.13.5"
dependencies = [ dependencies = [