Files
pallas/pallas/health.py
Robert Helewka 0cea5ece3a feat: add /healthz and /metrics endpoints, replace print with logging
- Add /healthz endpoint returning LLM provider validation status
- Add /metrics endpoint serving Prometheus metrics via prometheus_client
- Replace all print() calls in health.py with proper logging module
- Remove _PREFIX variable in favor of structured logger context
2026-04-10 11:22:26 +00:00

323 lines
12 KiB
Python

"""
Health check module for Pallas.
Probes downstream MCP server connectivity and exposes a get_health MCP tool.
Validates LLM provider API keys and model availability at startup.
"""
import asyncio
import json
import logging
import os
import re
from datetime import datetime, timezone
from pathlib import Path
import httpx
import yaml
logger = logging.getLogger(__name__)
def _config_root() -> Path:
"""Return the working directory where agents.yaml and fastagent configs live."""
return Path.cwd()
def _load_deployment_name() -> str:
"""Read the deployment name from agents.yaml (or PALLAS_AGENTS_CONFIG override)."""
config_path = _config_root() / os.environ.get("PALLAS_AGENTS_CONFIG", "agents.yaml")
if config_path.exists():
data = yaml.safe_load(config_path.read_text()) or {}
return data.get("name", "pallas")
return "pallas"
_DEPLOY_NAME = _load_deployment_name()
# ── Provider API endpoints ───────────────────────────────────────────────────
_ANTHROPIC_API = "https://api.anthropic.com/v1"
_OPENAI_DEFAULT_API = "https://api.openai.com/v1"
# Populated by validate_llm_providers() at startup, read by get_health()
_llm_status: dict[str, dict] = {}
_active_provider: str = ""
def _load_dotenv() -> None:
"""Load .env file into os.environ (without overwriting existing vars)."""
env_path = _config_root() / ".env"
if not env_path.exists():
return
for line in env_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip()
if key and key not in os.environ:
os.environ[key] = value
def _expand_env(value: str) -> str:
"""Replace ${VAR} placeholders with environment variable values."""
return re.sub(
r"\$\{([^}]+)\}",
lambda m: os.environ.get(m.group(1), ""),
value,
)
def _load_config() -> tuple[dict, dict]:
"""Load fastagent config and secrets YAML from the working directory."""
root = _config_root()
config = yaml.safe_load((root / "fastagent.config.yaml").read_text()) or {}
secrets_path = root / "fastagent.secrets.yaml"
secrets = yaml.safe_load(secrets_path.read_text()) if secrets_path.exists() else {}
return config, secrets
async def _check_anthropic(client: httpx.AsyncClient, api_key: str, model_id: str) -> str | None:
"""Validate an Anthropic model. Returns None on success, error message on failure."""
try:
resp = await client.get(
f"{_ANTHROPIC_API}/models/{model_id}",
headers={
"x-api-key": api_key,
"anthropic-version": "2023-06-01",
},
)
except Exception as exc:
return f"API unreachable ({type(exc).__name__})"
if resp.status_code == 200:
return None
if resp.status_code == 404:
return f"model '{model_id}' not found"
return f"API request failed ({resp.status_code})"
async def _check_openai(
client: httpx.AsyncClient, api_key: str, model_id: str, base_url: str
) -> str | None:
"""Validate an OpenAI-compatible model. Returns None on success, error message on failure."""
try:
resp = await client.get(
f"{base_url.rstrip('/')}/models/{model_id}",
headers={"Authorization": f"Bearer {api_key}"},
)
except Exception as exc:
return f"API unreachable ({type(exc).__name__})"
if resp.status_code == 200:
return None
if resp.status_code == 404:
return f"model '{model_id}' not found"
return f"API request failed ({resp.status_code})"
async def _list_openai_models(
client: httpx.AsyncClient, api_key: str, base_url: str
) -> tuple[str | None, list[str]]:
"""List models from an OpenAI-compatible API. Returns (error, model_ids)."""
try:
resp = await client.get(
f"{base_url.rstrip('/')}/models",
headers={"Authorization": f"Bearer {api_key}"},
)
except Exception as exc:
return f"API unreachable ({type(exc).__name__})", []
if resp.status_code != 200:
return f"API request failed ({resp.status_code})", []
data = resp.json()
models = [m["id"] for m in data.get("data", []) if "id" in m]
return None, models
async def validate_llm_providers(timeout: float = 5.0) -> dict[str, dict]:
"""
Validate configured LLM provider API keys and model availability.
Reads fastagent.config.yaml for default_model and fastagent.secrets.yaml
for API keys. Checks all providers that have keys configured.
Returns a dict keyed by provider name with validation results.
"""
_load_dotenv()
config, secrets = _load_config()
default_model = config.get("default_model", "")
# Parse provider and model from "provider.model-name" format
active_provider = default_model.split(".")[0] if "." in default_model else ""
active_model = default_model.split(".", 1)[1] if "." in default_model else default_model
# Resolve API keys from secrets (expanding ${ENV_VAR} references), falling
# back to env vars directly so that .env alone is sufficient.
anthropic_key = _expand_env(secrets.get("anthropic", {}).get("api_key", "")) or os.environ.get("ANTHROPIC_API_KEY", "")
openai_key = _expand_env(secrets.get("openai", {}).get("api_key", "")) or os.environ.get("OPENAI_API_KEY", "")
openai_base = (
_expand_env(secrets.get("openai", {}).get("base_url", ""))
or config.get("openai", {}).get("base_url", "")
or os.environ.get("OPENAI_BASE_URL", "")
or _OPENAI_DEFAULT_API
)
results: dict[str, dict] = {}
async with httpx.AsyncClient(timeout=timeout) as client:
# ── Anthropic ────────────────────────────────────────────────────
if anthropic_key:
model_id = active_model if active_provider == "anthropic" else None
if model_id:
err = await _check_anthropic(client, anthropic_key, model_id)
if err:
results["anthropic"] = {"status": "error", "model": model_id, "message": err}
logger.warning("anthropic: %s", err)
else:
results["anthropic"] = {"status": "ok", "model": model_id}
logger.info("anthropic: %s ready", model_id)
else:
# Key is set but Anthropic isn't the active provider — just verify API access
err = await _check_anthropic(client, anthropic_key, "claude-sonnet-4-5")
if err and "not found" not in err:
results["anthropic"] = {"status": "error", "message": err}
logger.warning("anthropic: %s", err)
else:
results["anthropic"] = {"status": "ok"}
logger.info("anthropic: API key valid")
elif active_provider == "anthropic":
results["anthropic"] = {"status": "error", "message": "API key not configured"}
logger.warning("anthropic: API key not configured")
# ── OpenAI ───────────────────────────────────────────────────────
if openai_key:
model_id = active_model if active_provider == "openai" else None
err, models = await _list_openai_models(client, openai_key, openai_base)
if err:
results["openai"] = {"status": "error", "message": err}
logger.warning("openai (%s): %s", openai_base, err)
elif model_id:
if model_id in models:
results["openai"] = {"status": "ok", "model": model_id}
logger.info("openai (%s): %s ready", openai_base, model_id)
else:
label = ", ".join(models) if models else "none"
results["openai"] = {"status": "error", "model": model_id, "message": f"model '{model_id}' not found (available: {label})"}
logger.warning("openai (%s): model '%s' not found (available: %s)", openai_base, model_id, label)
else:
results["openai"] = {"status": "ok", "models": models}
label = ", ".join(models) if models else "no models loaded"
logger.info("openai (%s): %s", openai_base, label)
elif active_provider == "openai":
results["openai"] = {"status": "error", "message": "API key not configured"}
logger.warning("openai: API key not configured")
_llm_status.clear()
_llm_status.update(results)
global _active_provider
_active_provider = active_provider
return results
async def check_downstream_health(
servers: dict[str, dict], timeout: float = 3.0
) -> dict:
"""
Probe downstream MCP servers and return aggregate health status.
Args:
servers: Mapping of server name to {"url": str, "headers": dict}.
Headers may contain ${ENV_VAR} placeholders which are expanded
before the request is sent.
timeout: Per-request timeout in seconds.
Returns:
{"status": "ok"|"degraded", "timestamp": "...", "message": "..."}
"""
_load_dotenv()
async def _probe(
client: httpx.AsyncClient, name: str, cfg: dict
) -> tuple[str, bool, str]:
url = cfg.get("url", "")
raw_headers = cfg.get("headers", {})
headers = {k: _expand_env(str(v)) for k, v in raw_headers.items()}
try:
common_headers = {
"Accept": "application/json, text/event-stream",
"Content-Type": "application/json",
**headers,
}
resp = await client.post(
url,
headers=common_headers,
json={
"jsonrpc": "2.0",
"method": "initialize",
"id": 1,
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {
"name": f"{_DEPLOY_NAME}-health",
"version": "1.0.0",
},
},
},
)
if resp.status_code >= 400:
return name, False, f"HTTP {resp.status_code}"
# Tear down the session so we don't leak server-side state.
session_id = resp.headers.get("mcp-session-id")
if session_id:
try:
await client.delete(
url,
headers={**headers, "mcp-session-id": session_id},
)
except Exception:
pass # best-effort cleanup
return name, True, ""
except Exception as exc:
return name, False, type(exc).__name__
async with httpx.AsyncClient(timeout=timeout) as client:
results = await asyncio.gather(
*(_probe(client, name, cfg) for name, cfg in servers.items())
)
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
failures = sorted(
(f"{name} ({reason})" if reason else name)
for name, ok, reason in results
if not ok
)
if not failures:
return {"status": "ok", "timestamp": now}
return {
"status": "degraded",
"timestamp": now,
"message": f"Unreachable: {', '.join(failures)}",
}
def register_health_tool(mcp_server, servers: dict[str, dict]) -> None:
"""Register a get_health MCP tool on the given FastMCP server instance."""
@mcp_server.tool(
name="get_health",
description="Returns the health status of this agent and its downstream dependencies.",
)
async def get_health() -> str:
result = await check_downstream_health(servers)
# Include LLM provider status from startup preflight (active provider only)
active = _llm_status.get(_active_provider, {})
if active.get("status") != "ok" and _active_provider:
err_msg = f"LLM: {_active_provider}: {active.get('message', 'error')}"
result["status"] = "degraded"
existing = result.get("message", "")
result["message"] = f"{existing}; {err_msg}" if existing else err_msg
return json.dumps(result)