"""Tests for the WebSocket endpoint and message models.""" import json import pytest from fastapi.testclient import TestClient from stentor.audio import encode_audio from stentor.models import ( AudioConfig, ErrorEvent, InputAudioBufferAppend, ResponseAudioDelta, ResponseAudioDone, ResponseDone, ResponseTextDone, SessionCreated, SessionStart, StatusUpdate, TranscriptDone, ) class TestMessageModels: """Tests for WebSocket message serialization.""" def test_session_start_defaults(self): """Test SessionStart with default values.""" msg = SessionStart() data = msg.model_dump() assert data["type"] == "session.start" assert data["client_id"] == "" assert data["audio_config"]["sample_rate"] == 16000 def test_session_start_custom(self): """Test SessionStart with custom values.""" msg = SessionStart( client_id="esp32-kitchen", audio_config=AudioConfig(sample_rate=24000), ) assert msg.client_id == "esp32-kitchen" assert msg.audio_config.sample_rate == 24000 def test_input_audio_buffer_append(self): """Test audio append message.""" audio_b64 = encode_audio(b"\x00\x00\x01\x00") msg = InputAudioBufferAppend(audio=audio_b64) data = msg.model_dump() assert data["type"] == "input_audio_buffer.append" assert data["audio"] == audio_b64 def test_session_created(self): """Test session created response.""" msg = SessionCreated(session_id="test-uuid") data = json.loads(msg.model_dump_json()) assert data["type"] == "session.created" assert data["session_id"] == "test-uuid" def test_status_update(self): """Test status update message.""" for state in ("listening", "transcribing", "thinking", "speaking"): msg = StatusUpdate(state=state) assert msg.state == state def test_transcript_done(self): """Test transcript done message.""" msg = TranscriptDone(text="Hello world") data = json.loads(msg.model_dump_json()) assert data["type"] == "transcript.done" assert data["text"] == "Hello world" def test_response_text_done(self): """Test response text done message.""" msg = ResponseTextDone(text="I can help with that.") data = json.loads(msg.model_dump_json()) assert data["type"] == "response.text.done" assert data["text"] == "I can help with that." def test_response_audio_delta(self): """Test audio delta message.""" msg = ResponseAudioDelta(delta="AAAA") data = json.loads(msg.model_dump_json()) assert data["type"] == "response.audio.delta" assert data["delta"] == "AAAA" def test_response_audio_done(self): """Test audio done message.""" msg = ResponseAudioDone() assert msg.type == "response.audio.done" def test_response_done(self): """Test response done message.""" msg = ResponseDone() assert msg.type == "response.done" def test_error_event(self): """Test error event message.""" msg = ErrorEvent(message="Something went wrong", code="test_error") data = json.loads(msg.model_dump_json()) assert data["type"] == "error" assert data["message"] == "Something went wrong" assert data["code"] == "test_error" def test_error_event_default_code(self): """Test error event with default code.""" msg = ErrorEvent(message="Oops") assert msg.code == "unknown_error" class TestWebSocketEndpoint: """Tests for the /api/v1/realtime WebSocket endpoint.""" @pytest.fixture def client(self): """Create a test client with lifespan to populate app.state.""" from stentor.main import app with TestClient(app) as c: yield c def test_health_live(self, client): """Test liveness endpoint.""" response = client.get("/api/live/") assert response.status_code == 200 assert response.json()["status"] == "ok" def test_api_info(self, client): """Test API info endpoint.""" response = client.get("/api/v1/info") assert response.status_code == 200 data = response.json() assert data["name"] == "stentor-gateway" assert data["version"] == "0.1.0" assert "realtime" in data["endpoints"] def test_websocket_session_lifecycle(self, client): """Test basic WebSocket session start and close.""" with client.websocket_connect("/api/v1/realtime") as ws: # Send session.start ws.send_json({ "type": "session.start", "client_id": "test-client", }) # Receive session.created msg = ws.receive_json() assert msg["type"] == "session.created" assert "session_id" in msg # Receive initial status: listening msg = ws.receive_json() assert msg["type"] == "status" assert msg["state"] == "listening" # Send session.close ws.send_json({"type": "session.close"}) def test_websocket_no_session_error(self, client): """Test sending audio without starting a session.""" with client.websocket_connect("/api/v1/realtime") as ws: ws.send_json({ "type": "input_audio_buffer.append", "audio": "AAAA", }) msg = ws.receive_json() assert msg["type"] == "error" assert msg["code"] == "no_session" def test_websocket_invalid_json(self, client): """Test sending invalid JSON.""" with client.websocket_connect("/api/v1/realtime") as ws: ws.send_text("not valid json{{{") msg = ws.receive_json() assert msg["type"] == "error" assert msg["code"] == "invalid_json" def test_websocket_unknown_event(self, client): """Test sending unknown event type.""" with client.websocket_connect("/api/v1/realtime") as ws: ws.send_json({"type": "scooby.dooby.doo"}) msg = ws.receive_json() assert msg["type"] == "error" assert msg["code"] == "unknown_event"