feat: support per-agent model and capabilities overrides in agents.yaml

Add optional `model` and `model_capabilities` fields to agent definitions
in agents.yaml, allowing each agent to target a different model/provider
with its own capability parameters (vision, context_window, etc.).

- Refactor `_build_agents_table` to return rich dicts instead of tuples
- Extract `_register_one_model` from `_register_unknown_models` for reuse
- Register per-agent models in addition to the global default_model,
  falling back to top-level model_capabilities when agent-specific ones
  are not provided
- Override `AgentConfig.model` at startup when an agent declares a model
- Thread deployment_config through `_preflight` and `_start_agent`
This commit is contained in:
2026-04-15 13:50:20 -04:00
parent 35cc2143b1
commit 065ce0b0dd

View File

@@ -46,10 +46,23 @@ def _load_deployment_config() -> dict:
return config
def _build_agents_table(config: dict) -> dict[str, tuple[str, int]]:
"""Build {name: (module_path, port)} from agents.yaml."""
def _build_agents_table(config: dict) -> dict[str, dict]:
"""Build {name: {module, port, model?, model_capabilities?}} from agents.yaml.
The ``model`` and ``model_capabilities`` fields are optional. When ``model``
is set, Pallas overrides the agent's ``AgentConfig.model`` at startup so
fast-agent routes that agent to the specified model/provider. When
``model_capabilities`` is set, those capabilities are used to register the
model with fast-agent's ``ModelDatabase`` instead of the top-level defaults
from ``fastagent.config.yaml``.
"""
return {
name: (agent["module"], agent["port"])
name: {
"module": agent["module"],
"port": agent["port"],
"model": agent.get("model"),
"model_capabilities": agent.get("model_capabilities"),
}
for name, agent in config["agents"].items()
}
@@ -109,34 +122,15 @@ def _preflight_mcp_servers(agent_name: str, servers: dict[str, dict]) -> None:
# ── Model registration ────────────────────────────────────────────────────────
def _register_unknown_models() -> None:
"""Register runtime model params for models not in fast-agent's ModelDatabase.
Reads ``default_model`` and ``model_capabilities`` from the active
fastagent.config.yaml. Capabilities (vision, context_window,
max_output_tokens) are declared explicitly in the config rather than
inferred from the model name, since naming conventions vary across
model families.
"""
config_path = _config_root() / "fastagent.config.yaml"
if not config_path.exists():
return
def _register_one_model(model_spec: str, capabilities: dict) -> None:
"""Register a single model with fast-agent's ModelDatabase if unknown."""
from fast_agent.llm.model_database import ModelDatabase, ModelParameters
with open(config_path) as f:
config = yaml.safe_load(f) or {}
default_model = config.get("default_model", "")
if not default_model:
return
model_name = default_model.split(".", 1)[-1] if "." in default_model else default_model
model_name = model_spec.split(".", 1)[-1] if "." in model_spec else model_spec
if ModelDatabase.get_model_params(model_name) is not None:
return
capabilities = config.get("model_capabilities", {})
is_vision = capabilities.get("vision", False)
context_window = capabilities.get("context_window", 131072)
max_output_tokens = capabilities.get("max_output_tokens", 16384)
@@ -158,22 +152,67 @@ def _register_unknown_models() -> None:
)
def _register_unknown_models(deployment_config: dict) -> None:
"""Register runtime model params for models not in fast-agent's ModelDatabase.
Registers the ``default_model`` from ``fastagent.config.yaml`` plus every
per-agent ``model`` declared in ``agents.yaml``. Capabilities are resolved
per model: if the agent carries its own ``model_capabilities`` block, those
take effect; otherwise the top-level ``model_capabilities`` from
``fastagent.config.yaml`` apply.
"""
fastagent_config_path = _config_root() / "fastagent.config.yaml"
if not fastagent_config_path.exists():
return
with open(fastagent_config_path) as f:
fa_config = yaml.safe_load(f) or {}
default_model = fa_config.get("default_model", "")
default_capabilities = fa_config.get("model_capabilities", {})
seen: set[str] = set()
if default_model:
_register_one_model(default_model, default_capabilities)
seen.add(default_model)
for agent_name, agent in deployment_config.get("agents", {}).items():
agent_model = agent.get("model")
if not agent_model or agent_model in seen:
continue
agent_caps = agent.get("model_capabilities") or default_capabilities
_register_one_model(agent_model, agent_caps)
seen.add(agent_model)
# ── Agent lifecycle ───────────────────────────────────────────────────────────
async def _preflight() -> None:
async def _preflight(deployment_config: dict) -> None:
from pallas.health import validate_llm_providers
_register_unknown_models()
_register_unknown_models(deployment_config)
await validate_llm_providers()
async def _start_agent(name: str, agents: dict[str, tuple[str, int]]) -> None:
async def _start_agent(name: str, agents: dict[str, dict]) -> None:
from pallas.health import register_health_tool
module_path, port = agents[name]
entry = agents[name]
module_path = entry["module"]
port = entry["port"]
model_override = entry.get("model")
module = importlib.import_module(module_path)
fast_instance = module.fast
if model_override:
for agent_data in fast_instance.agents.values():
agent_cfg = agent_data.get("config")
if agent_cfg is not None:
agent_cfg.model = model_override
logger.info("%s model override → %s", name, model_override)
logger.info("Starting %s agent on port %d", name, port)
async with fast_instance.run():
@@ -198,12 +237,12 @@ async def _start_agent(name: str, agents: dict[str, tuple[str, int]]) -> None:
async def _wait_for_agent(
name: str,
agents: dict[str, tuple[str, int]],
agents: dict[str, dict],
timeout: float = 60.0,
) -> None:
import httpx
_, port = agents[name]
port = agents[name]["port"]
url = f"http://127.0.0.1:{port}/mcp"
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
@@ -217,8 +256,8 @@ async def _wait_for_agent(
logger.warning("%s did not become ready within %.0fs", name, timeout)
async def _run_single(name: str, agents: dict[str, tuple[str, int]]) -> None:
await _preflight()
async def _run_single(name: str, agents: dict[str, dict], deployment_config: dict) -> None:
await _preflight(deployment_config)
await _start_agent(name, agents)
@@ -229,7 +268,7 @@ async def _start_all(config: dict) -> None:
agent_deps = _build_agent_deps(config)
registry_port = config.get("registry_port", 24200)
await _preflight()
await _preflight(config)
# Identify subagents that must start first.
subagents: set[str] = set()
@@ -271,8 +310,8 @@ def main() -> None:
epilog=(
"Port assignments:\n"
+ "\n".join(
f" {name:16s} port {port}"
for name, (_, port) in agents.items()
f" {name:16s} port {entry['port']}"
for name, entry in agents.items()
)
+ f"\n {'registry':16s} port {registry_port}"
),
@@ -288,16 +327,16 @@ def main() -> None:
setup_logging()
if args.agent:
_, port = agents[args.agent]
port = agents[args.agent]["port"]
logger.info("Starting %s agent on port %d", args.agent, port)
asyncio.run(_run_single(args.agent, agents))
asyncio.run(_run_single(args.agent, agents, config))
else:
logger.info("Starting all agents + registry for %s", deploy_name)
logger.info(
"registry → http://0.0.0.0:%d/.well-known/mcp/server.json", registry_port
)
for name, (_, port) in agents.items():
logger.info("%-16s → http://0.0.0.0:%d/mcp", name, port)
for name, entry in agents.items():
logger.info("%-16s → http://0.0.0.0:%d/mcp", name, entry["port"])
asyncio.run(_start_all(config))