diff --git a/pallas/server.py b/pallas/server.py index 77df015..6ee2022 100644 --- a/pallas/server.py +++ b/pallas/server.py @@ -46,10 +46,23 @@ def _load_deployment_config() -> dict: return config -def _build_agents_table(config: dict) -> dict[str, tuple[str, int]]: - """Build {name: (module_path, port)} from agents.yaml.""" +def _build_agents_table(config: dict) -> dict[str, dict]: + """Build {name: {module, port, model?, model_capabilities?}} from agents.yaml. + + The ``model`` and ``model_capabilities`` fields are optional. When ``model`` + is set, Pallas overrides the agent's ``AgentConfig.model`` at startup so + fast-agent routes that agent to the specified model/provider. When + ``model_capabilities`` is set, those capabilities are used to register the + model with fast-agent's ``ModelDatabase`` instead of the top-level defaults + from ``fastagent.config.yaml``. + """ return { - name: (agent["module"], agent["port"]) + name: { + "module": agent["module"], + "port": agent["port"], + "model": agent.get("model"), + "model_capabilities": agent.get("model_capabilities"), + } for name, agent in config["agents"].items() } @@ -109,34 +122,15 @@ def _preflight_mcp_servers(agent_name: str, servers: dict[str, dict]) -> None: # ── Model registration ──────────────────────────────────────────────────────── -def _register_unknown_models() -> None: - """Register runtime model params for models not in fast-agent's ModelDatabase. - - Reads ``default_model`` and ``model_capabilities`` from the active - fastagent.config.yaml. Capabilities (vision, context_window, - max_output_tokens) are declared explicitly in the config rather than - inferred from the model name, since naming conventions vary across - model families. - """ - config_path = _config_root() / "fastagent.config.yaml" - if not config_path.exists(): - return - +def _register_one_model(model_spec: str, capabilities: dict) -> None: + """Register a single model with fast-agent's ModelDatabase if unknown.""" from fast_agent.llm.model_database import ModelDatabase, ModelParameters - with open(config_path) as f: - config = yaml.safe_load(f) or {} - - default_model = config.get("default_model", "") - if not default_model: - return - - model_name = default_model.split(".", 1)[-1] if "." in default_model else default_model + model_name = model_spec.split(".", 1)[-1] if "." in model_spec else model_spec if ModelDatabase.get_model_params(model_name) is not None: return - capabilities = config.get("model_capabilities", {}) is_vision = capabilities.get("vision", False) context_window = capabilities.get("context_window", 131072) max_output_tokens = capabilities.get("max_output_tokens", 16384) @@ -158,22 +152,67 @@ def _register_unknown_models() -> None: ) +def _register_unknown_models(deployment_config: dict) -> None: + """Register runtime model params for models not in fast-agent's ModelDatabase. + + Registers the ``default_model`` from ``fastagent.config.yaml`` plus every + per-agent ``model`` declared in ``agents.yaml``. Capabilities are resolved + per model: if the agent carries its own ``model_capabilities`` block, those + take effect; otherwise the top-level ``model_capabilities`` from + ``fastagent.config.yaml`` apply. + """ + fastagent_config_path = _config_root() / "fastagent.config.yaml" + if not fastagent_config_path.exists(): + return + + with open(fastagent_config_path) as f: + fa_config = yaml.safe_load(f) or {} + + default_model = fa_config.get("default_model", "") + default_capabilities = fa_config.get("model_capabilities", {}) + + seen: set[str] = set() + + if default_model: + _register_one_model(default_model, default_capabilities) + seen.add(default_model) + + for agent_name, agent in deployment_config.get("agents", {}).items(): + agent_model = agent.get("model") + if not agent_model or agent_model in seen: + continue + agent_caps = agent.get("model_capabilities") or default_capabilities + _register_one_model(agent_model, agent_caps) + seen.add(agent_model) + + # ── Agent lifecycle ─────────────────────────────────────────────────────────── -async def _preflight() -> None: +async def _preflight(deployment_config: dict) -> None: from pallas.health import validate_llm_providers - _register_unknown_models() + _register_unknown_models(deployment_config) await validate_llm_providers() -async def _start_agent(name: str, agents: dict[str, tuple[str, int]]) -> None: +async def _start_agent(name: str, agents: dict[str, dict]) -> None: from pallas.health import register_health_tool - module_path, port = agents[name] + entry = agents[name] + module_path = entry["module"] + port = entry["port"] + model_override = entry.get("model") + module = importlib.import_module(module_path) fast_instance = module.fast + if model_override: + for agent_data in fast_instance.agents.values(): + agent_cfg = agent_data.get("config") + if agent_cfg is not None: + agent_cfg.model = model_override + logger.info("%s model override → %s", name, model_override) + logger.info("Starting %s agent on port %d", name, port) async with fast_instance.run(): @@ -198,12 +237,12 @@ async def _start_agent(name: str, agents: dict[str, tuple[str, int]]) -> None: async def _wait_for_agent( name: str, - agents: dict[str, tuple[str, int]], + agents: dict[str, dict], timeout: float = 60.0, ) -> None: import httpx - _, port = agents[name] + port = agents[name]["port"] url = f"http://127.0.0.1:{port}/mcp" deadline = asyncio.get_event_loop().time() + timeout while asyncio.get_event_loop().time() < deadline: @@ -217,8 +256,8 @@ async def _wait_for_agent( logger.warning("%s did not become ready within %.0fs", name, timeout) -async def _run_single(name: str, agents: dict[str, tuple[str, int]]) -> None: - await _preflight() +async def _run_single(name: str, agents: dict[str, dict], deployment_config: dict) -> None: + await _preflight(deployment_config) await _start_agent(name, agents) @@ -229,7 +268,7 @@ async def _start_all(config: dict) -> None: agent_deps = _build_agent_deps(config) registry_port = config.get("registry_port", 24200) - await _preflight() + await _preflight(config) # Identify subagents that must start first. subagents: set[str] = set() @@ -271,8 +310,8 @@ def main() -> None: epilog=( "Port assignments:\n" + "\n".join( - f" {name:16s} port {port}" - for name, (_, port) in agents.items() + f" {name:16s} port {entry['port']}" + for name, entry in agents.items() ) + f"\n {'registry':16s} port {registry_port}" ), @@ -288,16 +327,16 @@ def main() -> None: setup_logging() if args.agent: - _, port = agents[args.agent] + port = agents[args.agent]["port"] logger.info("Starting %s agent on port %d", args.agent, port) - asyncio.run(_run_single(args.agent, agents)) + asyncio.run(_run_single(args.agent, agents, config)) else: logger.info("Starting all agents + registry for %s", deploy_name) logger.info( "registry → http://0.0.0.0:%d/.well-known/mcp/server.json", registry_port ) - for name, (_, port) in agents.items(): - logger.info("%-16s → http://0.0.0.0:%d/mcp", name, port) + for name, entry in agents.items(): + logger.info("%-16s → http://0.0.0.0:%d/mcp", name, entry["port"]) asyncio.run(_start_all(config))