feat: support per-agent model and capabilities overrides in agents.yaml
Add optional `model` and `model_capabilities` fields to agent definitions in agents.yaml, allowing each agent to target a different model/provider with its own capability parameters (vision, context_window, etc.). - Refactor `_build_agents_table` to return rich dicts instead of tuples - Extract `_register_one_model` from `_register_unknown_models` for reuse - Register per-agent models in addition to the global default_model, falling back to top-level model_capabilities when agent-specific ones are not provided - Override `AgentConfig.model` at startup when an agent declares a model - Thread deployment_config through `_preflight` and `_start_agent`
This commit is contained in:
119
pallas/server.py
119
pallas/server.py
@@ -46,10 +46,23 @@ def _load_deployment_config() -> dict:
|
||||
return config
|
||||
|
||||
|
||||
def _build_agents_table(config: dict) -> dict[str, tuple[str, int]]:
|
||||
"""Build {name: (module_path, port)} from agents.yaml."""
|
||||
def _build_agents_table(config: dict) -> dict[str, dict]:
|
||||
"""Build {name: {module, port, model?, model_capabilities?}} from agents.yaml.
|
||||
|
||||
The ``model`` and ``model_capabilities`` fields are optional. When ``model``
|
||||
is set, Pallas overrides the agent's ``AgentConfig.model`` at startup so
|
||||
fast-agent routes that agent to the specified model/provider. When
|
||||
``model_capabilities`` is set, those capabilities are used to register the
|
||||
model with fast-agent's ``ModelDatabase`` instead of the top-level defaults
|
||||
from ``fastagent.config.yaml``.
|
||||
"""
|
||||
return {
|
||||
name: (agent["module"], agent["port"])
|
||||
name: {
|
||||
"module": agent["module"],
|
||||
"port": agent["port"],
|
||||
"model": agent.get("model"),
|
||||
"model_capabilities": agent.get("model_capabilities"),
|
||||
}
|
||||
for name, agent in config["agents"].items()
|
||||
}
|
||||
|
||||
@@ -109,34 +122,15 @@ def _preflight_mcp_servers(agent_name: str, servers: dict[str, dict]) -> None:
|
||||
|
||||
# ── Model registration ────────────────────────────────────────────────────────
|
||||
|
||||
def _register_unknown_models() -> None:
|
||||
"""Register runtime model params for models not in fast-agent's ModelDatabase.
|
||||
|
||||
Reads ``default_model`` and ``model_capabilities`` from the active
|
||||
fastagent.config.yaml. Capabilities (vision, context_window,
|
||||
max_output_tokens) are declared explicitly in the config rather than
|
||||
inferred from the model name, since naming conventions vary across
|
||||
model families.
|
||||
"""
|
||||
config_path = _config_root() / "fastagent.config.yaml"
|
||||
if not config_path.exists():
|
||||
return
|
||||
|
||||
def _register_one_model(model_spec: str, capabilities: dict) -> None:
|
||||
"""Register a single model with fast-agent's ModelDatabase if unknown."""
|
||||
from fast_agent.llm.model_database import ModelDatabase, ModelParameters
|
||||
|
||||
with open(config_path) as f:
|
||||
config = yaml.safe_load(f) or {}
|
||||
|
||||
default_model = config.get("default_model", "")
|
||||
if not default_model:
|
||||
return
|
||||
|
||||
model_name = default_model.split(".", 1)[-1] if "." in default_model else default_model
|
||||
model_name = model_spec.split(".", 1)[-1] if "." in model_spec else model_spec
|
||||
|
||||
if ModelDatabase.get_model_params(model_name) is not None:
|
||||
return
|
||||
|
||||
capabilities = config.get("model_capabilities", {})
|
||||
is_vision = capabilities.get("vision", False)
|
||||
context_window = capabilities.get("context_window", 131072)
|
||||
max_output_tokens = capabilities.get("max_output_tokens", 16384)
|
||||
@@ -158,22 +152,67 @@ def _register_unknown_models() -> None:
|
||||
)
|
||||
|
||||
|
||||
def _register_unknown_models(deployment_config: dict) -> None:
|
||||
"""Register runtime model params for models not in fast-agent's ModelDatabase.
|
||||
|
||||
Registers the ``default_model`` from ``fastagent.config.yaml`` plus every
|
||||
per-agent ``model`` declared in ``agents.yaml``. Capabilities are resolved
|
||||
per model: if the agent carries its own ``model_capabilities`` block, those
|
||||
take effect; otherwise the top-level ``model_capabilities`` from
|
||||
``fastagent.config.yaml`` apply.
|
||||
"""
|
||||
fastagent_config_path = _config_root() / "fastagent.config.yaml"
|
||||
if not fastagent_config_path.exists():
|
||||
return
|
||||
|
||||
with open(fastagent_config_path) as f:
|
||||
fa_config = yaml.safe_load(f) or {}
|
||||
|
||||
default_model = fa_config.get("default_model", "")
|
||||
default_capabilities = fa_config.get("model_capabilities", {})
|
||||
|
||||
seen: set[str] = set()
|
||||
|
||||
if default_model:
|
||||
_register_one_model(default_model, default_capabilities)
|
||||
seen.add(default_model)
|
||||
|
||||
for agent_name, agent in deployment_config.get("agents", {}).items():
|
||||
agent_model = agent.get("model")
|
||||
if not agent_model or agent_model in seen:
|
||||
continue
|
||||
agent_caps = agent.get("model_capabilities") or default_capabilities
|
||||
_register_one_model(agent_model, agent_caps)
|
||||
seen.add(agent_model)
|
||||
|
||||
|
||||
# ── Agent lifecycle ───────────────────────────────────────────────────────────
|
||||
|
||||
async def _preflight() -> None:
|
||||
async def _preflight(deployment_config: dict) -> None:
|
||||
from pallas.health import validate_llm_providers
|
||||
|
||||
_register_unknown_models()
|
||||
_register_unknown_models(deployment_config)
|
||||
await validate_llm_providers()
|
||||
|
||||
|
||||
async def _start_agent(name: str, agents: dict[str, tuple[str, int]]) -> None:
|
||||
async def _start_agent(name: str, agents: dict[str, dict]) -> None:
|
||||
from pallas.health import register_health_tool
|
||||
|
||||
module_path, port = agents[name]
|
||||
entry = agents[name]
|
||||
module_path = entry["module"]
|
||||
port = entry["port"]
|
||||
model_override = entry.get("model")
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
fast_instance = module.fast
|
||||
|
||||
if model_override:
|
||||
for agent_data in fast_instance.agents.values():
|
||||
agent_cfg = agent_data.get("config")
|
||||
if agent_cfg is not None:
|
||||
agent_cfg.model = model_override
|
||||
logger.info("%s model override → %s", name, model_override)
|
||||
|
||||
logger.info("Starting %s agent on port %d", name, port)
|
||||
|
||||
async with fast_instance.run():
|
||||
@@ -198,12 +237,12 @@ async def _start_agent(name: str, agents: dict[str, tuple[str, int]]) -> None:
|
||||
|
||||
async def _wait_for_agent(
|
||||
name: str,
|
||||
agents: dict[str, tuple[str, int]],
|
||||
agents: dict[str, dict],
|
||||
timeout: float = 60.0,
|
||||
) -> None:
|
||||
import httpx
|
||||
|
||||
_, port = agents[name]
|
||||
port = agents[name]["port"]
|
||||
url = f"http://127.0.0.1:{port}/mcp"
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
@@ -217,8 +256,8 @@ async def _wait_for_agent(
|
||||
logger.warning("%s did not become ready within %.0fs", name, timeout)
|
||||
|
||||
|
||||
async def _run_single(name: str, agents: dict[str, tuple[str, int]]) -> None:
|
||||
await _preflight()
|
||||
async def _run_single(name: str, agents: dict[str, dict], deployment_config: dict) -> None:
|
||||
await _preflight(deployment_config)
|
||||
await _start_agent(name, agents)
|
||||
|
||||
|
||||
@@ -229,7 +268,7 @@ async def _start_all(config: dict) -> None:
|
||||
agent_deps = _build_agent_deps(config)
|
||||
registry_port = config.get("registry_port", 24200)
|
||||
|
||||
await _preflight()
|
||||
await _preflight(config)
|
||||
|
||||
# Identify subagents that must start first.
|
||||
subagents: set[str] = set()
|
||||
@@ -271,8 +310,8 @@ def main() -> None:
|
||||
epilog=(
|
||||
"Port assignments:\n"
|
||||
+ "\n".join(
|
||||
f" {name:16s} port {port}"
|
||||
for name, (_, port) in agents.items()
|
||||
f" {name:16s} port {entry['port']}"
|
||||
for name, entry in agents.items()
|
||||
)
|
||||
+ f"\n {'registry':16s} port {registry_port}"
|
||||
),
|
||||
@@ -288,16 +327,16 @@ def main() -> None:
|
||||
setup_logging()
|
||||
|
||||
if args.agent:
|
||||
_, port = agents[args.agent]
|
||||
port = agents[args.agent]["port"]
|
||||
logger.info("Starting %s agent on port %d", args.agent, port)
|
||||
asyncio.run(_run_single(args.agent, agents))
|
||||
asyncio.run(_run_single(args.agent, agents, config))
|
||||
else:
|
||||
logger.info("Starting all agents + registry for %s", deploy_name)
|
||||
logger.info(
|
||||
"registry → http://0.0.0.0:%d/.well-known/mcp/server.json", registry_port
|
||||
)
|
||||
for name, (_, port) in agents.items():
|
||||
logger.info("%-16s → http://0.0.0.0:%d/mcp", name, port)
|
||||
for name, entry in agents.items():
|
||||
logger.info("%-16s → http://0.0.0.0:%d/mcp", name, entry["port"])
|
||||
asyncio.run(_start_all(config))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user