feat: scaffold stentor-gateway with FastAPI voice pipeline
Initialize the stentor-gateway project with WebSocket-based voice pipeline orchestrating STT → Agent → TTS via OpenAI-compatible APIs. - Add FastAPI app with WebSocket endpoint for audio streaming - Add pipeline orchestration (stt_client, tts_client, agent_client) - Add Pydantic Settings configuration and message models - Add audio utilities for PCM/WAV conversion and resampling - Add health check endpoints - Add Dockerfile and pyproject.toml with dependencies - Add initial test suite (pipeline, STT, TTS, WebSocket) - Add comprehensive README covering gateway and ESP32 ear design - Clean up .gitignore for Python/uv project
This commit is contained in:
49
stentor-gateway/Dockerfile
Normal file
49
stentor-gateway/Dockerfile
Normal file
@@ -0,0 +1,49 @@
|
||||
# Stentor Gateway — Multi-stage Dockerfile
|
||||
# Builds a minimal production image for the voice gateway.
|
||||
|
||||
# --- Build stage ---
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# Install uv for fast dependency resolution
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
|
||||
# Copy project files
|
||||
COPY pyproject.toml .
|
||||
COPY src/ src/
|
||||
|
||||
# Install dependencies into a virtual environment
|
||||
RUN uv venv /opt/venv && \
|
||||
uv pip install --python /opt/venv/bin/python . --no-cache
|
||||
|
||||
# --- Runtime stage ---
|
||||
FROM python:3.12-slim AS runtime
|
||||
|
||||
LABEL maintainer="robert"
|
||||
LABEL description="Stentor Gateway — Voice gateway for AI agents"
|
||||
LABEL version="0.1.0"
|
||||
|
||||
# Copy virtual environment from builder
|
||||
COPY --from=builder /opt/venv /opt/venv
|
||||
|
||||
# Copy application source (for templates/static)
|
||||
COPY src/ /app/src/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Use the venv Python
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
ENV PYTHONPATH="/app/src"
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Default configuration
|
||||
ENV STENTOR_HOST=0.0.0.0
|
||||
ENV STENTOR_PORT=8600
|
||||
|
||||
EXPOSE 8600
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
|
||||
CMD python -c "import httpx; r = httpx.get('http://localhost:8600/api/live/'); r.raise_for_status()"
|
||||
|
||||
CMD ["uvicorn", "stentor.main:app", "--host", "0.0.0.0", "--port", "8600"]
|
||||
45
stentor-gateway/pyproject.toml
Normal file
45
stentor-gateway/pyproject.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
[project]
|
||||
name = "stentor-gateway"
|
||||
version = "0.1.0"
|
||||
description = "Voice gateway connecting ESP32 audio hardware to AI agents via speech services"
|
||||
requires-python = ">=3.12"
|
||||
license = "MIT"
|
||||
dependencies = [
|
||||
"fastapi>=0.115",
|
||||
"uvicorn[standard]>=0.34",
|
||||
"websockets>=14.0",
|
||||
"httpx>=0.28",
|
||||
"pydantic>=2.10",
|
||||
"pydantic-settings>=2.7",
|
||||
"jinja2>=3.1",
|
||||
"prometheus-client>=0.21",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.25",
|
||||
"pytest-httpx>=0.35",
|
||||
"ruff>=0.9",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
stentor = "stentor.main:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/stentor"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "W", "UP", "B", "SIM", "RUF"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
3
stentor-gateway/src/stentor/__init__.py
Normal file
3
stentor-gateway/src/stentor/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Stentor — Voice Gateway for AI Agents."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
5
stentor-gateway/src/stentor/__main__.py
Normal file
5
stentor-gateway/src/stentor/__main__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Allow running Stentor Gateway with ``python -m stentor``."""
|
||||
|
||||
from stentor.main import main
|
||||
|
||||
main()
|
||||
62
stentor-gateway/src/stentor/agent_client.py
Normal file
62
stentor-gateway/src/stentor/agent_client.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""FastAgent HTTP client.
|
||||
|
||||
Thin adapter for communicating with FastAgent's HTTP transport.
|
||||
Designed to be easily swappable if the API shape changes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import httpx
|
||||
|
||||
from stentor.config import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentClient:
|
||||
"""Async client for FastAgent HTTP transport."""
|
||||
|
||||
def __init__(self, settings: Settings, http_client: httpx.AsyncClient) -> None:
|
||||
self._settings = settings
|
||||
self._http = http_client
|
||||
self._url = f"{settings.agent_url.rstrip('/')}/message"
|
||||
|
||||
async def send_message(self, content: str) -> str:
|
||||
"""Send a message to the agent and return the response.
|
||||
|
||||
Args:
|
||||
content: The user's transcribed message.
|
||||
|
||||
Returns:
|
||||
The agent's response text.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the agent returns an error.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
payload = {"content": content}
|
||||
|
||||
logger.debug("Agent request to %s: %r", self._url, content[:80])
|
||||
|
||||
response = await self._http.post(self._url, json=payload, timeout=60.0)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
text = result.get("content", "")
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info("Agent responded in %.2fs: %r", elapsed, text[:80])
|
||||
|
||||
return text
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Check if the agent service is reachable."""
|
||||
try:
|
||||
base_url = self._settings.agent_url.rstrip("/")
|
||||
response = await self._http.get(base_url, timeout=5.0)
|
||||
# Accept any non-5xx response as "available"
|
||||
return response.status_code < 500
|
||||
except (httpx.ConnectError, httpx.TimeoutException):
|
||||
return False
|
||||
127
stentor-gateway/src/stentor/audio.py
Normal file
127
stentor-gateway/src/stentor/audio.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Audio utilities for Stentor Gateway.
|
||||
|
||||
Handles PCM↔WAV conversion, resampling, and base64 encoding/decoding.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
|
||||
def pcm_to_wav(
|
||||
pcm_data: bytes,
|
||||
sample_rate: int = 16000,
|
||||
channels: int = 1,
|
||||
sample_width: int = 2,
|
||||
) -> bytes:
|
||||
"""Wrap raw PCM bytes in a WAV header.
|
||||
|
||||
Args:
|
||||
pcm_data: Raw PCM audio bytes (signed 16-bit little-endian).
|
||||
sample_rate: Sample rate in Hz.
|
||||
channels: Number of audio channels.
|
||||
sample_width: Bytes per sample (2 for 16-bit).
|
||||
|
||||
Returns:
|
||||
Complete WAV file as bytes.
|
||||
"""
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as wf:
|
||||
wf.setnchannels(channels)
|
||||
wf.setsampwidth(sample_width)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(pcm_data)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def wav_to_pcm(wav_data: bytes) -> tuple[bytes, int, int, int]:
|
||||
"""Extract raw PCM data from a WAV file.
|
||||
|
||||
Args:
|
||||
wav_data: Complete WAV file as bytes.
|
||||
|
||||
Returns:
|
||||
Tuple of (pcm_data, sample_rate, channels, sample_width).
|
||||
"""
|
||||
buf = io.BytesIO(wav_data)
|
||||
with wave.open(buf, "rb") as wf:
|
||||
pcm_data = wf.readframes(wf.getnframes())
|
||||
return pcm_data, wf.getframerate(), wf.getnchannels(), wf.getsampwidth()
|
||||
|
||||
|
||||
def resample_pcm(
|
||||
pcm_data: bytes,
|
||||
src_rate: int,
|
||||
dst_rate: int,
|
||||
sample_width: int = 2,
|
||||
channels: int = 1,
|
||||
) -> bytes:
|
||||
"""Resample PCM audio using linear interpolation.
|
||||
|
||||
Simple resampler with no external dependencies. Adequate for speech audio
|
||||
where the primary use case is converting TTS output (24kHz) to the ESP32
|
||||
playback rate (16kHz).
|
||||
|
||||
Args:
|
||||
pcm_data: Raw PCM bytes (signed 16-bit little-endian).
|
||||
src_rate: Source sample rate in Hz.
|
||||
dst_rate: Destination sample rate in Hz.
|
||||
sample_width: Bytes per sample (must be 2 for 16-bit).
|
||||
channels: Number of audio channels.
|
||||
|
||||
Returns:
|
||||
Resampled PCM bytes.
|
||||
"""
|
||||
if src_rate == dst_rate:
|
||||
return pcm_data
|
||||
|
||||
if sample_width != 2:
|
||||
msg = f"Only 16-bit PCM supported for resampling, got {sample_width * 8}-bit"
|
||||
raise ValueError(msg)
|
||||
|
||||
fmt = "<h"
|
||||
frame_size = sample_width * channels
|
||||
num_frames = len(pcm_data) // frame_size
|
||||
ratio = src_rate / dst_rate
|
||||
dst_frames = int(num_frames / ratio)
|
||||
|
||||
result = bytearray()
|
||||
|
||||
for ch in range(channels):
|
||||
# Extract samples for this channel
|
||||
samples = []
|
||||
for i in range(num_frames):
|
||||
offset = i * frame_size + ch * sample_width
|
||||
(sample,) = struct.unpack_from(fmt, pcm_data, offset)
|
||||
samples.append(sample)
|
||||
|
||||
# Resample via linear interpolation
|
||||
resampled = []
|
||||
for i in range(dst_frames):
|
||||
src_pos = i * ratio
|
||||
idx = int(src_pos)
|
||||
frac = src_pos - idx
|
||||
|
||||
if idx + 1 < len(samples):
|
||||
value = samples[idx] * (1.0 - frac) + samples[idx + 1] * frac
|
||||
else:
|
||||
value = samples[idx] if idx < len(samples) else 0
|
||||
|
||||
resampled.append(max(-32768, min(32767, int(value))))
|
||||
|
||||
# Pack resampled samples back
|
||||
for sample in resampled:
|
||||
result.extend(struct.pack(fmt, sample))
|
||||
|
||||
return bytes(result)
|
||||
|
||||
|
||||
def encode_audio(pcm_data: bytes) -> str:
|
||||
"""Base64-encode PCM audio data for WebSocket transmission."""
|
||||
return base64.b64encode(pcm_data).decode("ascii")
|
||||
|
||||
|
||||
def decode_audio(b64_data: str) -> bytes:
|
||||
"""Decode base64-encoded PCM audio data from WebSocket."""
|
||||
return base64.b64decode(b64_data)
|
||||
43
stentor-gateway/src/stentor/config.py
Normal file
43
stentor-gateway/src/stentor/config.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Pydantic Settings configuration for Stentor Gateway.
|
||||
|
||||
All configuration is driven by environment variables (12-factor).
|
||||
"""
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Stentor Gateway configuration."""
|
||||
|
||||
model_config = {"env_prefix": "STENTOR_"}
|
||||
|
||||
# Server
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8600
|
||||
|
||||
# Speaches endpoints (OpenAI-compatible)
|
||||
stt_url: str = "http://perseus.helu.ca:22070"
|
||||
tts_url: str = "http://perseus.helu.ca:22070"
|
||||
|
||||
# FastAgent endpoint
|
||||
agent_url: str = "http://localhost:8001"
|
||||
|
||||
# STT configuration
|
||||
stt_model: str = "Systran/faster-whisper-small"
|
||||
|
||||
# TTS configuration
|
||||
tts_model: str = "kokoro"
|
||||
tts_voice: str = "af_heart"
|
||||
|
||||
# Audio configuration
|
||||
audio_sample_rate: int = 16000
|
||||
audio_channels: int = 1
|
||||
audio_sample_width: int = 16
|
||||
|
||||
# Logging
|
||||
log_level: str = "DEBUG"
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Create and return a Settings instance."""
|
||||
return Settings()
|
||||
166
stentor-gateway/src/stentor/health.py
Normal file
166
stentor-gateway/src/stentor/health.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Health check and metrics endpoints for Stentor Gateway.
|
||||
|
||||
Follows Kubernetes health check conventions:
|
||||
- /api/live/ — Liveness probe (is the process alive?)
|
||||
- /api/ready/ — Readiness probe (are dependencies reachable?)
|
||||
- /api/metrics — Prometheus-compatible metrics
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, Response
|
||||
from prometheus_client import (
|
||||
Counter,
|
||||
Gauge,
|
||||
Histogram,
|
||||
generate_latest,
|
||||
)
|
||||
|
||||
from stentor.agent_client import AgentClient
|
||||
from stentor.stt_client import STTClient
|
||||
from stentor.tts_client import TTSClient
|
||||
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prometheus metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SESSIONS_ACTIVE = Gauge(
|
||||
"stentor_sessions_active",
|
||||
"Current active WebSocket sessions",
|
||||
)
|
||||
|
||||
TRANSCRIPTIONS_TOTAL = Counter(
|
||||
"stentor_transcriptions_total",
|
||||
"Total STT transcription calls",
|
||||
)
|
||||
|
||||
TTS_REQUESTS_TOTAL = Counter(
|
||||
"stentor_tts_requests_total",
|
||||
"Total TTS synthesis calls",
|
||||
)
|
||||
|
||||
AGENT_REQUESTS_TOTAL = Counter(
|
||||
"stentor_agent_requests_total",
|
||||
"Total agent message calls",
|
||||
)
|
||||
|
||||
PIPELINE_DURATION = Histogram(
|
||||
"stentor_pipeline_duration_seconds",
|
||||
"Full pipeline latency (STT + Agent + TTS)",
|
||||
buckets=(0.5, 1.0, 2.0, 3.0, 5.0, 10.0, 20.0, 30.0),
|
||||
)
|
||||
|
||||
STT_DURATION = Histogram(
|
||||
"stentor_stt_duration_seconds",
|
||||
"STT transcription latency",
|
||||
buckets=(0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0),
|
||||
)
|
||||
|
||||
TTS_DURATION = Histogram(
|
||||
"stentor_tts_duration_seconds",
|
||||
"TTS synthesis latency",
|
||||
buckets=(0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0),
|
||||
)
|
||||
|
||||
AGENT_DURATION = Histogram(
|
||||
"stentor_agent_duration_seconds",
|
||||
"Agent response latency",
|
||||
buckets=(0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0),
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Startup time tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_start_time: float = 0.0
|
||||
|
||||
|
||||
def record_start_time() -> None:
|
||||
"""Record the application start time."""
|
||||
global _start_time
|
||||
_start_time = time.monotonic()
|
||||
|
||||
|
||||
def get_uptime() -> float:
|
||||
"""Return uptime in seconds."""
|
||||
if _start_time == 0.0:
|
||||
return 0.0
|
||||
return time.monotonic() - _start_time
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dependency references (set during app startup)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_stt_client: STTClient | None = None
|
||||
_tts_client: TTSClient | None = None
|
||||
_agent_client: AgentClient | None = None
|
||||
|
||||
|
||||
def set_clients(stt: STTClient, tts: TTSClient, agent: AgentClient) -> None:
|
||||
"""Register service clients for health checks."""
|
||||
global _stt_client, _tts_client, _agent_client
|
||||
_stt_client = stt
|
||||
_tts_client = tts
|
||||
_agent_client = agent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/live/")
|
||||
async def liveness() -> dict:
|
||||
"""Liveness probe — is the process alive and responding?
|
||||
|
||||
Returns 200 if the application is running.
|
||||
Used by Kubernetes to determine if the pod should be restarted.
|
||||
"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/ready/")
|
||||
async def readiness(response: Response) -> dict:
|
||||
"""Readiness probe — are all dependencies reachable?
|
||||
|
||||
Checks connectivity to STT, TTS, and Agent services.
|
||||
Returns 200 if ready, 503 if any dependency is unavailable.
|
||||
Used by load balancers to determine if pod should receive traffic.
|
||||
"""
|
||||
checks: dict[str, bool] = {}
|
||||
|
||||
if _stt_client:
|
||||
checks["stt"] = await _stt_client.is_available()
|
||||
else:
|
||||
checks["stt"] = False
|
||||
|
||||
if _tts_client:
|
||||
checks["tts"] = await _tts_client.is_available()
|
||||
else:
|
||||
checks["tts"] = False
|
||||
|
||||
if _agent_client:
|
||||
checks["agent"] = await _agent_client.is_available()
|
||||
else:
|
||||
checks["agent"] = False
|
||||
|
||||
all_ready = all(checks.values())
|
||||
if not all_ready:
|
||||
response.status_code = 503
|
||||
|
||||
return {
|
||||
"status": "ready" if all_ready else "not_ready",
|
||||
"checks": checks,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def metrics() -> Response:
|
||||
"""Prometheus-compatible metrics endpoint."""
|
||||
return Response(
|
||||
content=generate_latest(),
|
||||
media_type="text/plain; version=0.0.4; charset=utf-8",
|
||||
)
|
||||
421
stentor-gateway/src/stentor/main.py
Normal file
421
stentor-gateway/src/stentor/main.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""Stentor Gateway — FastAPI application.
|
||||
|
||||
Central orchestrator connecting ESP32 audio hardware to AI agents via speech services.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from pydantic import ValidationError
|
||||
from starlette.requests import Request
|
||||
|
||||
from stentor import __version__
|
||||
from stentor.agent_client import AgentClient
|
||||
from stentor.audio import decode_audio
|
||||
from stentor.config import Settings, get_settings
|
||||
from stentor.health import (
|
||||
AGENT_DURATION,
|
||||
AGENT_REQUESTS_TOTAL,
|
||||
PIPELINE_DURATION,
|
||||
SESSIONS_ACTIVE,
|
||||
STT_DURATION,
|
||||
TRANSCRIPTIONS_TOTAL,
|
||||
TTS_DURATION,
|
||||
TTS_REQUESTS_TOTAL,
|
||||
get_uptime,
|
||||
record_start_time,
|
||||
set_clients,
|
||||
)
|
||||
from stentor.health import (
|
||||
router as health_router,
|
||||
)
|
||||
from stentor.models import (
|
||||
ErrorEvent,
|
||||
InputAudioBufferAppend,
|
||||
SessionCreated,
|
||||
SessionStart,
|
||||
StatusUpdate,
|
||||
)
|
||||
from stentor.pipeline import Pipeline, PipelineState
|
||||
from stentor.stt_client import STTClient
|
||||
from stentor.tts_client import TTSClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Paths for templates and static files
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
TEMPLATES_DIR = BASE_DIR / "templates"
|
||||
STATIC_DIR = BASE_DIR / "static"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""Active WebSocket session."""
|
||||
|
||||
session_id: str
|
||||
client_id: str
|
||||
websocket: WebSocket
|
||||
audio_buffer: bytearray = field(default_factory=bytearray)
|
||||
|
||||
def reset_buffer(self) -> bytes:
|
||||
"""Return buffered audio and reset the buffer."""
|
||||
data = bytes(self.audio_buffer)
|
||||
self.audio_buffer.clear()
|
||||
return data
|
||||
|
||||
|
||||
# Active sessions keyed by session_id
|
||||
_sessions: dict[str, Session] = {}
|
||||
|
||||
|
||||
def get_active_sessions() -> dict[str, Session]:
|
||||
"""Return the active sessions dict."""
|
||||
return _sessions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Application lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application startup and shutdown."""
|
||||
settings = get_settings()
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.log_level.upper(), logging.INFO),
|
||||
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
||||
)
|
||||
|
||||
logger.info("Stentor Gateway v%s starting", __version__)
|
||||
logger.info("STT: %s (model: %s)", settings.stt_url, settings.stt_model)
|
||||
logger.info(
|
||||
"TTS: %s (model: %s, voice: %s)",
|
||||
settings.tts_url, settings.tts_model, settings.tts_voice,
|
||||
)
|
||||
logger.info("Agent: %s", settings.agent_url)
|
||||
|
||||
record_start_time()
|
||||
|
||||
# Create shared HTTP client
|
||||
http_client = httpx.AsyncClient()
|
||||
|
||||
# Create service clients
|
||||
stt_client = STTClient(settings, http_client)
|
||||
tts_client = TTSClient(settings, http_client)
|
||||
agent_client = AgentClient(settings, http_client)
|
||||
|
||||
# Register clients for health checks
|
||||
set_clients(stt_client, tts_client, agent_client)
|
||||
|
||||
# Create pipeline state and pipeline
|
||||
pipeline_state = PipelineState()
|
||||
pipeline = Pipeline(settings, stt_client, tts_client, agent_client, pipeline_state)
|
||||
|
||||
# Store on app state for access in endpoints
|
||||
app.state.settings = settings
|
||||
app.state.http_client = http_client
|
||||
app.state.stt_client = stt_client
|
||||
app.state.tts_client = tts_client
|
||||
app.state.agent_client = agent_client
|
||||
app.state.pipeline = pipeline
|
||||
app.state.pipeline_state = pipeline_state
|
||||
|
||||
logger.info("Stentor Gateway ready on %s:%d", settings.host, settings.port)
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Stentor Gateway shutting down")
|
||||
await http_client.aclose()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI app
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
app = FastAPI(
|
||||
title="Stentor Gateway",
|
||||
version=__version__,
|
||||
description="Voice gateway connecting ESP32 audio hardware to AI agents via speech services.",
|
||||
lifespan=lifespan,
|
||||
docs_url="/api/docs",
|
||||
openapi_url="/api/openapi.json",
|
||||
)
|
||||
|
||||
# Include health/metrics routes
|
||||
app.include_router(health_router)
|
||||
|
||||
# Mount static files
|
||||
if STATIC_DIR.exists():
|
||||
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||||
|
||||
# Templates
|
||||
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dashboard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def dashboard(request: Request) -> HTMLResponse:
|
||||
"""Serve the Bootstrap dashboard."""
|
||||
settings: Settings = request.app.state.settings
|
||||
pipeline_state: PipelineState = request.app.state.pipeline_state
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"dashboard.html",
|
||||
{
|
||||
"request": request,
|
||||
"version": __version__,
|
||||
"settings": settings,
|
||||
"active_sessions": len(_sessions),
|
||||
"sessions": {
|
||||
sid: {"client_id": s.client_id, "buffer_size": len(s.audio_buffer)}
|
||||
for sid, s in _sessions.items()
|
||||
},
|
||||
"uptime": get_uptime(),
|
||||
"pipeline_state": pipeline_state,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API info endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.get("/api/v1/info")
|
||||
async def api_info() -> dict:
|
||||
"""Return gateway information."""
|
||||
settings: Settings = app.state.settings
|
||||
return {
|
||||
"name": "stentor-gateway",
|
||||
"version": __version__,
|
||||
"endpoints": {
|
||||
"realtime": "/api/v1/realtime",
|
||||
"live": "/api/live/",
|
||||
"ready": "/api/ready/",
|
||||
"metrics": "/api/metrics",
|
||||
},
|
||||
"config": {
|
||||
"stt_url": settings.stt_url,
|
||||
"tts_url": settings.tts_url,
|
||||
"agent_url": settings.agent_url,
|
||||
"stt_model": settings.stt_model,
|
||||
"tts_model": settings.tts_model,
|
||||
"tts_voice": settings.tts_voice,
|
||||
"audio_sample_rate": settings.audio_sample_rate,
|
||||
"audio_channels": settings.audio_channels,
|
||||
"audio_sample_width": settings.audio_sample_width,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WebSocket endpoint (OpenAI Realtime API-inspired)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.websocket("/api/v1/realtime")
|
||||
async def realtime_websocket(websocket: WebSocket) -> None:
|
||||
"""WebSocket endpoint for real-time audio conversations.
|
||||
|
||||
Protocol inspired by the OpenAI Realtime API:
|
||||
- Client sends session.start, then streams input_audio_buffer.append events
|
||||
- Client sends input_audio_buffer.commit to trigger pipeline
|
||||
- Gateway responds with status updates, transcript, response text, and audio
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
session: Session | None = None
|
||||
pipeline: Pipeline = websocket.app.state.pipeline
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = await websocket.receive_text()
|
||||
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
await _send_error(websocket, "Invalid JSON", "invalid_json")
|
||||
continue
|
||||
|
||||
msg_type = data.get("type", "")
|
||||
|
||||
# --- session.start ---
|
||||
if msg_type == "session.start":
|
||||
try:
|
||||
msg = SessionStart.model_validate(data)
|
||||
except ValidationError as e:
|
||||
await _send_error(websocket, str(e), "validation_error")
|
||||
continue
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
session = Session(
|
||||
session_id=session_id,
|
||||
client_id=msg.client_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
_sessions[session_id] = session
|
||||
SESSIONS_ACTIVE.inc()
|
||||
|
||||
logger.info(
|
||||
"Session started: %s (client: %s)",
|
||||
session_id,
|
||||
msg.client_id,
|
||||
)
|
||||
|
||||
await websocket.send_text(
|
||||
SessionCreated(session_id=session_id).model_dump_json()
|
||||
)
|
||||
await websocket.send_text(
|
||||
StatusUpdate(state="listening").model_dump_json()
|
||||
)
|
||||
|
||||
# --- input_audio_buffer.append ---
|
||||
elif msg_type == "input_audio_buffer.append":
|
||||
if not session:
|
||||
await _send_error(websocket, "No active session", "no_session")
|
||||
continue
|
||||
|
||||
try:
|
||||
msg = InputAudioBufferAppend.model_validate(data)
|
||||
except ValidationError as e:
|
||||
await _send_error(websocket, str(e), "validation_error")
|
||||
continue
|
||||
|
||||
pcm_chunk = decode_audio(msg.audio)
|
||||
session.audio_buffer.extend(pcm_chunk)
|
||||
|
||||
# --- input_audio_buffer.commit ---
|
||||
elif msg_type == "input_audio_buffer.commit":
|
||||
if not session:
|
||||
await _send_error(websocket, "No active session", "no_session")
|
||||
continue
|
||||
|
||||
audio_data = session.reset_buffer()
|
||||
|
||||
if not audio_data:
|
||||
await _send_error(websocket, "Empty audio buffer", "empty_buffer")
|
||||
await websocket.send_text(
|
||||
StatusUpdate(state="listening").model_dump_json()
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Processing %d bytes of audio for session %s",
|
||||
len(audio_data),
|
||||
session.session_id,
|
||||
)
|
||||
|
||||
# Run the pipeline and stream events back
|
||||
async for event in pipeline.process(audio_data):
|
||||
await websocket.send_text(event.model_dump_json())
|
||||
|
||||
# Record Prometheus metrics from pipeline events
|
||||
if event.type == "transcript.done":
|
||||
TRANSCRIPTIONS_TOTAL.inc()
|
||||
elif event.type == "response.text.done":
|
||||
AGENT_REQUESTS_TOTAL.inc()
|
||||
elif event.type == "response.audio.done":
|
||||
TTS_REQUESTS_TOTAL.inc()
|
||||
|
||||
# Record pipeline duration from state
|
||||
pipeline_state: PipelineState = websocket.app.state.pipeline_state
|
||||
if pipeline_state.recent_metrics:
|
||||
last = pipeline_state.recent_metrics[-1]
|
||||
PIPELINE_DURATION.observe(last.total_duration)
|
||||
STT_DURATION.observe(last.stt_duration)
|
||||
AGENT_DURATION.observe(last.agent_duration)
|
||||
TTS_DURATION.observe(last.tts_duration)
|
||||
|
||||
# Return to listening state
|
||||
await websocket.send_text(
|
||||
StatusUpdate(state="listening").model_dump_json()
|
||||
)
|
||||
|
||||
# --- session.close ---
|
||||
elif msg_type == "session.close":
|
||||
logger.info("Session close requested")
|
||||
break
|
||||
|
||||
else:
|
||||
await _send_error(
|
||||
websocket, f"Unknown event type: {msg_type}", "unknown_event"
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket disconnected")
|
||||
|
||||
except Exception:
|
||||
logger.exception("WebSocket error")
|
||||
with suppress(Exception):
|
||||
await _send_error(websocket, "Internal server error", "internal_error")
|
||||
|
||||
finally:
|
||||
if session and session.session_id in _sessions:
|
||||
del _sessions[session.session_id]
|
||||
SESSIONS_ACTIVE.dec()
|
||||
logger.info("Session ended: %s", session.session_id)
|
||||
|
||||
|
||||
async def _send_error(websocket: WebSocket, message: str, code: str) -> None:
|
||||
"""Send an error event to the client."""
|
||||
event = ErrorEvent(message=message, code=code)
|
||||
await websocket.send_text(event.model_dump_json())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Start the Stentor Gateway server.
|
||||
|
||||
This is the main entry point used by:
|
||||
- ``python -m stentor``
|
||||
- The ``stentor`` console script (installed via pip)
|
||||
- ``python main.py`` (when run directly)
|
||||
"""
|
||||
import uvicorn
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Configure root logging early so startup messages are visible
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.log_level.upper(), logging.INFO),
|
||||
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
||||
)
|
||||
|
||||
logger.info("Starting Stentor Gateway v%s on %s:%d", __version__, settings.host, settings.port)
|
||||
|
||||
uvicorn.run(
|
||||
"stentor.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
log_level=settings.log_level.lower(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
129
stentor-gateway/src/stentor/models.py
Normal file
129
stentor-gateway/src/stentor/models.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""WebSocket message models for Stentor Gateway.
|
||||
|
||||
Inspired by the OpenAI Realtime API event naming conventions.
|
||||
Messages are JSON with base64-encoded audio data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Client → Gateway events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AudioConfig(BaseModel):
|
||||
"""Audio configuration sent by the client on session start."""
|
||||
|
||||
sample_rate: int = 16000
|
||||
channels: int = 1
|
||||
sample_width: int = 16
|
||||
encoding: str = "pcm_s16le"
|
||||
|
||||
|
||||
class SessionStart(BaseModel):
|
||||
"""Client requests a new session."""
|
||||
|
||||
type: Literal["session.start"] = "session.start"
|
||||
client_id: str = ""
|
||||
audio_config: AudioConfig = Field(default_factory=AudioConfig)
|
||||
|
||||
|
||||
class InputAudioBufferAppend(BaseModel):
|
||||
"""Client sends a chunk of audio data."""
|
||||
|
||||
type: Literal["input_audio_buffer.append"] = "input_audio_buffer.append"
|
||||
audio: str # base64-encoded PCM
|
||||
|
||||
|
||||
class InputAudioBufferCommit(BaseModel):
|
||||
"""Client signals end of speech / commits the audio buffer."""
|
||||
|
||||
type: Literal["input_audio_buffer.commit"] = "input_audio_buffer.commit"
|
||||
|
||||
|
||||
class SessionClose(BaseModel):
|
||||
"""Client requests session termination."""
|
||||
|
||||
type: Literal["session.close"] = "session.close"
|
||||
|
||||
|
||||
ClientEvent = Annotated[
|
||||
SessionStart | InputAudioBufferAppend | InputAudioBufferCommit | SessionClose,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gateway → Client events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SessionCreated(BaseModel):
|
||||
"""Acknowledge session creation."""
|
||||
|
||||
type: Literal["session.created"] = "session.created"
|
||||
session_id: str
|
||||
|
||||
|
||||
class StatusUpdate(BaseModel):
|
||||
"""Gateway processing status update."""
|
||||
|
||||
type: Literal["status"] = "status"
|
||||
state: Literal["listening", "transcribing", "thinking", "speaking"]
|
||||
|
||||
|
||||
class TranscriptDone(BaseModel):
|
||||
"""Transcript of what the user said."""
|
||||
|
||||
type: Literal["transcript.done"] = "transcript.done"
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseTextDone(BaseModel):
|
||||
"""AI agent response text."""
|
||||
|
||||
type: Literal["response.text.done"] = "response.text.done"
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseAudioDelta(BaseModel):
|
||||
"""Streamed audio response chunk."""
|
||||
|
||||
type: Literal["response.audio.delta"] = "response.audio.delta"
|
||||
delta: str # base64-encoded PCM
|
||||
|
||||
|
||||
class ResponseAudioDone(BaseModel):
|
||||
"""Audio response streaming complete."""
|
||||
|
||||
type: Literal["response.audio.done"] = "response.audio.done"
|
||||
|
||||
|
||||
class ResponseDone(BaseModel):
|
||||
"""Full response cycle complete."""
|
||||
|
||||
type: Literal["response.done"] = "response.done"
|
||||
|
||||
|
||||
class ErrorEvent(BaseModel):
|
||||
"""Error event."""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
message: str
|
||||
code: str = "unknown_error"
|
||||
|
||||
|
||||
ServerEvent = (
|
||||
SessionCreated
|
||||
| StatusUpdate
|
||||
| TranscriptDone
|
||||
| ResponseTextDone
|
||||
| ResponseAudioDelta
|
||||
| ResponseAudioDone
|
||||
| ResponseDone
|
||||
| ErrorEvent
|
||||
)
|
||||
175
stentor-gateway/src/stentor/pipeline.py
Normal file
175
stentor-gateway/src/stentor/pipeline.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Voice pipeline orchestrator: STT → Agent → TTS.
|
||||
|
||||
Ties together the three service clients into a single processing pipeline.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from stentor.agent_client import AgentClient
|
||||
from stentor.audio import encode_audio, pcm_to_wav, resample_pcm
|
||||
from stentor.config import Settings
|
||||
from stentor.models import (
|
||||
ErrorEvent,
|
||||
ResponseAudioDelta,
|
||||
ResponseAudioDone,
|
||||
ResponseDone,
|
||||
ResponseTextDone,
|
||||
ServerEvent,
|
||||
StatusUpdate,
|
||||
TranscriptDone,
|
||||
)
|
||||
from stentor.stt_client import STTClient
|
||||
from stentor.tts_client import TTS_OUTPUT_SAMPLE_RATE, TTSClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Audio chunk size for streaming to client (2KB ≈ 64ms at 16kHz/16bit/mono)
|
||||
AUDIO_CHUNK_SIZE = 2048
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineMetrics:
|
||||
"""Accumulated metrics for a single pipeline run."""
|
||||
|
||||
stt_duration: float = 0.0
|
||||
agent_duration: float = 0.0
|
||||
tts_duration: float = 0.0
|
||||
total_duration: float = 0.0
|
||||
transcript: str = ""
|
||||
response_text: str = ""
|
||||
audio_bytes_sent: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineState:
|
||||
"""Shared state for pipeline metrics collection."""
|
||||
|
||||
total_transcriptions: int = 0
|
||||
total_tts_requests: int = 0
|
||||
total_agent_requests: int = 0
|
||||
recent_metrics: list[PipelineMetrics] = field(default_factory=list)
|
||||
|
||||
def record(self, metrics: PipelineMetrics) -> None:
|
||||
"""Record metrics from a pipeline run."""
|
||||
self.total_transcriptions += 1
|
||||
self.total_agent_requests += 1
|
||||
self.total_tts_requests += 1
|
||||
self.recent_metrics.append(metrics)
|
||||
# Keep only last 100 runs
|
||||
if len(self.recent_metrics) > 100:
|
||||
self.recent_metrics = self.recent_metrics[-100:]
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""Orchestrates the STT → Agent → TTS voice pipeline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
stt: STTClient,
|
||||
tts: TTSClient,
|
||||
agent: AgentClient,
|
||||
state: PipelineState,
|
||||
) -> None:
|
||||
self._settings = settings
|
||||
self._stt = stt
|
||||
self._tts = tts
|
||||
self._agent = agent
|
||||
self._state = state
|
||||
|
||||
async def process(self, audio_buffer: bytes) -> AsyncIterator[ServerEvent]:
|
||||
"""Run the full voice pipeline on buffered audio.
|
||||
|
||||
Yields server events as the pipeline progresses through each stage:
|
||||
status updates, transcript, response text, audio chunks, and completion.
|
||||
|
||||
Args:
|
||||
audio_buffer: Raw PCM audio bytes from the client.
|
||||
|
||||
Yields:
|
||||
ServerEvent instances to send back to the client.
|
||||
"""
|
||||
pipeline_start = time.monotonic()
|
||||
metrics = PipelineMetrics()
|
||||
|
||||
try:
|
||||
# --- Stage 1: STT ---
|
||||
yield StatusUpdate(state="transcribing")
|
||||
|
||||
wav_data = pcm_to_wav(
|
||||
audio_buffer,
|
||||
sample_rate=self._settings.audio_sample_rate,
|
||||
channels=self._settings.audio_channels,
|
||||
)
|
||||
|
||||
stt_start = time.monotonic()
|
||||
transcript = await self._stt.transcribe(wav_data)
|
||||
metrics.stt_duration = time.monotonic() - stt_start
|
||||
metrics.transcript = transcript
|
||||
|
||||
if not transcript:
|
||||
logger.warning("STT returned empty transcript")
|
||||
yield ErrorEvent(message="No speech detected", code="empty_transcript")
|
||||
return
|
||||
|
||||
yield TranscriptDone(text=transcript)
|
||||
|
||||
# --- Stage 2: Agent ---
|
||||
yield StatusUpdate(state="thinking")
|
||||
|
||||
agent_start = time.monotonic()
|
||||
response_text = await self._agent.send_message(transcript)
|
||||
metrics.agent_duration = time.monotonic() - agent_start
|
||||
metrics.response_text = response_text
|
||||
|
||||
if not response_text:
|
||||
logger.warning("Agent returned empty response")
|
||||
yield ErrorEvent(message="Agent returned empty response", code="empty_response")
|
||||
return
|
||||
|
||||
yield ResponseTextDone(text=response_text)
|
||||
|
||||
# --- Stage 3: TTS ---
|
||||
yield StatusUpdate(state="speaking")
|
||||
|
||||
tts_start = time.monotonic()
|
||||
tts_audio = await self._tts.synthesize(response_text)
|
||||
metrics.tts_duration = time.monotonic() - tts_start
|
||||
|
||||
# Resample from TTS output rate to client playback rate if needed
|
||||
target_rate = self._settings.audio_sample_rate
|
||||
if target_rate != TTS_OUTPUT_SAMPLE_RATE:
|
||||
tts_audio = resample_pcm(
|
||||
tts_audio,
|
||||
src_rate=TTS_OUTPUT_SAMPLE_RATE,
|
||||
dst_rate=target_rate,
|
||||
)
|
||||
|
||||
# Stream audio in chunks
|
||||
offset = 0
|
||||
while offset < len(tts_audio):
|
||||
chunk = tts_audio[offset : offset + AUDIO_CHUNK_SIZE]
|
||||
yield ResponseAudioDelta(delta=encode_audio(chunk))
|
||||
metrics.audio_bytes_sent += len(chunk)
|
||||
offset += AUDIO_CHUNK_SIZE
|
||||
|
||||
yield ResponseAudioDone()
|
||||
yield ResponseDone()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Pipeline error")
|
||||
yield ErrorEvent(message="Internal pipeline error", code="pipeline_error")
|
||||
|
||||
finally:
|
||||
metrics.total_duration = time.monotonic() - pipeline_start
|
||||
self._state.record(metrics)
|
||||
logger.info(
|
||||
"Pipeline complete: stt=%.2fs agent=%.2fs tts=%.2fs total=%.2fs",
|
||||
metrics.stt_duration,
|
||||
metrics.agent_duration,
|
||||
metrics.tts_duration,
|
||||
metrics.total_duration,
|
||||
)
|
||||
0
stentor-gateway/src/stentor/static/.gitkeep
Normal file
0
stentor-gateway/src/stentor/static/.gitkeep
Normal file
68
stentor-gateway/src/stentor/stt_client.py
Normal file
68
stentor-gateway/src/stentor/stt_client.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Speaches STT client (OpenAI-compatible).
|
||||
|
||||
Posts audio as multipart/form-data to the /v1/audio/transcriptions endpoint.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import httpx
|
||||
|
||||
from stentor.config import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class STTClient:
|
||||
"""Async client for Speaches Speech-to-Text service."""
|
||||
|
||||
def __init__(self, settings: Settings, http_client: httpx.AsyncClient) -> None:
|
||||
self._settings = settings
|
||||
self._http = http_client
|
||||
self._url = f"{settings.stt_url.rstrip('/')}/v1/audio/transcriptions"
|
||||
|
||||
async def transcribe(self, wav_data: bytes, language: str | None = None) -> str:
|
||||
"""Send audio to Speaches STT and return the transcript.
|
||||
|
||||
Args:
|
||||
wav_data: Complete WAV file bytes.
|
||||
language: Optional language code (e.g., "en"). Auto-detect if omitted.
|
||||
|
||||
Returns:
|
||||
Transcribed text.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the STT service returns an error.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
files = {"file": ("audio.wav", wav_data, "audio/wav")}
|
||||
data: dict[str, str] = {
|
||||
"model": self._settings.stt_model,
|
||||
"response_format": "json",
|
||||
}
|
||||
if language:
|
||||
data["language"] = language
|
||||
|
||||
logger.debug("STT request to %s (model=%s)", self._url, self._settings.stt_model)
|
||||
|
||||
response = await self._http.post(self._url, files=files, data=data, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
text = result.get("text", "").strip()
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info("STT completed in %.2fs: %r", elapsed, text[:80])
|
||||
|
||||
return text
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Check if the STT service is reachable."""
|
||||
try:
|
||||
# Hit the base URL to check availability
|
||||
base_url = self._settings.stt_url.rstrip("/")
|
||||
response = await self._http.get(f"{base_url}/v1/models", timeout=5.0)
|
||||
return response.status_code == 200
|
||||
except (httpx.ConnectError, httpx.TimeoutException):
|
||||
return False
|
||||
361
stentor-gateway/src/stentor/templates/dashboard.html
Normal file
361
stentor-gateway/src/stentor/templates/dashboard.html
Normal file
@@ -0,0 +1,361 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en" data-bs-theme="dark">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Stentor Gateway — Dashboard</title>
|
||||
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
||||
<link href="https://cdn.jsdelivr.net/npm/bootstrap-icons@1.11.3/font/bootstrap-icons.min.css" rel="stylesheet">
|
||||
<style>
|
||||
.status-dot {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
display: inline-block;
|
||||
margin-right: 8px;
|
||||
}
|
||||
.status-dot.ok { background-color: #198754; }
|
||||
.status-dot.error { background-color: #dc3545; }
|
||||
.status-dot.unknown { background-color: #6c757d; }
|
||||
.status-dot.checking { background-color: #ffc107; animation: pulse 1s infinite; }
|
||||
@keyframes pulse {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.4; }
|
||||
}
|
||||
.card-metric {
|
||||
font-size: 2rem;
|
||||
font-weight: 700;
|
||||
}
|
||||
.card-metric-label {
|
||||
font-size: 0.85rem;
|
||||
color: var(--bs-secondary-color);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
}
|
||||
.hero-quote {
|
||||
font-style: italic;
|
||||
color: var(--bs-secondary-color);
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<!-- Navbar -->
|
||||
<nav class="navbar navbar-expand-lg bg-body-tertiary border-bottom">
|
||||
<div class="container-fluid">
|
||||
<a class="navbar-brand d-flex align-items-center" href="/">
|
||||
<i class="bi bi-soundwave me-2 fs-4"></i>
|
||||
<strong>Stentor</strong> Gateway
|
||||
</a>
|
||||
<span class="navbar-text">
|
||||
<span class="badge bg-secondary">v{{ version }}</span>
|
||||
</span>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<div class="container-fluid py-4">
|
||||
<!-- Header -->
|
||||
<div class="row mb-4">
|
||||
<div class="col">
|
||||
<p class="hero-quote mb-0">
|
||||
“Stentor, whose voice was as powerful as fifty voices of other men.”
|
||||
— Homer, <em>Iliad</em>, Book V
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Service Connectivity -->
|
||||
<div class="row mb-4">
|
||||
<div class="col-12">
|
||||
<h5 class="mb-3"><i class="bi bi-diagram-3 me-2"></i>Service Connectivity</h5>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<div class="card h-100">
|
||||
<div class="card-body">
|
||||
<div class="d-flex align-items-center mb-2">
|
||||
<span class="status-dot checking" id="stt-status-dot"></span>
|
||||
<h6 class="card-title mb-0">Speaches STT</h6>
|
||||
</div>
|
||||
<p class="card-text text-body-secondary small mb-1">
|
||||
<i class="bi bi-mic me-1"></i>Speech-to-Text
|
||||
</p>
|
||||
<code class="small">{{ settings.stt_url }}</code>
|
||||
<div class="mt-2 small">
|
||||
<span class="text-body-secondary">Model:</span>
|
||||
<span>{{ settings.stt_model }}</span>
|
||||
</div>
|
||||
<div class="mt-2" id="stt-status-text">
|
||||
<span class="badge bg-warning"><i class="bi bi-hourglass-split me-1"></i>Checking...</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<div class="card h-100">
|
||||
<div class="card-body">
|
||||
<div class="d-flex align-items-center mb-2">
|
||||
<span class="status-dot checking" id="tts-status-dot"></span>
|
||||
<h6 class="card-title mb-0">Speaches TTS</h6>
|
||||
</div>
|
||||
<p class="card-text text-body-secondary small mb-1">
|
||||
<i class="bi bi-volume-up me-1"></i>Text-to-Speech
|
||||
</p>
|
||||
<code class="small">{{ settings.tts_url }}</code>
|
||||
<div class="mt-2 small">
|
||||
<span class="text-body-secondary">Model:</span>
|
||||
<span>{{ settings.tts_model }}</span>
|
||||
<span class="text-body-secondary ms-2">Voice:</span>
|
||||
<span>{{ settings.tts_voice }}</span>
|
||||
</div>
|
||||
<div class="mt-2" id="tts-status-text">
|
||||
<span class="badge bg-warning"><i class="bi bi-hourglass-split me-1"></i>Checking...</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<div class="card h-100">
|
||||
<div class="card-body">
|
||||
<div class="d-flex align-items-center mb-2">
|
||||
<span class="status-dot checking" id="agent-status-dot"></span>
|
||||
<h6 class="card-title mb-0">FastAgent</h6>
|
||||
</div>
|
||||
<p class="card-text text-body-secondary small mb-1">
|
||||
<i class="bi bi-robot me-1"></i>AI Agent
|
||||
</p>
|
||||
<code class="small">{{ settings.agent_url }}</code>
|
||||
<div class="mt-2" id="agent-status-text">
|
||||
<span class="badge bg-warning"><i class="bi bi-hourglass-split me-1"></i>Checking...</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Metrics Row -->
|
||||
<div class="row mb-4">
|
||||
<div class="col-12">
|
||||
<h5 class="mb-3"><i class="bi bi-speedometer2 me-2"></i>Metrics</h5>
|
||||
</div>
|
||||
<div class="col-sm-6 col-lg-3 mb-3">
|
||||
<div class="card text-center h-100">
|
||||
<div class="card-body">
|
||||
<div class="card-metric" id="metric-sessions">{{ active_sessions }}</div>
|
||||
<div class="card-metric-label">Active Sessions</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-sm-6 col-lg-3 mb-3">
|
||||
<div class="card text-center h-100">
|
||||
<div class="card-body">
|
||||
<div class="card-metric" id="metric-transcriptions">{{ pipeline_state.total_transcriptions }}</div>
|
||||
<div class="card-metric-label">Transcriptions</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-sm-6 col-lg-3 mb-3">
|
||||
<div class="card text-center h-100">
|
||||
<div class="card-body">
|
||||
<div class="card-metric" id="metric-tts">{{ pipeline_state.total_tts_requests }}</div>
|
||||
<div class="card-metric-label">TTS Requests</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-sm-6 col-lg-3 mb-3">
|
||||
<div class="card text-center h-100">
|
||||
<div class="card-body">
|
||||
<div class="card-metric" id="metric-uptime">—</div>
|
||||
<div class="card-metric-label">Uptime</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Connected Clients -->
|
||||
<div class="row mb-4">
|
||||
<div class="col-lg-6 mb-3">
|
||||
<h5 class="mb-3"><i class="bi bi-speaker me-2"></i>Connected Clients</h5>
|
||||
<div class="card">
|
||||
<div class="card-body">
|
||||
{% if sessions %}
|
||||
<div class="table-responsive">
|
||||
<table class="table table-sm table-hover mb-0">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Session ID</th>
|
||||
<th>Client ID</th>
|
||||
<th>Buffer</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for sid, info in sessions.items() %}
|
||||
<tr>
|
||||
<td><code class="small">{{ sid[:8] }}...</code></td>
|
||||
<td>{{ info.client_id or "—" }}</td>
|
||||
<td>{{ info.buffer_size }} bytes</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
{% else %}
|
||||
<p class="text-body-secondary mb-0">
|
||||
<i class="bi bi-info-circle me-1"></i>No clients connected.
|
||||
</p>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Audio Configuration -->
|
||||
<div class="col-lg-6 mb-3">
|
||||
<h5 class="mb-3"><i class="bi bi-sliders me-2"></i>Audio Configuration</h5>
|
||||
<div class="card">
|
||||
<div class="card-body">
|
||||
<table class="table table-sm mb-0">
|
||||
<tbody>
|
||||
<tr>
|
||||
<td class="text-body-secondary">Sample Rate</td>
|
||||
<td>{{ settings.audio_sample_rate }} Hz</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="text-body-secondary">Channels</td>
|
||||
<td>{{ settings.audio_channels }} (mono)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="text-body-secondary">Sample Width</td>
|
||||
<td>{{ settings.audio_sample_width }}-bit</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="text-body-secondary">Encoding</td>
|
||||
<td>PCM S16LE</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="text-body-secondary">WebSocket</td>
|
||||
<td><code>/api/v1/realtime</code></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Recent Pipeline Runs -->
|
||||
<div class="row mb-4">
|
||||
<div class="col-12">
|
||||
<h5 class="mb-3"><i class="bi bi-clock-history me-2"></i>Recent Pipeline Runs</h5>
|
||||
<div class="card">
|
||||
<div class="card-body">
|
||||
{% if pipeline_state.recent_metrics %}
|
||||
<div class="table-responsive">
|
||||
<table class="table table-sm table-hover mb-0">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Transcript</th>
|
||||
<th>Response</th>
|
||||
<th>STT</th>
|
||||
<th>Agent</th>
|
||||
<th>TTS</th>
|
||||
<th>Total</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for m in pipeline_state.recent_metrics[-10:]|reverse %}
|
||||
<tr>
|
||||
<td class="small">{{ m.transcript[:50] }}{% if m.transcript|length > 50 %}...{% endif %}</td>
|
||||
<td class="small">{{ m.response_text[:50] }}{% if m.response_text|length > 50 %}...{% endif %}</td>
|
||||
<td><span class="badge bg-info">{{ "%.2f"|format(m.stt_duration) }}s</span></td>
|
||||
<td><span class="badge bg-primary">{{ "%.2f"|format(m.agent_duration) }}s</span></td>
|
||||
<td><span class="badge bg-success">{{ "%.2f"|format(m.tts_duration) }}s</span></td>
|
||||
<td><span class="badge bg-secondary">{{ "%.2f"|format(m.total_duration) }}s</span></td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
{% else %}
|
||||
<p class="text-body-secondary mb-0">
|
||||
<i class="bi bi-info-circle me-1"></i>No pipeline runs yet.
|
||||
</p>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Footer -->
|
||||
<div class="row">
|
||||
<div class="col text-center">
|
||||
<p class="text-body-secondary small">
|
||||
Stentor Gateway v{{ version }} ·
|
||||
<a href="/api/docs" class="text-decoration-none">API Docs</a> ·
|
||||
<a href="/api/metrics" class="text-decoration-none">Metrics</a> ·
|
||||
<a href="/api/ready/" class="text-decoration-none">Readiness</a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.3/dist/js/bootstrap.bundle.min.js"></script>
|
||||
<script>
|
||||
// Format uptime
|
||||
function formatUptime(seconds) {
|
||||
if (seconds < 60) return Math.floor(seconds) + 's';
|
||||
if (seconds < 3600) return Math.floor(seconds / 60) + 'm ' + Math.floor(seconds % 60) + 's';
|
||||
const h = Math.floor(seconds / 3600);
|
||||
const m = Math.floor((seconds % 3600) / 60);
|
||||
return h + 'h ' + m + 'm';
|
||||
}
|
||||
|
||||
// Update uptime display
|
||||
const uptimeSeconds = {{ uptime }};
|
||||
const startTime = Date.now() / 1000 - uptimeSeconds;
|
||||
function updateUptime() {
|
||||
const now = Date.now() / 1000;
|
||||
document.getElementById('metric-uptime').textContent = formatUptime(now - startTime);
|
||||
}
|
||||
updateUptime();
|
||||
setInterval(updateUptime, 1000);
|
||||
|
||||
// Check service connectivity
|
||||
async function checkServices() {
|
||||
try {
|
||||
const resp = await fetch('/api/ready/');
|
||||
const data = await resp.json();
|
||||
|
||||
updateServiceStatus('stt', data.checks.stt);
|
||||
updateServiceStatus('tts', data.checks.tts);
|
||||
updateServiceStatus('agent', data.checks.agent);
|
||||
} catch (e) {
|
||||
console.error('Failed to check services:', e);
|
||||
updateServiceStatus('stt', null);
|
||||
updateServiceStatus('tts', null);
|
||||
updateServiceStatus('agent', null);
|
||||
}
|
||||
}
|
||||
|
||||
function updateServiceStatus(service, available) {
|
||||
const dot = document.getElementById(service + '-status-dot');
|
||||
const text = document.getElementById(service + '-status-text');
|
||||
|
||||
dot.className = 'status-dot';
|
||||
|
||||
if (available === true) {
|
||||
dot.classList.add('ok');
|
||||
text.innerHTML = '<span class="badge bg-success"><i class="bi bi-check-circle me-1"></i>Connected</span>';
|
||||
} else if (available === false) {
|
||||
dot.classList.add('error');
|
||||
text.innerHTML = '<span class="badge bg-danger"><i class="bi bi-x-circle me-1"></i>Unavailable</span>';
|
||||
} else {
|
||||
dot.classList.add('unknown');
|
||||
text.innerHTML = '<span class="badge bg-secondary"><i class="bi bi-question-circle me-1"></i>Unknown</span>';
|
||||
}
|
||||
}
|
||||
|
||||
// Check on load and every 30 seconds
|
||||
checkServices();
|
||||
setInterval(checkServices, 30000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
102
stentor-gateway/src/stentor/tts_client.py
Normal file
102
stentor-gateway/src/stentor/tts_client.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Speaches TTS client (OpenAI-compatible).
|
||||
|
||||
Posts JSON to the /v1/audio/speech endpoint and streams back PCM audio.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
from stentor.config import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Speaches TTS typically outputs at 24kHz
|
||||
TTS_OUTPUT_SAMPLE_RATE = 24000
|
||||
|
||||
# Chunk size for streaming TTS audio (4KB ≈ 85ms at 24kHz/16bit/mono)
|
||||
TTS_STREAM_CHUNK_SIZE = 4096
|
||||
|
||||
|
||||
class TTSClient:
|
||||
"""Async client for Speaches Text-to-Speech service."""
|
||||
|
||||
def __init__(self, settings: Settings, http_client: httpx.AsyncClient) -> None:
|
||||
self._settings = settings
|
||||
self._http = http_client
|
||||
self._url = f"{settings.tts_url.rstrip('/')}/v1/audio/speech"
|
||||
|
||||
async def synthesize(self, text: str) -> bytes:
|
||||
"""Synthesize speech from text, returning complete PCM audio.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize.
|
||||
|
||||
Returns:
|
||||
Raw PCM audio bytes (24kHz, mono, 16-bit signed LE).
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the TTS service returns an error.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
payload = {
|
||||
"model": self._settings.tts_model,
|
||||
"voice": self._settings.tts_voice,
|
||||
"input": text,
|
||||
"response_format": "pcm",
|
||||
"speed": 1.0,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
"TTS request to %s (model=%s, voice=%s): %r",
|
||||
self._url,
|
||||
self._settings.tts_model,
|
||||
self._settings.tts_voice,
|
||||
text[:80],
|
||||
)
|
||||
|
||||
response = await self._http.post(self._url, json=payload, timeout=60.0)
|
||||
response.raise_for_status()
|
||||
|
||||
pcm_data = response.content
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info("TTS completed in %.2fs (%d bytes)", elapsed, len(pcm_data))
|
||||
|
||||
return pcm_data
|
||||
|
||||
async def synthesize_stream(self, text: str) -> AsyncIterator[bytes]:
|
||||
"""Synthesize speech and yield PCM audio chunks as they arrive.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize.
|
||||
|
||||
Yields:
|
||||
Chunks of raw PCM audio bytes.
|
||||
"""
|
||||
payload = {
|
||||
"model": self._settings.tts_model,
|
||||
"voice": self._settings.tts_voice,
|
||||
"input": text,
|
||||
"response_format": "pcm",
|
||||
"speed": 1.0,
|
||||
}
|
||||
|
||||
logger.debug("TTS streaming request: %r", text[:80])
|
||||
|
||||
async with self._http.stream("POST", self._url, json=payload, timeout=60.0) as response:
|
||||
response.raise_for_status()
|
||||
async for chunk in response.aiter_bytes(chunk_size=TTS_STREAM_CHUNK_SIZE):
|
||||
yield chunk
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Check if the TTS service is reachable."""
|
||||
try:
|
||||
base_url = self._settings.tts_url.rstrip("/")
|
||||
response = await self._http.get(f"{base_url}/v1/models", timeout=5.0)
|
||||
return response.status_code == 200
|
||||
except (httpx.ConnectError, httpx.TimeoutException):
|
||||
return False
|
||||
226
stentor-gateway/test_client.py
Normal file
226
stentor-gateway/test_client.py
Normal file
@@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Stentor Gateway test client.
|
||||
|
||||
Sends a WAV file over WebSocket to the Stentor Gateway and plays back
|
||||
or saves the audio response. Useful for testing without ESP32 hardware.
|
||||
|
||||
Usage:
|
||||
# Send a WAV file and save the response
|
||||
python test_client.py --input recording.wav --output response.pcm
|
||||
|
||||
# Send a WAV file to a custom gateway URL
|
||||
python test_client.py --input recording.wav --gateway ws://10.10.0.5:8600
|
||||
|
||||
# Generate silent audio for testing connectivity
|
||||
python test_client.py --test-silence
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import struct
|
||||
import sys
|
||||
import wave
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
async def run_client(
|
||||
gateway_url: str,
|
||||
audio_data: bytes,
|
||||
client_id: str = "test-client",
|
||||
output_file: str | None = None,
|
||||
) -> None:
|
||||
"""Connect to the gateway, send audio, and receive the response.
|
||||
|
||||
Args:
|
||||
gateway_url: WebSocket URL of the Stentor Gateway.
|
||||
audio_data: Raw PCM audio bytes to send.
|
||||
client_id: Client identifier.
|
||||
output_file: Optional path to save response PCM audio.
|
||||
"""
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("Error: 'websockets' package required. Install with: pip install websockets")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Connecting to {gateway_url}...")
|
||||
|
||||
async with websockets.connect(gateway_url) as ws:
|
||||
# 1. Start session
|
||||
await ws.send(json.dumps({
|
||||
"type": "session.start",
|
||||
"client_id": client_id,
|
||||
"audio_config": {
|
||||
"sample_rate": 16000,
|
||||
"channels": 1,
|
||||
"sample_width": 16,
|
||||
"encoding": "pcm_s16le",
|
||||
},
|
||||
}))
|
||||
|
||||
# Wait for session.created
|
||||
msg = json.loads(await ws.recv())
|
||||
assert msg["type"] == "session.created", f"Expected session.created, got {msg}"
|
||||
session_id = msg["session_id"]
|
||||
print(f"Session created: {session_id}")
|
||||
|
||||
# Wait for listening status
|
||||
msg = json.loads(await ws.recv())
|
||||
print(f"Status: {msg.get('state', msg)}")
|
||||
|
||||
# 2. Stream audio in chunks (32ms chunks at 16kHz = 1024 bytes)
|
||||
chunk_size = 1024
|
||||
total_chunks = 0
|
||||
for offset in range(0, len(audio_data), chunk_size):
|
||||
chunk = audio_data[offset : offset + chunk_size]
|
||||
b64_chunk = base64.b64encode(chunk).decode("ascii")
|
||||
await ws.send(json.dumps({
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": b64_chunk,
|
||||
}))
|
||||
total_chunks += 1
|
||||
|
||||
print(f"Sent {total_chunks} audio chunks ({len(audio_data)} bytes)")
|
||||
|
||||
# 3. Commit the audio buffer
|
||||
await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
|
||||
print("Audio committed, waiting for response...")
|
||||
|
||||
# 4. Receive response events
|
||||
response_audio = bytearray()
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
raw = await ws.recv()
|
||||
msg = json.loads(raw)
|
||||
msg_type = msg.get("type", "")
|
||||
|
||||
if msg_type == "status":
|
||||
print(f" Status: {msg['state']}")
|
||||
|
||||
elif msg_type == "transcript.done":
|
||||
print(f" Transcript: {msg['text']}")
|
||||
|
||||
elif msg_type == "response.text.done":
|
||||
print(f" Response: {msg['text']}")
|
||||
|
||||
elif msg_type == "response.audio.delta":
|
||||
chunk = base64.b64decode(msg["delta"])
|
||||
response_audio.extend(chunk)
|
||||
print(f" Audio chunk: {len(chunk)} bytes", end="\r")
|
||||
|
||||
elif msg_type == "response.audio.done":
|
||||
print(f"\n Audio complete: {len(response_audio)} bytes total")
|
||||
|
||||
elif msg_type == "response.done":
|
||||
print(" Response complete!")
|
||||
done = True
|
||||
|
||||
elif msg_type == "error":
|
||||
print(f" ERROR [{msg.get('code', '?')}]: {msg['message']}")
|
||||
done = True
|
||||
|
||||
else:
|
||||
print(f" Unknown event: {msg_type}")
|
||||
|
||||
# 5. Save response audio
|
||||
if response_audio:
|
||||
if output_file:
|
||||
out_path = Path(output_file)
|
||||
if out_path.suffix == ".wav":
|
||||
# Write as WAV
|
||||
with wave.open(str(out_path), "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(16000)
|
||||
wf.writeframes(bytes(response_audio))
|
||||
else:
|
||||
# Write raw PCM
|
||||
out_path.write_bytes(bytes(response_audio))
|
||||
print(f"Response audio saved to {output_file}")
|
||||
else:
|
||||
print("(Use --output to save response audio)")
|
||||
|
||||
# 6. Close session
|
||||
await ws.send(json.dumps({"type": "session.close"}))
|
||||
print("Session closed.")
|
||||
|
||||
|
||||
def load_wav_as_pcm(wav_path: str) -> bytes:
|
||||
"""Load a WAV file and return raw PCM data."""
|
||||
with wave.open(wav_path, "rb") as wf:
|
||||
print(f"Input: {wav_path}")
|
||||
print(f" Channels: {wf.getnchannels()}")
|
||||
print(f" Sample rate: {wf.getframerate()} Hz")
|
||||
print(f" Sample width: {wf.getsampwidth() * 8}-bit")
|
||||
print(f" Frames: {wf.getnframes()}")
|
||||
print(f" Duration: {wf.getnframes() / wf.getframerate():.2f}s")
|
||||
|
||||
if wf.getframerate() != 16000:
|
||||
print(f" WARNING: Expected 16kHz, got {wf.getframerate()} Hz")
|
||||
if wf.getnchannels() != 1:
|
||||
print(f" WARNING: Expected mono, got {wf.getnchannels()} channels")
|
||||
if wf.getsampwidth() != 2:
|
||||
print(f" WARNING: Expected 16-bit, got {wf.getsampwidth() * 8}-bit")
|
||||
|
||||
return wf.readframes(wf.getnframes())
|
||||
|
||||
|
||||
def generate_silence(duration_ms: int = 2000) -> bytes:
|
||||
"""Generate silent PCM audio for testing."""
|
||||
num_samples = int(16000 * duration_ms / 1000)
|
||||
print(f"Generated {duration_ms}ms of silence ({num_samples} samples)")
|
||||
return struct.pack(f"<{num_samples}h", *([0] * num_samples))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Stentor Gateway test client",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gateway",
|
||||
default="ws://localhost:8600/api/v1/realtime",
|
||||
help="Gateway WebSocket URL (default: ws://localhost:8600/api/v1/realtime)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input", "-i",
|
||||
help="Path to input WAV file (16kHz, mono, 16-bit)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
help="Path to save response audio (.wav or .pcm)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client-id",
|
||||
default="test-client",
|
||||
help="Client identifier (default: test-client)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-silence",
|
||||
action="store_true",
|
||||
help="Send 2 seconds of silence (for connectivity testing)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.test_silence:
|
||||
audio_data = generate_silence()
|
||||
elif args.input:
|
||||
audio_data = load_wav_as_pcm(args.input)
|
||||
else:
|
||||
parser.error("Specify --input WAV_FILE or --test-silence")
|
||||
|
||||
asyncio.run(run_client(
|
||||
gateway_url=args.gateway,
|
||||
audio_data=audio_data,
|
||||
client_id=args.client_id,
|
||||
output_file=args.output,
|
||||
))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
stentor-gateway/tests/__init__.py
Normal file
1
stentor-gateway/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Stentor Gateway tests."""
|
||||
34
stentor-gateway/tests/conftest.py
Normal file
34
stentor-gateway/tests/conftest.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Shared test fixtures for Stentor Gateway tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
from stentor.config import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings() -> Settings:
|
||||
"""Test settings with localhost endpoints."""
|
||||
return Settings(
|
||||
stt_url="http://localhost:9001",
|
||||
tts_url="http://localhost:9002",
|
||||
agent_url="http://localhost:9003",
|
||||
log_level="DEBUG",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pcm() -> bytes:
|
||||
"""Generate 1 second of silent PCM audio (16kHz, mono, 16-bit)."""
|
||||
import struct
|
||||
|
||||
num_samples = 16000 # 1 second at 16kHz
|
||||
return struct.pack(f"<{num_samples}h", *([0] * num_samples))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pcm_short() -> bytes:
|
||||
"""Generate 100ms of silent PCM audio."""
|
||||
import struct
|
||||
|
||||
num_samples = 1600 # 100ms at 16kHz
|
||||
return struct.pack(f"<{num_samples}h", *([0] * num_samples))
|
||||
120
stentor-gateway/tests/test_pipeline.py
Normal file
120
stentor-gateway/tests/test_pipeline.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Tests for the voice pipeline orchestrator."""
|
||||
|
||||
import struct
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from stentor.agent_client import AgentClient
|
||||
from stentor.pipeline import Pipeline, PipelineState
|
||||
from stentor.stt_client import STTClient
|
||||
from stentor.tts_client import TTSClient
|
||||
|
||||
|
||||
class TestPipeline:
|
||||
"""Tests for the Pipeline orchestrator."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stt(self):
|
||||
"""Create a mock STT client."""
|
||||
stt = AsyncMock(spec=STTClient)
|
||||
stt.transcribe.return_value = "What is the weather?"
|
||||
return stt
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts(self):
|
||||
"""Create a mock TTS client."""
|
||||
tts = AsyncMock(spec=TTSClient)
|
||||
# Return 100 samples of silence as PCM (at 24kHz, will be resampled)
|
||||
tts.synthesize.return_value = struct.pack("<100h", *([0] * 100))
|
||||
return tts
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent(self):
|
||||
"""Create a mock agent client."""
|
||||
agent = AsyncMock(spec=AgentClient)
|
||||
agent.send_message.return_value = "I don't have weather tools yet."
|
||||
return agent
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self, settings, mock_stt, mock_tts, mock_agent):
|
||||
"""Create a pipeline with mock clients."""
|
||||
state = PipelineState()
|
||||
return Pipeline(settings, mock_stt, mock_tts, mock_agent, state)
|
||||
|
||||
async def test_full_pipeline(self, pipeline, sample_pcm, mock_stt, mock_tts, mock_agent):
|
||||
"""Test the complete pipeline produces expected event sequence."""
|
||||
events = []
|
||||
async for event in pipeline.process(sample_pcm):
|
||||
events.append(event)
|
||||
|
||||
# Verify event sequence
|
||||
event_types = [e.type for e in events]
|
||||
|
||||
assert "status" in event_types # transcribing status
|
||||
assert "transcript.done" in event_types
|
||||
assert "response.text.done" in event_types
|
||||
assert "response.audio.delta" in event_types or "response.audio.done" in event_types
|
||||
assert "response.done" in event_types
|
||||
|
||||
# Verify services were called
|
||||
mock_stt.transcribe.assert_called_once()
|
||||
mock_agent.send_message.assert_called_once_with("What is the weather?")
|
||||
mock_tts.synthesize.assert_called_once_with("I don't have weather tools yet.")
|
||||
|
||||
async def test_pipeline_empty_transcript(self, settings, mock_tts, mock_agent):
|
||||
"""Test pipeline handles empty transcript gracefully."""
|
||||
mock_stt = AsyncMock(spec=STTClient)
|
||||
mock_stt.transcribe.return_value = ""
|
||||
|
||||
state = PipelineState()
|
||||
pipeline = Pipeline(settings, mock_stt, mock_tts, mock_agent, state)
|
||||
|
||||
events = []
|
||||
sample_pcm = struct.pack("<100h", *([0] * 100))
|
||||
async for event in pipeline.process(sample_pcm):
|
||||
events.append(event)
|
||||
|
||||
event_types = [e.type for e in events]
|
||||
assert "error" in event_types
|
||||
|
||||
# Agent and TTS should NOT have been called
|
||||
mock_agent.send_message.assert_not_called()
|
||||
mock_tts.synthesize.assert_not_called()
|
||||
|
||||
async def test_pipeline_empty_agent_response(self, settings, mock_stt, mock_tts):
|
||||
"""Test pipeline handles empty agent response."""
|
||||
mock_agent = AsyncMock(spec=AgentClient)
|
||||
mock_agent.send_message.return_value = ""
|
||||
|
||||
state = PipelineState()
|
||||
pipeline = Pipeline(settings, mock_stt, mock_tts, mock_agent, state)
|
||||
|
||||
events = []
|
||||
sample_pcm = struct.pack("<100h", *([0] * 100))
|
||||
async for event in pipeline.process(sample_pcm):
|
||||
events.append(event)
|
||||
|
||||
event_types = [e.type for e in events]
|
||||
assert "error" in event_types
|
||||
mock_tts.synthesize.assert_not_called()
|
||||
|
||||
async def test_pipeline_metrics_recorded(self, pipeline, sample_pcm):
|
||||
"""Test that pipeline metrics are recorded."""
|
||||
state = pipeline._state
|
||||
|
||||
assert state.total_transcriptions == 0
|
||||
|
||||
events = []
|
||||
async for event in pipeline.process(sample_pcm):
|
||||
events.append(event)
|
||||
|
||||
assert state.total_transcriptions == 1
|
||||
assert state.total_agent_requests == 1
|
||||
assert state.total_tts_requests == 1
|
||||
assert len(state.recent_metrics) == 1
|
||||
|
||||
last = state.recent_metrics[-1]
|
||||
assert last.total_duration > 0
|
||||
assert last.transcript == "What is the weather?"
|
||||
assert last.response_text == "I don't have weather tools yet."
|
||||
89
stentor-gateway/tests/test_stt_client.py
Normal file
89
stentor-gateway/tests/test_stt_client.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Tests for the Speaches STT client."""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from stentor.audio import pcm_to_wav
|
||||
from stentor.stt_client import STTClient
|
||||
|
||||
|
||||
class TestSTTClient:
|
||||
"""Tests for STTClient."""
|
||||
|
||||
@pytest.fixture
|
||||
def stt_client(self, settings):
|
||||
"""Create an STT client with a mock HTTP client."""
|
||||
http_client = httpx.AsyncClient()
|
||||
return STTClient(settings, http_client)
|
||||
|
||||
async def test_transcribe_success(self, settings, sample_pcm, httpx_mock):
|
||||
"""Test successful transcription."""
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.stt_url}/v1/audio/transcriptions",
|
||||
method="POST",
|
||||
json={"text": "Hello world"},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = STTClient(settings, http_client)
|
||||
wav_data = pcm_to_wav(sample_pcm)
|
||||
result = await client.transcribe(wav_data)
|
||||
|
||||
assert result == "Hello world"
|
||||
|
||||
async def test_transcribe_with_language(self, settings, sample_pcm, httpx_mock):
|
||||
"""Test transcription with explicit language."""
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.stt_url}/v1/audio/transcriptions",
|
||||
method="POST",
|
||||
json={"text": "Bonjour le monde"},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = STTClient(settings, http_client)
|
||||
wav_data = pcm_to_wav(sample_pcm)
|
||||
result = await client.transcribe(wav_data, language="fr")
|
||||
|
||||
assert result == "Bonjour le monde"
|
||||
|
||||
async def test_transcribe_empty_result(self, settings, sample_pcm, httpx_mock):
|
||||
"""Test transcription returning empty text."""
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.stt_url}/v1/audio/transcriptions",
|
||||
method="POST",
|
||||
json={"text": " "},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = STTClient(settings, http_client)
|
||||
wav_data = pcm_to_wav(sample_pcm)
|
||||
result = await client.transcribe(wav_data)
|
||||
|
||||
assert result == ""
|
||||
|
||||
async def test_is_available_success(self, settings, httpx_mock):
|
||||
"""Test availability check when service is up."""
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.stt_url}/v1/models",
|
||||
method="GET",
|
||||
json={"models": []},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = STTClient(settings, http_client)
|
||||
available = await client.is_available()
|
||||
|
||||
assert available is True
|
||||
|
||||
async def test_is_available_failure(self, settings, httpx_mock):
|
||||
"""Test availability check when service is down."""
|
||||
httpx_mock.add_exception(
|
||||
httpx.ConnectError("Connection refused"),
|
||||
url=f"{settings.stt_url}/v1/models",
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = STTClient(settings, http_client)
|
||||
available = await client.is_available()
|
||||
|
||||
assert available is False
|
||||
78
stentor-gateway/tests/test_tts_client.py
Normal file
78
stentor-gateway/tests/test_tts_client.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Tests for the Speaches TTS client."""
|
||||
|
||||
import struct
|
||||
|
||||
import httpx
|
||||
|
||||
from stentor.tts_client import TTSClient
|
||||
|
||||
|
||||
class TestTTSClient:
|
||||
"""Tests for TTSClient."""
|
||||
|
||||
async def test_synthesize_success(self, settings, httpx_mock):
|
||||
"""Test successful TTS synthesis."""
|
||||
# Generate fake PCM audio (100 samples of silence)
|
||||
fake_pcm = struct.pack("<100h", *([0] * 100))
|
||||
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.tts_url}/v1/audio/speech",
|
||||
method="POST",
|
||||
content=fake_pcm,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = TTSClient(settings, http_client)
|
||||
result = await client.synthesize("Hello world")
|
||||
|
||||
assert result == fake_pcm
|
||||
|
||||
async def test_synthesize_uses_correct_params(self, settings, httpx_mock):
|
||||
"""Test that TTS requests include correct parameters."""
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.tts_url}/v1/audio/speech",
|
||||
method="POST",
|
||||
content=b"\x00\x00",
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = TTSClient(settings, http_client)
|
||||
await client.synthesize("Test text")
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
assert request is not None
|
||||
|
||||
import json
|
||||
body = json.loads(request.content)
|
||||
assert body["model"] == settings.tts_model
|
||||
assert body["voice"] == settings.tts_voice
|
||||
assert body["input"] == "Test text"
|
||||
assert body["response_format"] == "pcm"
|
||||
assert body["speed"] == 1.0
|
||||
|
||||
async def test_is_available_success(self, settings, httpx_mock):
|
||||
"""Test availability check when service is up."""
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.tts_url}/v1/models",
|
||||
method="GET",
|
||||
json={"models": []},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = TTSClient(settings, http_client)
|
||||
available = await client.is_available()
|
||||
|
||||
assert available is True
|
||||
|
||||
async def test_is_available_failure(self, settings, httpx_mock):
|
||||
"""Test availability check when service is down."""
|
||||
httpx_mock.add_exception(
|
||||
httpx.ConnectError("Connection refused"),
|
||||
url=f"{settings.tts_url}/v1/models",
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = TTSClient(settings, http_client)
|
||||
available = await client.is_available()
|
||||
|
||||
assert available is False
|
||||
185
stentor-gateway/tests/test_websocket.py
Normal file
185
stentor-gateway/tests/test_websocket.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Tests for the WebSocket endpoint and message models."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from stentor.audio import encode_audio
|
||||
from stentor.models import (
|
||||
AudioConfig,
|
||||
ErrorEvent,
|
||||
InputAudioBufferAppend,
|
||||
ResponseAudioDelta,
|
||||
ResponseAudioDone,
|
||||
ResponseDone,
|
||||
ResponseTextDone,
|
||||
SessionCreated,
|
||||
SessionStart,
|
||||
StatusUpdate,
|
||||
TranscriptDone,
|
||||
)
|
||||
|
||||
|
||||
class TestMessageModels:
|
||||
"""Tests for WebSocket message serialization."""
|
||||
|
||||
def test_session_start_defaults(self):
|
||||
"""Test SessionStart with default values."""
|
||||
msg = SessionStart()
|
||||
data = msg.model_dump()
|
||||
assert data["type"] == "session.start"
|
||||
assert data["client_id"] == ""
|
||||
assert data["audio_config"]["sample_rate"] == 16000
|
||||
|
||||
def test_session_start_custom(self):
|
||||
"""Test SessionStart with custom values."""
|
||||
msg = SessionStart(
|
||||
client_id="esp32-kitchen",
|
||||
audio_config=AudioConfig(sample_rate=24000),
|
||||
)
|
||||
assert msg.client_id == "esp32-kitchen"
|
||||
assert msg.audio_config.sample_rate == 24000
|
||||
|
||||
def test_input_audio_buffer_append(self):
|
||||
"""Test audio append message."""
|
||||
audio_b64 = encode_audio(b"\x00\x00\x01\x00")
|
||||
msg = InputAudioBufferAppend(audio=audio_b64)
|
||||
data = msg.model_dump()
|
||||
assert data["type"] == "input_audio_buffer.append"
|
||||
assert data["audio"] == audio_b64
|
||||
|
||||
def test_session_created(self):
|
||||
"""Test session created response."""
|
||||
msg = SessionCreated(session_id="test-uuid")
|
||||
data = json.loads(msg.model_dump_json())
|
||||
assert data["type"] == "session.created"
|
||||
assert data["session_id"] == "test-uuid"
|
||||
|
||||
def test_status_update(self):
|
||||
"""Test status update message."""
|
||||
for state in ("listening", "transcribing", "thinking", "speaking"):
|
||||
msg = StatusUpdate(state=state)
|
||||
assert msg.state == state
|
||||
|
||||
def test_transcript_done(self):
|
||||
"""Test transcript done message."""
|
||||
msg = TranscriptDone(text="Hello world")
|
||||
data = json.loads(msg.model_dump_json())
|
||||
assert data["type"] == "transcript.done"
|
||||
assert data["text"] == "Hello world"
|
||||
|
||||
def test_response_text_done(self):
|
||||
"""Test response text done message."""
|
||||
msg = ResponseTextDone(text="I can help with that.")
|
||||
data = json.loads(msg.model_dump_json())
|
||||
assert data["type"] == "response.text.done"
|
||||
assert data["text"] == "I can help with that."
|
||||
|
||||
def test_response_audio_delta(self):
|
||||
"""Test audio delta message."""
|
||||
msg = ResponseAudioDelta(delta="AAAA")
|
||||
data = json.loads(msg.model_dump_json())
|
||||
assert data["type"] == "response.audio.delta"
|
||||
assert data["delta"] == "AAAA"
|
||||
|
||||
def test_response_audio_done(self):
|
||||
"""Test audio done message."""
|
||||
msg = ResponseAudioDone()
|
||||
assert msg.type == "response.audio.done"
|
||||
|
||||
def test_response_done(self):
|
||||
"""Test response done message."""
|
||||
msg = ResponseDone()
|
||||
assert msg.type == "response.done"
|
||||
|
||||
def test_error_event(self):
|
||||
"""Test error event message."""
|
||||
msg = ErrorEvent(message="Something went wrong", code="test_error")
|
||||
data = json.loads(msg.model_dump_json())
|
||||
assert data["type"] == "error"
|
||||
assert data["message"] == "Something went wrong"
|
||||
assert data["code"] == "test_error"
|
||||
|
||||
def test_error_event_default_code(self):
|
||||
"""Test error event with default code."""
|
||||
msg = ErrorEvent(message="Oops")
|
||||
assert msg.code == "unknown_error"
|
||||
|
||||
|
||||
class TestWebSocketEndpoint:
|
||||
"""Tests for the /api/v1/realtime WebSocket endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
"""Create a test client with lifespan to populate app.state."""
|
||||
from stentor.main import app
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
def test_health_live(self, client):
|
||||
"""Test liveness endpoint."""
|
||||
response = client.get("/api/live/")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
def test_api_info(self, client):
|
||||
"""Test API info endpoint."""
|
||||
response = client.get("/api/v1/info")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "stentor-gateway"
|
||||
assert data["version"] == "0.1.0"
|
||||
assert "realtime" in data["endpoints"]
|
||||
|
||||
def test_websocket_session_lifecycle(self, client):
|
||||
"""Test basic WebSocket session start and close."""
|
||||
with client.websocket_connect("/api/v1/realtime") as ws:
|
||||
# Send session.start
|
||||
ws.send_json({
|
||||
"type": "session.start",
|
||||
"client_id": "test-client",
|
||||
})
|
||||
|
||||
# Receive session.created
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "session.created"
|
||||
assert "session_id" in msg
|
||||
|
||||
# Receive initial status: listening
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "status"
|
||||
assert msg["state"] == "listening"
|
||||
|
||||
# Send session.close
|
||||
ws.send_json({"type": "session.close"})
|
||||
|
||||
def test_websocket_no_session_error(self, client):
|
||||
"""Test sending audio without starting a session."""
|
||||
with client.websocket_connect("/api/v1/realtime") as ws:
|
||||
ws.send_json({
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": "AAAA",
|
||||
})
|
||||
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert msg["code"] == "no_session"
|
||||
|
||||
def test_websocket_invalid_json(self, client):
|
||||
"""Test sending invalid JSON."""
|
||||
with client.websocket_connect("/api/v1/realtime") as ws:
|
||||
ws.send_text("not valid json{{{")
|
||||
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert msg["code"] == "invalid_json"
|
||||
|
||||
def test_websocket_unknown_event(self, client):
|
||||
"""Test sending unknown event type."""
|
||||
with client.websocket_connect("/api/v1/realtime") as ws:
|
||||
ws.send_json({"type": "scooby.dooby.doo"})
|
||||
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert msg["code"] == "unknown_event"
|
||||
Reference in New Issue
Block a user