feat: scaffold stentor-gateway with FastAPI voice pipeline
Initialize the stentor-gateway project with WebSocket-based voice pipeline orchestrating STT → Agent → TTS via OpenAI-compatible APIs. - Add FastAPI app with WebSocket endpoint for audio streaming - Add pipeline orchestration (stt_client, tts_client, agent_client) - Add Pydantic Settings configuration and message models - Add audio utilities for PCM/WAV conversion and resampling - Add health check endpoints - Add Dockerfile and pyproject.toml with dependencies - Add initial test suite (pipeline, STT, TTS, WebSocket) - Add comprehensive README covering gateway and ESP32 ear design - Clean up .gitignore for Python/uv project
This commit is contained in:
1
stentor-gateway/tests/__init__.py
Normal file
1
stentor-gateway/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Stentor Gateway tests."""
|
||||
34
stentor-gateway/tests/conftest.py
Normal file
34
stentor-gateway/tests/conftest.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Shared test fixtures for Stentor Gateway tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
from stentor.config import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings() -> Settings:
|
||||
"""Test settings with localhost endpoints."""
|
||||
return Settings(
|
||||
stt_url="http://localhost:9001",
|
||||
tts_url="http://localhost:9002",
|
||||
agent_url="http://localhost:9003",
|
||||
log_level="DEBUG",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pcm() -> bytes:
|
||||
"""Generate 1 second of silent PCM audio (16kHz, mono, 16-bit)."""
|
||||
import struct
|
||||
|
||||
num_samples = 16000 # 1 second at 16kHz
|
||||
return struct.pack(f"<{num_samples}h", *([0] * num_samples))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pcm_short() -> bytes:
|
||||
"""Generate 100ms of silent PCM audio."""
|
||||
import struct
|
||||
|
||||
num_samples = 1600 # 100ms at 16kHz
|
||||
return struct.pack(f"<{num_samples}h", *([0] * num_samples))
|
||||
120
stentor-gateway/tests/test_pipeline.py
Normal file
120
stentor-gateway/tests/test_pipeline.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Tests for the voice pipeline orchestrator."""
|
||||
|
||||
import struct
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from stentor.agent_client import AgentClient
|
||||
from stentor.pipeline import Pipeline, PipelineState
|
||||
from stentor.stt_client import STTClient
|
||||
from stentor.tts_client import TTSClient
|
||||
|
||||
|
||||
class TestPipeline:
|
||||
"""Tests for the Pipeline orchestrator."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stt(self):
|
||||
"""Create a mock STT client."""
|
||||
stt = AsyncMock(spec=STTClient)
|
||||
stt.transcribe.return_value = "What is the weather?"
|
||||
return stt
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts(self):
|
||||
"""Create a mock TTS client."""
|
||||
tts = AsyncMock(spec=TTSClient)
|
||||
# Return 100 samples of silence as PCM (at 24kHz, will be resampled)
|
||||
tts.synthesize.return_value = struct.pack("<100h", *([0] * 100))
|
||||
return tts
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent(self):
|
||||
"""Create a mock agent client."""
|
||||
agent = AsyncMock(spec=AgentClient)
|
||||
agent.send_message.return_value = "I don't have weather tools yet."
|
||||
return agent
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self, settings, mock_stt, mock_tts, mock_agent):
|
||||
"""Create a pipeline with mock clients."""
|
||||
state = PipelineState()
|
||||
return Pipeline(settings, mock_stt, mock_tts, mock_agent, state)
|
||||
|
||||
async def test_full_pipeline(self, pipeline, sample_pcm, mock_stt, mock_tts, mock_agent):
|
||||
"""Test the complete pipeline produces expected event sequence."""
|
||||
events = []
|
||||
async for event in pipeline.process(sample_pcm):
|
||||
events.append(event)
|
||||
|
||||
# Verify event sequence
|
||||
event_types = [e.type for e in events]
|
||||
|
||||
assert "status" in event_types # transcribing status
|
||||
assert "transcript.done" in event_types
|
||||
assert "response.text.done" in event_types
|
||||
assert "response.audio.delta" in event_types or "response.audio.done" in event_types
|
||||
assert "response.done" in event_types
|
||||
|
||||
# Verify services were called
|
||||
mock_stt.transcribe.assert_called_once()
|
||||
mock_agent.send_message.assert_called_once_with("What is the weather?")
|
||||
mock_tts.synthesize.assert_called_once_with("I don't have weather tools yet.")
|
||||
|
||||
async def test_pipeline_empty_transcript(self, settings, mock_tts, mock_agent):
|
||||
"""Test pipeline handles empty transcript gracefully."""
|
||||
mock_stt = AsyncMock(spec=STTClient)
|
||||
mock_stt.transcribe.return_value = ""
|
||||
|
||||
state = PipelineState()
|
||||
pipeline = Pipeline(settings, mock_stt, mock_tts, mock_agent, state)
|
||||
|
||||
events = []
|
||||
sample_pcm = struct.pack("<100h", *([0] * 100))
|
||||
async for event in pipeline.process(sample_pcm):
|
||||
events.append(event)
|
||||
|
||||
event_types = [e.type for e in events]
|
||||
assert "error" in event_types
|
||||
|
||||
# Agent and TTS should NOT have been called
|
||||
mock_agent.send_message.assert_not_called()
|
||||
mock_tts.synthesize.assert_not_called()
|
||||
|
||||
async def test_pipeline_empty_agent_response(self, settings, mock_stt, mock_tts):
|
||||
"""Test pipeline handles empty agent response."""
|
||||
mock_agent = AsyncMock(spec=AgentClient)
|
||||
mock_agent.send_message.return_value = ""
|
||||
|
||||
state = PipelineState()
|
||||
pipeline = Pipeline(settings, mock_stt, mock_tts, mock_agent, state)
|
||||
|
||||
events = []
|
||||
sample_pcm = struct.pack("<100h", *([0] * 100))
|
||||
async for event in pipeline.process(sample_pcm):
|
||||
events.append(event)
|
||||
|
||||
event_types = [e.type for e in events]
|
||||
assert "error" in event_types
|
||||
mock_tts.synthesize.assert_not_called()
|
||||
|
||||
async def test_pipeline_metrics_recorded(self, pipeline, sample_pcm):
|
||||
"""Test that pipeline metrics are recorded."""
|
||||
state = pipeline._state
|
||||
|
||||
assert state.total_transcriptions == 0
|
||||
|
||||
events = []
|
||||
async for event in pipeline.process(sample_pcm):
|
||||
events.append(event)
|
||||
|
||||
assert state.total_transcriptions == 1
|
||||
assert state.total_agent_requests == 1
|
||||
assert state.total_tts_requests == 1
|
||||
assert len(state.recent_metrics) == 1
|
||||
|
||||
last = state.recent_metrics[-1]
|
||||
assert last.total_duration > 0
|
||||
assert last.transcript == "What is the weather?"
|
||||
assert last.response_text == "I don't have weather tools yet."
|
||||
89
stentor-gateway/tests/test_stt_client.py
Normal file
89
stentor-gateway/tests/test_stt_client.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Tests for the Speaches STT client."""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from stentor.audio import pcm_to_wav
|
||||
from stentor.stt_client import STTClient
|
||||
|
||||
|
||||
class TestSTTClient:
|
||||
"""Tests for STTClient."""
|
||||
|
||||
@pytest.fixture
|
||||
def stt_client(self, settings):
|
||||
"""Create an STT client with a mock HTTP client."""
|
||||
http_client = httpx.AsyncClient()
|
||||
return STTClient(settings, http_client)
|
||||
|
||||
async def test_transcribe_success(self, settings, sample_pcm, httpx_mock):
|
||||
"""Test successful transcription."""
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.stt_url}/v1/audio/transcriptions",
|
||||
method="POST",
|
||||
json={"text": "Hello world"},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = STTClient(settings, http_client)
|
||||
wav_data = pcm_to_wav(sample_pcm)
|
||||
result = await client.transcribe(wav_data)
|
||||
|
||||
assert result == "Hello world"
|
||||
|
||||
async def test_transcribe_with_language(self, settings, sample_pcm, httpx_mock):
|
||||
"""Test transcription with explicit language."""
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.stt_url}/v1/audio/transcriptions",
|
||||
method="POST",
|
||||
json={"text": "Bonjour le monde"},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = STTClient(settings, http_client)
|
||||
wav_data = pcm_to_wav(sample_pcm)
|
||||
result = await client.transcribe(wav_data, language="fr")
|
||||
|
||||
assert result == "Bonjour le monde"
|
||||
|
||||
async def test_transcribe_empty_result(self, settings, sample_pcm, httpx_mock):
|
||||
"""Test transcription returning empty text."""
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.stt_url}/v1/audio/transcriptions",
|
||||
method="POST",
|
||||
json={"text": " "},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = STTClient(settings, http_client)
|
||||
wav_data = pcm_to_wav(sample_pcm)
|
||||
result = await client.transcribe(wav_data)
|
||||
|
||||
assert result == ""
|
||||
|
||||
async def test_is_available_success(self, settings, httpx_mock):
|
||||
"""Test availability check when service is up."""
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.stt_url}/v1/models",
|
||||
method="GET",
|
||||
json={"models": []},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = STTClient(settings, http_client)
|
||||
available = await client.is_available()
|
||||
|
||||
assert available is True
|
||||
|
||||
async def test_is_available_failure(self, settings, httpx_mock):
|
||||
"""Test availability check when service is down."""
|
||||
httpx_mock.add_exception(
|
||||
httpx.ConnectError("Connection refused"),
|
||||
url=f"{settings.stt_url}/v1/models",
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = STTClient(settings, http_client)
|
||||
available = await client.is_available()
|
||||
|
||||
assert available is False
|
||||
78
stentor-gateway/tests/test_tts_client.py
Normal file
78
stentor-gateway/tests/test_tts_client.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Tests for the Speaches TTS client."""
|
||||
|
||||
import struct
|
||||
|
||||
import httpx
|
||||
|
||||
from stentor.tts_client import TTSClient
|
||||
|
||||
|
||||
class TestTTSClient:
|
||||
"""Tests for TTSClient."""
|
||||
|
||||
async def test_synthesize_success(self, settings, httpx_mock):
|
||||
"""Test successful TTS synthesis."""
|
||||
# Generate fake PCM audio (100 samples of silence)
|
||||
fake_pcm = struct.pack("<100h", *([0] * 100))
|
||||
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.tts_url}/v1/audio/speech",
|
||||
method="POST",
|
||||
content=fake_pcm,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = TTSClient(settings, http_client)
|
||||
result = await client.synthesize("Hello world")
|
||||
|
||||
assert result == fake_pcm
|
||||
|
||||
async def test_synthesize_uses_correct_params(self, settings, httpx_mock):
|
||||
"""Test that TTS requests include correct parameters."""
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.tts_url}/v1/audio/speech",
|
||||
method="POST",
|
||||
content=b"\x00\x00",
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = TTSClient(settings, http_client)
|
||||
await client.synthesize("Test text")
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
assert request is not None
|
||||
|
||||
import json
|
||||
body = json.loads(request.content)
|
||||
assert body["model"] == settings.tts_model
|
||||
assert body["voice"] == settings.tts_voice
|
||||
assert body["input"] == "Test text"
|
||||
assert body["response_format"] == "pcm"
|
||||
assert body["speed"] == 1.0
|
||||
|
||||
async def test_is_available_success(self, settings, httpx_mock):
|
||||
"""Test availability check when service is up."""
|
||||
httpx_mock.add_response(
|
||||
url=f"{settings.tts_url}/v1/models",
|
||||
method="GET",
|
||||
json={"models": []},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = TTSClient(settings, http_client)
|
||||
available = await client.is_available()
|
||||
|
||||
assert available is True
|
||||
|
||||
async def test_is_available_failure(self, settings, httpx_mock):
|
||||
"""Test availability check when service is down."""
|
||||
httpx_mock.add_exception(
|
||||
httpx.ConnectError("Connection refused"),
|
||||
url=f"{settings.tts_url}/v1/models",
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
client = TTSClient(settings, http_client)
|
||||
available = await client.is_available()
|
||||
|
||||
assert available is False
|
||||
185
stentor-gateway/tests/test_websocket.py
Normal file
185
stentor-gateway/tests/test_websocket.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Tests for the WebSocket endpoint and message models."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from stentor.audio import encode_audio
|
||||
from stentor.models import (
|
||||
AudioConfig,
|
||||
ErrorEvent,
|
||||
InputAudioBufferAppend,
|
||||
ResponseAudioDelta,
|
||||
ResponseAudioDone,
|
||||
ResponseDone,
|
||||
ResponseTextDone,
|
||||
SessionCreated,
|
||||
SessionStart,
|
||||
StatusUpdate,
|
||||
TranscriptDone,
|
||||
)
|
||||
|
||||
|
||||
class TestMessageModels:
|
||||
"""Tests for WebSocket message serialization."""
|
||||
|
||||
def test_session_start_defaults(self):
|
||||
"""Test SessionStart with default values."""
|
||||
msg = SessionStart()
|
||||
data = msg.model_dump()
|
||||
assert data["type"] == "session.start"
|
||||
assert data["client_id"] == ""
|
||||
assert data["audio_config"]["sample_rate"] == 16000
|
||||
|
||||
def test_session_start_custom(self):
|
||||
"""Test SessionStart with custom values."""
|
||||
msg = SessionStart(
|
||||
client_id="esp32-kitchen",
|
||||
audio_config=AudioConfig(sample_rate=24000),
|
||||
)
|
||||
assert msg.client_id == "esp32-kitchen"
|
||||
assert msg.audio_config.sample_rate == 24000
|
||||
|
||||
def test_input_audio_buffer_append(self):
|
||||
"""Test audio append message."""
|
||||
audio_b64 = encode_audio(b"\x00\x00\x01\x00")
|
||||
msg = InputAudioBufferAppend(audio=audio_b64)
|
||||
data = msg.model_dump()
|
||||
assert data["type"] == "input_audio_buffer.append"
|
||||
assert data["audio"] == audio_b64
|
||||
|
||||
def test_session_created(self):
|
||||
"""Test session created response."""
|
||||
msg = SessionCreated(session_id="test-uuid")
|
||||
data = json.loads(msg.model_dump_json())
|
||||
assert data["type"] == "session.created"
|
||||
assert data["session_id"] == "test-uuid"
|
||||
|
||||
def test_status_update(self):
|
||||
"""Test status update message."""
|
||||
for state in ("listening", "transcribing", "thinking", "speaking"):
|
||||
msg = StatusUpdate(state=state)
|
||||
assert msg.state == state
|
||||
|
||||
def test_transcript_done(self):
|
||||
"""Test transcript done message."""
|
||||
msg = TranscriptDone(text="Hello world")
|
||||
data = json.loads(msg.model_dump_json())
|
||||
assert data["type"] == "transcript.done"
|
||||
assert data["text"] == "Hello world"
|
||||
|
||||
def test_response_text_done(self):
|
||||
"""Test response text done message."""
|
||||
msg = ResponseTextDone(text="I can help with that.")
|
||||
data = json.loads(msg.model_dump_json())
|
||||
assert data["type"] == "response.text.done"
|
||||
assert data["text"] == "I can help with that."
|
||||
|
||||
def test_response_audio_delta(self):
|
||||
"""Test audio delta message."""
|
||||
msg = ResponseAudioDelta(delta="AAAA")
|
||||
data = json.loads(msg.model_dump_json())
|
||||
assert data["type"] == "response.audio.delta"
|
||||
assert data["delta"] == "AAAA"
|
||||
|
||||
def test_response_audio_done(self):
|
||||
"""Test audio done message."""
|
||||
msg = ResponseAudioDone()
|
||||
assert msg.type == "response.audio.done"
|
||||
|
||||
def test_response_done(self):
|
||||
"""Test response done message."""
|
||||
msg = ResponseDone()
|
||||
assert msg.type == "response.done"
|
||||
|
||||
def test_error_event(self):
|
||||
"""Test error event message."""
|
||||
msg = ErrorEvent(message="Something went wrong", code="test_error")
|
||||
data = json.loads(msg.model_dump_json())
|
||||
assert data["type"] == "error"
|
||||
assert data["message"] == "Something went wrong"
|
||||
assert data["code"] == "test_error"
|
||||
|
||||
def test_error_event_default_code(self):
|
||||
"""Test error event with default code."""
|
||||
msg = ErrorEvent(message="Oops")
|
||||
assert msg.code == "unknown_error"
|
||||
|
||||
|
||||
class TestWebSocketEndpoint:
|
||||
"""Tests for the /api/v1/realtime WebSocket endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
"""Create a test client with lifespan to populate app.state."""
|
||||
from stentor.main import app
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
def test_health_live(self, client):
|
||||
"""Test liveness endpoint."""
|
||||
response = client.get("/api/live/")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
def test_api_info(self, client):
|
||||
"""Test API info endpoint."""
|
||||
response = client.get("/api/v1/info")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "stentor-gateway"
|
||||
assert data["version"] == "0.1.0"
|
||||
assert "realtime" in data["endpoints"]
|
||||
|
||||
def test_websocket_session_lifecycle(self, client):
|
||||
"""Test basic WebSocket session start and close."""
|
||||
with client.websocket_connect("/api/v1/realtime") as ws:
|
||||
# Send session.start
|
||||
ws.send_json({
|
||||
"type": "session.start",
|
||||
"client_id": "test-client",
|
||||
})
|
||||
|
||||
# Receive session.created
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "session.created"
|
||||
assert "session_id" in msg
|
||||
|
||||
# Receive initial status: listening
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "status"
|
||||
assert msg["state"] == "listening"
|
||||
|
||||
# Send session.close
|
||||
ws.send_json({"type": "session.close"})
|
||||
|
||||
def test_websocket_no_session_error(self, client):
|
||||
"""Test sending audio without starting a session."""
|
||||
with client.websocket_connect("/api/v1/realtime") as ws:
|
||||
ws.send_json({
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": "AAAA",
|
||||
})
|
||||
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert msg["code"] == "no_session"
|
||||
|
||||
def test_websocket_invalid_json(self, client):
|
||||
"""Test sending invalid JSON."""
|
||||
with client.websocket_connect("/api/v1/realtime") as ws:
|
||||
ws.send_text("not valid json{{{")
|
||||
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert msg["code"] == "invalid_json"
|
||||
|
||||
def test_websocket_unknown_event(self, client):
|
||||
"""Test sending unknown event type."""
|
||||
with client.websocket_connect("/api/v1/realtime") as ws:
|
||||
ws.send_json({"type": "scooby.dooby.doo"})
|
||||
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert msg["code"] == "unknown_event"
|
||||
Reference in New Issue
Block a user