feat: support per-agent model and capabilities overrides in agents.yaml
Add optional `model` and `model_capabilities` fields to agent definitions in agents.yaml, allowing each agent to target a different model/provider with its own capability parameters (vision, context_window, etc.). - Refactor `_build_agents_table` to return rich dicts instead of tuples - Extract `_register_one_model` from `_register_unknown_models` for reuse - Register per-agent models in addition to the global default_model, falling back to top-level model_capabilities when agent-specific ones are not provided - Override `AgentConfig.model` at startup when an agent declares a model - Thread deployment_config through `_preflight` and `_start_agent`
This commit is contained in:
119
pallas/server.py
119
pallas/server.py
@@ -46,10 +46,23 @@ def _load_deployment_config() -> dict:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def _build_agents_table(config: dict) -> dict[str, tuple[str, int]]:
|
def _build_agents_table(config: dict) -> dict[str, dict]:
|
||||||
"""Build {name: (module_path, port)} from agents.yaml."""
|
"""Build {name: {module, port, model?, model_capabilities?}} from agents.yaml.
|
||||||
|
|
||||||
|
The ``model`` and ``model_capabilities`` fields are optional. When ``model``
|
||||||
|
is set, Pallas overrides the agent's ``AgentConfig.model`` at startup so
|
||||||
|
fast-agent routes that agent to the specified model/provider. When
|
||||||
|
``model_capabilities`` is set, those capabilities are used to register the
|
||||||
|
model with fast-agent's ``ModelDatabase`` instead of the top-level defaults
|
||||||
|
from ``fastagent.config.yaml``.
|
||||||
|
"""
|
||||||
return {
|
return {
|
||||||
name: (agent["module"], agent["port"])
|
name: {
|
||||||
|
"module": agent["module"],
|
||||||
|
"port": agent["port"],
|
||||||
|
"model": agent.get("model"),
|
||||||
|
"model_capabilities": agent.get("model_capabilities"),
|
||||||
|
}
|
||||||
for name, agent in config["agents"].items()
|
for name, agent in config["agents"].items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,34 +122,15 @@ def _preflight_mcp_servers(agent_name: str, servers: dict[str, dict]) -> None:
|
|||||||
|
|
||||||
# ── Model registration ────────────────────────────────────────────────────────
|
# ── Model registration ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def _register_unknown_models() -> None:
|
def _register_one_model(model_spec: str, capabilities: dict) -> None:
|
||||||
"""Register runtime model params for models not in fast-agent's ModelDatabase.
|
"""Register a single model with fast-agent's ModelDatabase if unknown."""
|
||||||
|
|
||||||
Reads ``default_model`` and ``model_capabilities`` from the active
|
|
||||||
fastagent.config.yaml. Capabilities (vision, context_window,
|
|
||||||
max_output_tokens) are declared explicitly in the config rather than
|
|
||||||
inferred from the model name, since naming conventions vary across
|
|
||||||
model families.
|
|
||||||
"""
|
|
||||||
config_path = _config_root() / "fastagent.config.yaml"
|
|
||||||
if not config_path.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
from fast_agent.llm.model_database import ModelDatabase, ModelParameters
|
from fast_agent.llm.model_database import ModelDatabase, ModelParameters
|
||||||
|
|
||||||
with open(config_path) as f:
|
model_name = model_spec.split(".", 1)[-1] if "." in model_spec else model_spec
|
||||||
config = yaml.safe_load(f) or {}
|
|
||||||
|
|
||||||
default_model = config.get("default_model", "")
|
|
||||||
if not default_model:
|
|
||||||
return
|
|
||||||
|
|
||||||
model_name = default_model.split(".", 1)[-1] if "." in default_model else default_model
|
|
||||||
|
|
||||||
if ModelDatabase.get_model_params(model_name) is not None:
|
if ModelDatabase.get_model_params(model_name) is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
capabilities = config.get("model_capabilities", {})
|
|
||||||
is_vision = capabilities.get("vision", False)
|
is_vision = capabilities.get("vision", False)
|
||||||
context_window = capabilities.get("context_window", 131072)
|
context_window = capabilities.get("context_window", 131072)
|
||||||
max_output_tokens = capabilities.get("max_output_tokens", 16384)
|
max_output_tokens = capabilities.get("max_output_tokens", 16384)
|
||||||
@@ -158,22 +152,67 @@ def _register_unknown_models() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_unknown_models(deployment_config: dict) -> None:
|
||||||
|
"""Register runtime model params for models not in fast-agent's ModelDatabase.
|
||||||
|
|
||||||
|
Registers the ``default_model`` from ``fastagent.config.yaml`` plus every
|
||||||
|
per-agent ``model`` declared in ``agents.yaml``. Capabilities are resolved
|
||||||
|
per model: if the agent carries its own ``model_capabilities`` block, those
|
||||||
|
take effect; otherwise the top-level ``model_capabilities`` from
|
||||||
|
``fastagent.config.yaml`` apply.
|
||||||
|
"""
|
||||||
|
fastagent_config_path = _config_root() / "fastagent.config.yaml"
|
||||||
|
if not fastagent_config_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(fastagent_config_path) as f:
|
||||||
|
fa_config = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
default_model = fa_config.get("default_model", "")
|
||||||
|
default_capabilities = fa_config.get("model_capabilities", {})
|
||||||
|
|
||||||
|
seen: set[str] = set()
|
||||||
|
|
||||||
|
if default_model:
|
||||||
|
_register_one_model(default_model, default_capabilities)
|
||||||
|
seen.add(default_model)
|
||||||
|
|
||||||
|
for agent_name, agent in deployment_config.get("agents", {}).items():
|
||||||
|
agent_model = agent.get("model")
|
||||||
|
if not agent_model or agent_model in seen:
|
||||||
|
continue
|
||||||
|
agent_caps = agent.get("model_capabilities") or default_capabilities
|
||||||
|
_register_one_model(agent_model, agent_caps)
|
||||||
|
seen.add(agent_model)
|
||||||
|
|
||||||
|
|
||||||
# ── Agent lifecycle ───────────────────────────────────────────────────────────
|
# ── Agent lifecycle ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _preflight() -> None:
|
async def _preflight(deployment_config: dict) -> None:
|
||||||
from pallas.health import validate_llm_providers
|
from pallas.health import validate_llm_providers
|
||||||
|
|
||||||
_register_unknown_models()
|
_register_unknown_models(deployment_config)
|
||||||
await validate_llm_providers()
|
await validate_llm_providers()
|
||||||
|
|
||||||
|
|
||||||
async def _start_agent(name: str, agents: dict[str, tuple[str, int]]) -> None:
|
async def _start_agent(name: str, agents: dict[str, dict]) -> None:
|
||||||
from pallas.health import register_health_tool
|
from pallas.health import register_health_tool
|
||||||
|
|
||||||
module_path, port = agents[name]
|
entry = agents[name]
|
||||||
|
module_path = entry["module"]
|
||||||
|
port = entry["port"]
|
||||||
|
model_override = entry.get("model")
|
||||||
|
|
||||||
module = importlib.import_module(module_path)
|
module = importlib.import_module(module_path)
|
||||||
fast_instance = module.fast
|
fast_instance = module.fast
|
||||||
|
|
||||||
|
if model_override:
|
||||||
|
for agent_data in fast_instance.agents.values():
|
||||||
|
agent_cfg = agent_data.get("config")
|
||||||
|
if agent_cfg is not None:
|
||||||
|
agent_cfg.model = model_override
|
||||||
|
logger.info("%s model override → %s", name, model_override)
|
||||||
|
|
||||||
logger.info("Starting %s agent on port %d", name, port)
|
logger.info("Starting %s agent on port %d", name, port)
|
||||||
|
|
||||||
async with fast_instance.run():
|
async with fast_instance.run():
|
||||||
@@ -198,12 +237,12 @@ async def _start_agent(name: str, agents: dict[str, tuple[str, int]]) -> None:
|
|||||||
|
|
||||||
async def _wait_for_agent(
|
async def _wait_for_agent(
|
||||||
name: str,
|
name: str,
|
||||||
agents: dict[str, tuple[str, int]],
|
agents: dict[str, dict],
|
||||||
timeout: float = 60.0,
|
timeout: float = 60.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
_, port = agents[name]
|
port = agents[name]["port"]
|
||||||
url = f"http://127.0.0.1:{port}/mcp"
|
url = f"http://127.0.0.1:{port}/mcp"
|
||||||
deadline = asyncio.get_event_loop().time() + timeout
|
deadline = asyncio.get_event_loop().time() + timeout
|
||||||
while asyncio.get_event_loop().time() < deadline:
|
while asyncio.get_event_loop().time() < deadline:
|
||||||
@@ -217,8 +256,8 @@ async def _wait_for_agent(
|
|||||||
logger.warning("%s did not become ready within %.0fs", name, timeout)
|
logger.warning("%s did not become ready within %.0fs", name, timeout)
|
||||||
|
|
||||||
|
|
||||||
async def _run_single(name: str, agents: dict[str, tuple[str, int]]) -> None:
|
async def _run_single(name: str, agents: dict[str, dict], deployment_config: dict) -> None:
|
||||||
await _preflight()
|
await _preflight(deployment_config)
|
||||||
await _start_agent(name, agents)
|
await _start_agent(name, agents)
|
||||||
|
|
||||||
|
|
||||||
@@ -229,7 +268,7 @@ async def _start_all(config: dict) -> None:
|
|||||||
agent_deps = _build_agent_deps(config)
|
agent_deps = _build_agent_deps(config)
|
||||||
registry_port = config.get("registry_port", 24200)
|
registry_port = config.get("registry_port", 24200)
|
||||||
|
|
||||||
await _preflight()
|
await _preflight(config)
|
||||||
|
|
||||||
# Identify subagents that must start first.
|
# Identify subagents that must start first.
|
||||||
subagents: set[str] = set()
|
subagents: set[str] = set()
|
||||||
@@ -271,8 +310,8 @@ def main() -> None:
|
|||||||
epilog=(
|
epilog=(
|
||||||
"Port assignments:\n"
|
"Port assignments:\n"
|
||||||
+ "\n".join(
|
+ "\n".join(
|
||||||
f" {name:16s} port {port}"
|
f" {name:16s} port {entry['port']}"
|
||||||
for name, (_, port) in agents.items()
|
for name, entry in agents.items()
|
||||||
)
|
)
|
||||||
+ f"\n {'registry':16s} port {registry_port}"
|
+ f"\n {'registry':16s} port {registry_port}"
|
||||||
),
|
),
|
||||||
@@ -288,16 +327,16 @@ def main() -> None:
|
|||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
if args.agent:
|
if args.agent:
|
||||||
_, port = agents[args.agent]
|
port = agents[args.agent]["port"]
|
||||||
logger.info("Starting %s agent on port %d", args.agent, port)
|
logger.info("Starting %s agent on port %d", args.agent, port)
|
||||||
asyncio.run(_run_single(args.agent, agents))
|
asyncio.run(_run_single(args.agent, agents, config))
|
||||||
else:
|
else:
|
||||||
logger.info("Starting all agents + registry for %s", deploy_name)
|
logger.info("Starting all agents + registry for %s", deploy_name)
|
||||||
logger.info(
|
logger.info(
|
||||||
"registry → http://0.0.0.0:%d/.well-known/mcp/server.json", registry_port
|
"registry → http://0.0.0.0:%d/.well-known/mcp/server.json", registry_port
|
||||||
)
|
)
|
||||||
for name, (_, port) in agents.items():
|
for name, entry in agents.items():
|
||||||
logger.info("%-16s → http://0.0.0.0:%d/mcp", name, port)
|
logger.info("%-16s → http://0.0.0.0:%d/mcp", name, entry["port"])
|
||||||
asyncio.run(_start_all(config))
|
asyncio.run(_start_all(config))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user