feat: add initial Hold Slayer AI telephony gateway implementation
Complete project scaffolding and core implementation of an AI-powered telephony system that calls companies, navigates IVR menus, waits on hold, and transfers to the user when a human answers. Key components: - FastAPI server with REST API, WebSocket, and MCP (SSE) interfaces - SIP/VoIP call management via PJSUA2 with RTP audio streaming - LLM-powered IVR navigation using OpenAI/Anthropic with tool calling - Hold detection service combining audio analysis and silence detection - Real-time STT (Whisper/Deepgram) and TTS (OpenAI/Piper) pipelines - Call recording with per-channel and mixed audio capture - Event bus (asyncio pub/sub) for real-time client updates - Web dashboard with live call monitoring - SQLite persistence via SQLAlchemy with call history and analytics - Notification support (email, SMS, webhook, desktop) - Docker Compose deployment with Opal VoIP and Opal Media containers - Comprehensive test suite with unit, integration, and E2E tests - Simplified .gitignore and full project documentation in README
This commit is contained in:
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Hold Slayer tests."""
|
||||
253
tests/test_audio_classifier.py
Normal file
253
tests/test_audio_classifier.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Tests for the audio classifier.
|
||||
|
||||
Tests spectral analysis, DTMF detection, and classification logic.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from config import ClassifierSettings
|
||||
from models.call import AudioClassification
|
||||
from services.audio_classifier import AudioClassifier, SAMPLE_RATE
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def classifier():
|
||||
"""Create a classifier with default settings."""
|
||||
settings = ClassifierSettings()
|
||||
return AudioClassifier(settings)
|
||||
|
||||
|
||||
def generate_silence(duration_seconds: float = 1.0) -> bytes:
|
||||
"""Generate silent audio (near-zero amplitude)."""
|
||||
samples = int(SAMPLE_RATE * duration_seconds)
|
||||
data = np.zeros(samples, dtype=np.int16)
|
||||
return data.tobytes()
|
||||
|
||||
|
||||
def generate_tone(frequency: float, duration_seconds: float = 1.0, amplitude: float = 0.5) -> bytes:
|
||||
"""Generate a pure sine tone."""
|
||||
samples = int(SAMPLE_RATE * duration_seconds)
|
||||
t = np.linspace(0, duration_seconds, samples, endpoint=False)
|
||||
signal = (amplitude * 32767 * np.sin(2 * np.pi * frequency * t)).astype(np.int16)
|
||||
return signal.tobytes()
|
||||
|
||||
|
||||
def generate_dtmf(digit: str, duration_seconds: float = 0.5) -> bytes:
|
||||
"""Generate a DTMF tone for a digit."""
|
||||
dtmf_freqs = {
|
||||
"1": (697, 1209), "2": (697, 1336), "3": (697, 1477),
|
||||
"4": (770, 1209), "5": (770, 1336), "6": (770, 1477),
|
||||
"7": (852, 1209), "8": (852, 1336), "9": (852, 1477),
|
||||
"*": (941, 1209), "0": (941, 1336), "#": (941, 1477),
|
||||
}
|
||||
low_freq, high_freq = dtmf_freqs[digit]
|
||||
samples = int(SAMPLE_RATE * duration_seconds)
|
||||
t = np.linspace(0, duration_seconds, samples, endpoint=False)
|
||||
signal = 0.5 * (np.sin(2 * np.pi * low_freq * t) + np.sin(2 * np.pi * high_freq * t))
|
||||
signal = (signal * 16383).astype(np.int16)
|
||||
return signal.tobytes()
|
||||
|
||||
|
||||
def generate_noise(duration_seconds: float = 1.0, amplitude: float = 0.3) -> bytes:
|
||||
"""Generate white noise."""
|
||||
samples = int(SAMPLE_RATE * duration_seconds)
|
||||
noise = np.random.normal(0, amplitude * 32767, samples).astype(np.int16)
|
||||
return noise.tobytes()
|
||||
|
||||
|
||||
def generate_speech_like(duration_seconds: float = 1.0) -> bytes:
|
||||
"""
|
||||
Generate a rough approximation of speech.
|
||||
Mix of formant-like frequencies with amplitude modulation.
|
||||
"""
|
||||
samples = int(SAMPLE_RATE * duration_seconds)
|
||||
t = np.linspace(0, duration_seconds, samples, endpoint=False)
|
||||
|
||||
# Fundamental frequency (pitch) with vibrato
|
||||
f0 = 150 + 10 * np.sin(2 * np.pi * 5 * t)
|
||||
fundamental = np.sin(2 * np.pi * f0 * t)
|
||||
|
||||
# Formants (vowel-like)
|
||||
f1 = np.sin(2 * np.pi * 730 * t) * 0.5
|
||||
f2 = np.sin(2 * np.pi * 1090 * t) * 0.3
|
||||
f3 = np.sin(2 * np.pi * 2440 * t) * 0.1
|
||||
|
||||
# Amplitude modulation (syllable-like rhythm)
|
||||
envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 3 * t)
|
||||
|
||||
signal = envelope * (fundamental + f1 + f2 + f3)
|
||||
signal = (signal * 8000).astype(np.int16)
|
||||
return signal.tobytes()
|
||||
|
||||
|
||||
class TestSilenceDetection:
|
||||
"""Test silence classification."""
|
||||
|
||||
def test_pure_silence(self, classifier):
|
||||
result = classifier.classify_chunk(generate_silence())
|
||||
assert result.audio_type == AudioClassification.SILENCE
|
||||
assert result.confidence > 0.5
|
||||
|
||||
def test_very_quiet(self, classifier):
|
||||
# Near-silent audio
|
||||
quiet = generate_tone(440, amplitude=0.001)
|
||||
result = classifier.classify_chunk(quiet)
|
||||
assert result.audio_type == AudioClassification.SILENCE
|
||||
|
||||
def test_empty_audio(self, classifier):
|
||||
result = classifier.classify_chunk(b"")
|
||||
assert result.audio_type == AudioClassification.SILENCE
|
||||
|
||||
|
||||
class TestToneDetection:
|
||||
"""Test tonal audio classification."""
|
||||
|
||||
def test_440hz_ringback(self, classifier):
|
||||
"""440Hz is North American ring-back tone frequency."""
|
||||
tone = generate_tone(440, amplitude=0.3)
|
||||
result = classifier.classify_chunk(tone)
|
||||
# Should be detected as ringing (440Hz is in the ring-back range)
|
||||
assert result.audio_type in (
|
||||
AudioClassification.RINGING,
|
||||
AudioClassification.MUSIC,
|
||||
)
|
||||
assert result.confidence > 0.5
|
||||
|
||||
def test_1000hz_tone(self, classifier):
|
||||
"""1000Hz tone — not ring-back, should be music or unknown."""
|
||||
tone = generate_tone(1000, amplitude=0.3)
|
||||
result = classifier.classify_chunk(tone)
|
||||
assert result.audio_type != AudioClassification.SILENCE
|
||||
|
||||
|
||||
class TestDTMFDetection:
|
||||
"""Test DTMF tone detection."""
|
||||
|
||||
def test_dtmf_digit_5(self, classifier):
|
||||
dtmf = generate_dtmf("5", duration_seconds=0.5)
|
||||
result = classifier.classify_chunk(dtmf)
|
||||
# DTMF detection should catch this
|
||||
if result.audio_type == AudioClassification.DTMF:
|
||||
assert result.details.get("dtmf_digit") == "5"
|
||||
|
||||
def test_dtmf_digit_0(self, classifier):
|
||||
dtmf = generate_dtmf("0", duration_seconds=0.5)
|
||||
result = classifier.classify_chunk(dtmf)
|
||||
if result.audio_type == AudioClassification.DTMF:
|
||||
assert result.details.get("dtmf_digit") == "0"
|
||||
|
||||
|
||||
class TestMusicDetection:
|
||||
"""Test hold music detection."""
|
||||
|
||||
def test_complex_tone_as_music(self, classifier):
|
||||
"""Multiple frequencies together = more music-like."""
|
||||
samples = int(SAMPLE_RATE * 2)
|
||||
t = np.linspace(0, 2, samples, endpoint=False)
|
||||
|
||||
# Chord: C major (C4 + E4 + G4)
|
||||
signal = (
|
||||
np.sin(2 * np.pi * 261.6 * t)
|
||||
+ np.sin(2 * np.pi * 329.6 * t) * 0.8
|
||||
+ np.sin(2 * np.pi * 392.0 * t) * 0.6
|
||||
)
|
||||
signal = (signal * 6000).astype(np.int16)
|
||||
|
||||
result = classifier.classify_chunk(signal.tobytes())
|
||||
assert result.audio_type in (
|
||||
AudioClassification.MUSIC,
|
||||
AudioClassification.RINGING,
|
||||
AudioClassification.UNKNOWN,
|
||||
)
|
||||
assert result.confidence > 0.3
|
||||
|
||||
|
||||
class TestSpeechDetection:
|
||||
"""Test speech-like audio classification."""
|
||||
|
||||
def test_speech_like_audio(self, classifier):
|
||||
speech = generate_speech_like(2.0)
|
||||
result = classifier.classify_chunk(speech)
|
||||
assert result.audio_type in (
|
||||
AudioClassification.IVR_PROMPT,
|
||||
AudioClassification.LIVE_HUMAN,
|
||||
AudioClassification.MUSIC, # Speech-like can be ambiguous
|
||||
AudioClassification.UNKNOWN,
|
||||
)
|
||||
|
||||
|
||||
class TestClassificationHistory:
|
||||
"""Test history-based transition detection."""
|
||||
|
||||
def test_hold_to_human_transition(self, classifier):
|
||||
"""Detect the music → speech transition."""
|
||||
# Simulate being on hold
|
||||
for _ in range(10):
|
||||
classifier.update_history(AudioClassification.MUSIC)
|
||||
|
||||
# Now speech appears
|
||||
classifier.update_history(AudioClassification.LIVE_HUMAN)
|
||||
classifier.update_history(AudioClassification.LIVE_HUMAN)
|
||||
classifier.update_history(AudioClassification.LIVE_HUMAN)
|
||||
|
||||
assert classifier.detect_hold_to_human_transition()
|
||||
|
||||
def test_no_transition_during_ivr(self, classifier):
|
||||
"""IVR prompt after silence is not a hold→human transition."""
|
||||
for _ in range(5):
|
||||
classifier.update_history(AudioClassification.SILENCE)
|
||||
|
||||
classifier.update_history(AudioClassification.IVR_PROMPT)
|
||||
classifier.update_history(AudioClassification.IVR_PROMPT)
|
||||
classifier.update_history(AudioClassification.IVR_PROMPT)
|
||||
|
||||
# No music in history, so no hold→human transition
|
||||
assert not classifier.detect_hold_to_human_transition()
|
||||
|
||||
def test_not_enough_history(self, classifier):
|
||||
"""Not enough data to detect transition."""
|
||||
classifier.update_history(AudioClassification.MUSIC)
|
||||
classifier.update_history(AudioClassification.LIVE_HUMAN)
|
||||
assert not classifier.detect_hold_to_human_transition()
|
||||
|
||||
|
||||
class TestFeatureExtraction:
|
||||
"""Test individual feature extractors."""
|
||||
|
||||
def test_rms_silence(self, classifier):
|
||||
samples = np.zeros(1000, dtype=np.float32)
|
||||
rms = classifier._compute_rms(samples)
|
||||
assert rms == 0.0
|
||||
|
||||
def test_rms_loud(self, classifier):
|
||||
samples = np.ones(1000, dtype=np.float32) * 0.5
|
||||
rms = classifier._compute_rms(samples)
|
||||
assert rms == pytest.approx(0.5, abs=0.01)
|
||||
|
||||
def test_zcr_silence(self, classifier):
|
||||
samples = np.zeros(1000, dtype=np.float32)
|
||||
zcr = classifier._compute_zero_crossing_rate(samples)
|
||||
assert zcr == 0.0
|
||||
|
||||
def test_zcr_high_freq(self, classifier):
|
||||
"""High frequency signal should have high ZCR."""
|
||||
t = np.linspace(0, 1, SAMPLE_RATE, endpoint=False)
|
||||
samples = np.sin(2 * np.pi * 4000 * t).astype(np.float32)
|
||||
zcr = classifier._compute_zero_crossing_rate(samples)
|
||||
assert zcr > 0.1
|
||||
|
||||
def test_spectral_flatness_tone(self, classifier):
|
||||
"""Pure tone should have low spectral flatness."""
|
||||
t = np.linspace(0, 1, SAMPLE_RATE, endpoint=False)
|
||||
samples = np.sin(2 * np.pi * 440 * t).astype(np.float32)
|
||||
flatness = classifier._compute_spectral_flatness(samples)
|
||||
assert flatness < 0.3
|
||||
|
||||
def test_dominant_frequency(self, classifier):
|
||||
"""Should find the dominant frequency of a pure tone."""
|
||||
t = np.linspace(0, 1, SAMPLE_RATE, endpoint=False)
|
||||
samples = np.sin(2 * np.pi * 1000 * t).astype(np.float32)
|
||||
freq = classifier._compute_dominant_frequency(samples)
|
||||
assert abs(freq - 1000) < 50 # Within 50Hz
|
||||
173
tests/test_call_flows.py
Normal file
173
tests/test_call_flows.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Tests for call flow models and serialization.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from models.call_flow import ActionType, CallFlow, CallFlowCreate, CallFlowStep, CallFlowSummary
|
||||
|
||||
|
||||
class TestCallFlowStep:
|
||||
"""Test CallFlowStep model."""
|
||||
|
||||
def test_basic_dtmf_step(self):
|
||||
step = CallFlowStep(
|
||||
id="press_1",
|
||||
description="Press 1 for English",
|
||||
action=ActionType.DTMF,
|
||||
action_value="1",
|
||||
expect="for english|para español",
|
||||
next_step="main_menu",
|
||||
)
|
||||
assert step.id == "press_1"
|
||||
assert step.action == ActionType.DTMF
|
||||
assert step.action_value == "1"
|
||||
assert step.timeout == 30 # default
|
||||
|
||||
def test_hold_step(self):
|
||||
step = CallFlowStep(
|
||||
id="hold_queue",
|
||||
description="On hold waiting for agent",
|
||||
action=ActionType.HOLD,
|
||||
timeout=7200,
|
||||
next_step="agent_connected",
|
||||
notes="Average hold: 25-45 min. Plays Vivaldi. Kill me.",
|
||||
)
|
||||
assert step.action == ActionType.HOLD
|
||||
assert step.timeout == 7200
|
||||
assert "Vivaldi" in step.notes
|
||||
|
||||
def test_transfer_step(self):
|
||||
step = CallFlowStep(
|
||||
id="connected",
|
||||
description="Agent picked up!",
|
||||
action=ActionType.TRANSFER,
|
||||
action_value="sip_phone",
|
||||
)
|
||||
assert step.action == ActionType.TRANSFER
|
||||
|
||||
|
||||
class TestCallFlow:
|
||||
"""Test CallFlow model."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_flow(self):
|
||||
return CallFlow(
|
||||
id="test-bank",
|
||||
name="Test Bank - Main Line",
|
||||
phone_number="+18005551234",
|
||||
description="Test bank IVR",
|
||||
steps=[
|
||||
CallFlowStep(
|
||||
id="greeting",
|
||||
description="Language selection",
|
||||
action=ActionType.DTMF,
|
||||
action_value="1",
|
||||
expect="for english",
|
||||
next_step="main_menu",
|
||||
),
|
||||
CallFlowStep(
|
||||
id="main_menu",
|
||||
description="Main menu",
|
||||
action=ActionType.LISTEN,
|
||||
next_step="agent_request",
|
||||
fallback_step="agent_request",
|
||||
),
|
||||
CallFlowStep(
|
||||
id="agent_request",
|
||||
description="Request agent",
|
||||
action=ActionType.DTMF,
|
||||
action_value="0",
|
||||
next_step="hold_queue",
|
||||
),
|
||||
CallFlowStep(
|
||||
id="hold_queue",
|
||||
description="Hold queue",
|
||||
action=ActionType.HOLD,
|
||||
timeout=3600,
|
||||
next_step="agent_connected",
|
||||
),
|
||||
CallFlowStep(
|
||||
id="agent_connected",
|
||||
description="Agent connected",
|
||||
action=ActionType.TRANSFER,
|
||||
action_value="sip_phone",
|
||||
),
|
||||
],
|
||||
tags=["bank", "personal"],
|
||||
avg_hold_time=2100,
|
||||
success_rate=0.92,
|
||||
)
|
||||
|
||||
def test_step_count(self, sample_flow):
|
||||
assert len(sample_flow.steps) == 5
|
||||
|
||||
def test_get_step(self, sample_flow):
|
||||
step = sample_flow.get_step("hold_queue")
|
||||
assert step is not None
|
||||
assert step.action == ActionType.HOLD
|
||||
assert step.timeout == 3600
|
||||
|
||||
def test_get_step_not_found(self, sample_flow):
|
||||
assert sample_flow.get_step("nonexistent") is None
|
||||
|
||||
def test_first_step(self, sample_flow):
|
||||
first = sample_flow.first_step()
|
||||
assert first is not None
|
||||
assert first.id == "greeting"
|
||||
|
||||
def test_steps_by_id(self, sample_flow):
|
||||
steps = sample_flow.steps_by_id()
|
||||
assert len(steps) == 5
|
||||
assert "greeting" in steps
|
||||
assert "agent_connected" in steps
|
||||
assert steps["agent_connected"].action == ActionType.TRANSFER
|
||||
|
||||
def test_serialization_roundtrip(self, sample_flow):
|
||||
"""Test JSON serialization and deserialization."""
|
||||
json_str = sample_flow.model_dump_json()
|
||||
restored = CallFlow.model_validate_json(json_str)
|
||||
assert restored.id == sample_flow.id
|
||||
assert len(restored.steps) == len(sample_flow.steps)
|
||||
assert restored.steps[0].id == "greeting"
|
||||
assert restored.avg_hold_time == 2100
|
||||
|
||||
|
||||
class TestCallFlowCreate:
|
||||
"""Test call flow creation model."""
|
||||
|
||||
def test_minimal_create(self):
|
||||
create = CallFlowCreate(
|
||||
name="My Bank",
|
||||
phone_number="+18005551234",
|
||||
steps=[
|
||||
CallFlowStep(
|
||||
id="start",
|
||||
description="Start",
|
||||
action=ActionType.HOLD,
|
||||
next_step="end",
|
||||
),
|
||||
],
|
||||
)
|
||||
assert create.name == "My Bank"
|
||||
assert len(create.steps) == 1
|
||||
assert create.tags == []
|
||||
assert create.notes is None
|
||||
|
||||
|
||||
class TestCallFlowSummary:
|
||||
"""Test lightweight summary model."""
|
||||
|
||||
def test_summary(self):
|
||||
summary = CallFlowSummary(
|
||||
id="chase-bank-main",
|
||||
name="Chase Bank - Main",
|
||||
phone_number="+18005551234",
|
||||
step_count=6,
|
||||
avg_hold_time=2100,
|
||||
success_rate=0.92,
|
||||
times_used=15,
|
||||
tags=["bank"],
|
||||
)
|
||||
assert summary.step_count == 6
|
||||
assert summary.success_rate == 0.92
|
||||
265
tests/test_hold_slayer.py
Normal file
265
tests/test_hold_slayer.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Tests for the Hold Slayer service.
|
||||
|
||||
Uses MockSIPEngine to test the state machine without real SIP.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from config import Settings
|
||||
from core.call_manager import CallManager
|
||||
from core.event_bus import EventBus
|
||||
from core.sip_engine import MockSIPEngine
|
||||
from models.call import ActiveCall, AudioClassification, CallMode, CallStatus
|
||||
from models.call_flow import ActionType, CallFlow, CallFlowStep
|
||||
from services.hold_slayer import HoldSlayerService
|
||||
|
||||
|
||||
class TestMenuNavigation:
|
||||
"""Test the IVR menu navigation logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def hold_slayer(self):
|
||||
"""Create a HoldSlayerService with mock dependencies."""
|
||||
from config import ClassifierSettings, SpeachesSettings
|
||||
from services.audio_classifier import AudioClassifier
|
||||
from services.transcription import TranscriptionService
|
||||
|
||||
settings = Settings()
|
||||
event_bus = EventBus()
|
||||
call_manager = CallManager(event_bus)
|
||||
sip_engine = MockSIPEngine()
|
||||
classifier = AudioClassifier(ClassifierSettings())
|
||||
transcription = TranscriptionService(SpeachesSettings())
|
||||
|
||||
return HoldSlayerService(
|
||||
gateway=None, # Not needed for menu tests
|
||||
call_manager=call_manager,
|
||||
sip_engine=sip_engine,
|
||||
classifier=classifier,
|
||||
transcription=transcription,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
def test_decide_cancel_card(self, hold_slayer):
|
||||
"""Should match 'cancel' intent to card cancellation option."""
|
||||
transcript = (
|
||||
"Press 1 for account balance, press 2 for recent transactions, "
|
||||
"press 3 to report a lost or stolen card, press 4 to cancel your card, "
|
||||
"press 0 to speak with a representative."
|
||||
)
|
||||
result = hold_slayer._decide_menu_option(
|
||||
transcript, "cancel my credit card", None
|
||||
)
|
||||
assert result == "4"
|
||||
|
||||
def test_decide_dispute_charge(self, hold_slayer):
|
||||
"""Should match 'dispute' intent to billing option."""
|
||||
transcript = (
|
||||
"Press 1 for account balance, press 2 for billing and disputes, "
|
||||
"press 3 for payments, press 0 for agent."
|
||||
)
|
||||
result = hold_slayer._decide_menu_option(
|
||||
transcript, "dispute a charge on my statement", None
|
||||
)
|
||||
assert result == "2"
|
||||
|
||||
def test_decide_agent_fallback(self, hold_slayer):
|
||||
"""Should fall back to agent option when no match."""
|
||||
transcript = (
|
||||
"Press 1 for mortgage, press 2 for auto loans, "
|
||||
"press 3 for investments, press 0 to speak with a representative."
|
||||
)
|
||||
result = hold_slayer._decide_menu_option(
|
||||
transcript, "cancel my credit card", None
|
||||
)
|
||||
# Should choose representative since no direct match
|
||||
assert result == "0"
|
||||
|
||||
def test_decide_no_options_found(self, hold_slayer):
|
||||
"""Return None when transcript has no recognizable menu."""
|
||||
transcript = "Please hold while we transfer your call."
|
||||
result = hold_slayer._decide_menu_option(
|
||||
transcript, "cancel my card", None
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_decide_alternate_pattern(self, hold_slayer):
|
||||
"""Handle 'for X, press N' pattern."""
|
||||
transcript = (
|
||||
"For account balance, press 1. For billing inquiries, press 2. "
|
||||
"For card cancellation, press 3."
|
||||
)
|
||||
result = hold_slayer._decide_menu_option(
|
||||
transcript, "cancel my card", None
|
||||
)
|
||||
# Should match card cancellation
|
||||
assert result == "3"
|
||||
|
||||
def test_decide_fraud_intent(self, hold_slayer):
|
||||
"""Match fraud-related intent."""
|
||||
transcript = (
|
||||
"Press 1 for balance, press 2 for payments, "
|
||||
"press 3 to report fraud or unauthorized transactions, "
|
||||
"press 0 for an agent."
|
||||
)
|
||||
result = hold_slayer._decide_menu_option(
|
||||
transcript, "report unauthorized charge on my card", None
|
||||
)
|
||||
assert result == "3"
|
||||
|
||||
|
||||
class TestEventBus:
|
||||
"""Test the event bus pub/sub system."""
|
||||
|
||||
@pytest.fixture
|
||||
def event_bus(self):
|
||||
return EventBus()
|
||||
|
||||
def test_subscribe(self, event_bus):
|
||||
sub = event_bus.subscribe()
|
||||
assert event_bus.subscriber_count == 1
|
||||
sub.close()
|
||||
assert event_bus.subscriber_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_receive(self, event_bus):
|
||||
from models.events import EventType, GatewayEvent
|
||||
|
||||
sub = event_bus.subscribe()
|
||||
|
||||
event = GatewayEvent(
|
||||
type=EventType.CALL_INITIATED,
|
||||
call_id="test_123",
|
||||
message="Test event",
|
||||
)
|
||||
await event_bus.publish(event)
|
||||
|
||||
received = await asyncio.wait_for(sub.__anext__(), timeout=1.0)
|
||||
assert received.type == EventType.CALL_INITIATED
|
||||
assert received.call_id == "test_123"
|
||||
sub.close()
|
||||
|
||||
def test_history(self, event_bus):
|
||||
assert len(event_bus.recent_events) == 0
|
||||
|
||||
|
||||
class TestCallManager:
|
||||
"""Test call manager state tracking."""
|
||||
|
||||
@pytest.fixture
|
||||
def call_manager(self):
|
||||
event_bus = EventBus()
|
||||
return CallManager(event_bus)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_call(self, call_manager):
|
||||
call = await call_manager.create_call(
|
||||
remote_number="+18005551234",
|
||||
mode=CallMode.HOLD_SLAYER,
|
||||
intent="cancel my card",
|
||||
)
|
||||
assert call.id.startswith("call_")
|
||||
assert call.remote_number == "+18005551234"
|
||||
assert call.mode == CallMode.HOLD_SLAYER
|
||||
assert call.intent == "cancel my card"
|
||||
assert call.status == CallStatus.INITIATING
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status(self, call_manager):
|
||||
call = await call_manager.create_call(
|
||||
remote_number="+18005551234",
|
||||
mode=CallMode.DIRECT,
|
||||
)
|
||||
await call_manager.update_status(call.id, CallStatus.RINGING)
|
||||
|
||||
updated = call_manager.get_call(call.id)
|
||||
assert updated.status == CallStatus.RINGING
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_call(self, call_manager):
|
||||
call = await call_manager.create_call(
|
||||
remote_number="+18005551234",
|
||||
mode=CallMode.DIRECT,
|
||||
)
|
||||
ended = await call_manager.end_call(call.id)
|
||||
assert ended is not None
|
||||
assert ended.status == CallStatus.COMPLETED
|
||||
assert call_manager.get_call(call.id) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_call_count(self, call_manager):
|
||||
assert call_manager.active_call_count == 0
|
||||
await call_manager.create_call("+18005551234", CallMode.DIRECT)
|
||||
assert call_manager.active_call_count == 1
|
||||
await call_manager.create_call("+18005559999", CallMode.HOLD_SLAYER)
|
||||
assert call_manager.active_call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_transcript(self, call_manager):
|
||||
call = await call_manager.create_call("+18005551234", CallMode.HOLD_SLAYER)
|
||||
await call_manager.add_transcript(call.id, "Press 1 for English")
|
||||
await call_manager.add_transcript(call.id, "Press 2 for French")
|
||||
|
||||
updated = call_manager.get_call(call.id)
|
||||
assert "Press 1 for English" in updated.transcript
|
||||
assert "Press 2 for French" in updated.transcript
|
||||
|
||||
|
||||
class TestMockSIPEngine:
|
||||
"""Test the mock SIP engine."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self):
|
||||
return MockSIPEngine()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle(self, engine):
|
||||
assert not await engine.is_ready()
|
||||
await engine.start()
|
||||
assert await engine.is_ready()
|
||||
await engine.stop()
|
||||
assert not await engine.is_ready()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_call(self, engine):
|
||||
await engine.start()
|
||||
leg_id = await engine.make_call("+18005551234")
|
||||
assert leg_id.startswith("mock_leg_")
|
||||
assert leg_id in engine._active_legs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hangup(self, engine):
|
||||
await engine.start()
|
||||
leg_id = await engine.make_call("+18005551234")
|
||||
await engine.hangup(leg_id)
|
||||
assert leg_id not in engine._active_legs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dtmf(self, engine):
|
||||
await engine.start()
|
||||
leg_id = await engine.make_call("+18005551234")
|
||||
await engine.send_dtmf(leg_id, "1")
|
||||
await engine.send_dtmf(leg_id, "0")
|
||||
assert engine._active_legs[leg_id]["dtmf_sent"] == ["1", "0"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bridge(self, engine):
|
||||
await engine.start()
|
||||
leg_a = await engine.make_call("+18005551234")
|
||||
leg_b = await engine.make_call("+18005559999")
|
||||
bridge_id = await engine.bridge_calls(leg_a, leg_b)
|
||||
assert bridge_id in engine._bridges
|
||||
await engine.unbridge(bridge_id)
|
||||
assert bridge_id not in engine._bridges
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trunk_status(self, engine):
|
||||
status = await engine.get_trunk_status()
|
||||
assert status["registered"] is False
|
||||
|
||||
await engine.start()
|
||||
status = await engine.get_trunk_status()
|
||||
assert status["registered"] is True
|
||||
557
tests/test_services.py
Normal file
557
tests/test_services.py
Normal file
@@ -0,0 +1,557 @@
|
||||
"""
|
||||
Tests for the intelligence layer services:
|
||||
- LLMClient
|
||||
- NotificationService
|
||||
- RecordingService
|
||||
- CallAnalytics
|
||||
- CallFlowLearner
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from config import Settings
|
||||
from core.event_bus import EventBus
|
||||
from models.events import EventType, GatewayEvent
|
||||
|
||||
|
||||
# ============================================================
|
||||
# LLM Client Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestLLMClient:
|
||||
"""Test the LLM client with mocked HTTP responses."""
|
||||
|
||||
def _make_client(self):
|
||||
from services.llm_client import LLMClient
|
||||
|
||||
return LLMClient(
|
||||
base_url="http://localhost:11434/v1",
|
||||
model="llama3",
|
||||
api_key="not-needed",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init(self):
|
||||
client = self._make_client()
|
||||
assert client.model == "llama3"
|
||||
assert client._total_requests == 0
|
||||
assert client._total_errors == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stats(self):
|
||||
client = self._make_client()
|
||||
stats = client.stats
|
||||
assert stats["total_requests"] == 0
|
||||
assert stats["total_errors"] == 0
|
||||
assert stats["model"] == "llama3"
|
||||
assert stats["avg_latency_ms"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_request_format(self):
|
||||
"""Verify the HTTP request is formatted correctly."""
|
||||
client = self._make_client()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "Hello!"}}],
|
||||
"usage": {"total_tokens": 10},
|
||||
}
|
||||
|
||||
with patch.object(client._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await client.chat("Say hello", system="Hi")
|
||||
assert result == "Hello!"
|
||||
assert client._total_requests == 1
|
||||
|
||||
# Verify the request body
|
||||
call_args = mock_post.call_args
|
||||
body = call_args[1]["json"]
|
||||
assert body["model"] == "llama3"
|
||||
assert len(body["messages"]) == 2
|
||||
assert body["messages"][0]["role"] == "system"
|
||||
assert body["messages"][1]["role"] == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_json_parsing(self):
|
||||
"""Verify JSON response parsing works."""
|
||||
client = self._make_client()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": '{"action": "press_1", "confidence": 0.9}'}}],
|
||||
"usage": {"total_tokens": 20},
|
||||
}
|
||||
|
||||
with patch.object(client._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await client.chat_json("Analyze menu", system="Press 1 for billing")
|
||||
assert result is not None
|
||||
assert result["action"] == "press_1"
|
||||
assert result["confidence"] == 0.9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_json_markdown_extraction(self):
|
||||
"""Verify JSON extraction from markdown code blocks."""
|
||||
client = self._make_client()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": 'Here is the result:\n```json\n{"key": "value"}\n```'
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {"total_tokens": 15},
|
||||
}
|
||||
|
||||
with patch.object(client._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await client.chat_json("Parse this", system="test")
|
||||
assert result is not None
|
||||
assert result["key"] == "value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_http_error_returns_empty(self):
|
||||
"""Verify HTTP errors return empty string gracefully."""
|
||||
client = self._make_client()
|
||||
|
||||
with patch.object(client._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.side_effect = Exception("Connection refused")
|
||||
result = await client.chat("test", system="test")
|
||||
assert result == ""
|
||||
assert client._total_errors == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_ivr_menu(self):
|
||||
"""Verify IVR menu analysis formats correctly."""
|
||||
client = self._make_client()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": '{"action": "press_2", "digit": "2", "confidence": 0.85, "reason": "Option 2 is billing"}'
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {"total_tokens": 30},
|
||||
}
|
||||
|
||||
with patch.object(client._client, "post", new_callable=AsyncMock) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
result = await client.analyze_ivr_menu(
|
||||
transcript="Press 1 for sales, press 2 for billing",
|
||||
intent="dispute a charge",
|
||||
previous_selections=["1"],
|
||||
)
|
||||
assert result is not None
|
||||
assert result["digit"] == "2"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Notification Service Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestNotificationService:
|
||||
"""Test notification routing and deduplication."""
|
||||
|
||||
def _make_service(self):
|
||||
from services.notification import NotificationService
|
||||
|
||||
event_bus = EventBus()
|
||||
settings = Settings()
|
||||
svc = NotificationService(event_bus, settings)
|
||||
return svc, event_bus
|
||||
|
||||
def test_init(self):
|
||||
svc, _ = self._make_service()
|
||||
assert svc._notified == {}
|
||||
|
||||
def test_event_to_notification_human_detected(self):
|
||||
from services.notification import NotificationPriority
|
||||
|
||||
svc, _ = self._make_service()
|
||||
event = GatewayEvent(
|
||||
type=EventType.HUMAN_DETECTED,
|
||||
call_id="call_123",
|
||||
data={"confidence": 0.95},
|
||||
message="Human detected!",
|
||||
)
|
||||
notification = svc._event_to_notification(event)
|
||||
assert notification is not None
|
||||
assert notification.priority == NotificationPriority.CRITICAL
|
||||
assert "Human" in notification.title
|
||||
|
||||
def test_event_to_notification_hold_detected(self):
|
||||
from services.notification import NotificationPriority
|
||||
|
||||
svc, _ = self._make_service()
|
||||
event = GatewayEvent(
|
||||
type=EventType.HOLD_DETECTED,
|
||||
call_id="call_123",
|
||||
data={},
|
||||
message="On hold",
|
||||
)
|
||||
notification = svc._event_to_notification(event)
|
||||
assert notification is not None
|
||||
assert notification.priority == NotificationPriority.NORMAL
|
||||
|
||||
def test_event_to_notification_skip_transcript(self):
|
||||
svc, _ = self._make_service()
|
||||
event = GatewayEvent(
|
||||
type=EventType.TRANSCRIPT_CHUNK,
|
||||
call_id="call_123",
|
||||
data={"text": "hello"},
|
||||
)
|
||||
notification = svc._event_to_notification(event)
|
||||
assert notification is None # Transcripts don't generate notifications
|
||||
|
||||
def test_event_to_notification_call_ended_cleanup(self):
|
||||
svc, _ = self._make_service()
|
||||
# Simulate some tracking data
|
||||
svc._notified["call_123"] = {"some_event"}
|
||||
|
||||
event = GatewayEvent(
|
||||
type=EventType.CALL_ENDED,
|
||||
call_id="call_123",
|
||||
data={},
|
||||
)
|
||||
notification = svc._event_to_notification(event)
|
||||
assert notification is not None
|
||||
assert "call_123" not in svc._notified # Cleaned up
|
||||
|
||||
def test_event_to_notification_call_failed(self):
|
||||
from services.notification import NotificationPriority
|
||||
|
||||
svc, _ = self._make_service()
|
||||
event = GatewayEvent(
|
||||
type=EventType.CALL_FAILED,
|
||||
call_id="call_123",
|
||||
data={},
|
||||
message="Connection timed out",
|
||||
)
|
||||
notification = svc._event_to_notification(event)
|
||||
assert notification is not None
|
||||
assert notification.priority == NotificationPriority.HIGH
|
||||
assert "Connection timed out" in notification.message
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Recording Service Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestRecordingService:
|
||||
"""Test recording lifecycle."""
|
||||
|
||||
def _make_service(self):
|
||||
from services.recording import RecordingService
|
||||
|
||||
return RecordingService(storage_dir="/tmp/test_recordings")
|
||||
|
||||
def test_init(self):
|
||||
svc = self._make_service()
|
||||
assert svc._active_recordings == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recording_path_generation(self):
|
||||
"""Verify recording paths are organized by date."""
|
||||
svc = self._make_service()
|
||||
await svc.start() # Creates storage dir
|
||||
|
||||
session = await svc.start_recording(call_id="call_abc123")
|
||||
assert "call_abc123" in session.filepath_mixed
|
||||
# Should include date-based directory
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
assert today in session.filepath_mixed
|
||||
|
||||
# Clean up
|
||||
await svc.stop_recording("call_abc123")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Call Analytics Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestCallAnalytics:
|
||||
"""Test analytics tracking."""
|
||||
|
||||
def _make_service(self):
|
||||
from services.call_analytics import CallAnalytics
|
||||
|
||||
return CallAnalytics(max_history=1000)
|
||||
|
||||
def test_init(self):
|
||||
svc = self._make_service()
|
||||
assert svc._call_records == []
|
||||
assert svc.total_calls_recorded == 0
|
||||
|
||||
def test_get_summary_empty(self):
|
||||
svc = self._make_service()
|
||||
summary = svc.get_summary(hours=24)
|
||||
assert summary["total_calls"] == 0
|
||||
assert summary["success_rate"] == 0.0
|
||||
|
||||
def test_get_company_stats_unknown(self):
|
||||
svc = self._make_service()
|
||||
stats = svc.get_company_stats("+18005551234")
|
||||
assert stats["total_calls"] == 0
|
||||
|
||||
def test_get_top_numbers_empty(self):
|
||||
svc = self._make_service()
|
||||
top = svc.get_top_numbers(limit=5)
|
||||
assert top == []
|
||||
|
||||
def test_get_hold_time_trend(self):
|
||||
svc = self._make_service()
|
||||
trend = svc.get_hold_time_trend(days=7)
|
||||
assert len(trend) == 7
|
||||
assert all(t["call_count"] == 0 for t in trend)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Call Flow Learner Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestCallFlowLearner:
|
||||
"""Test call flow learning from exploration data."""
|
||||
|
||||
def _make_learner(self):
|
||||
from services.call_flow_learner import CallFlowLearner
|
||||
|
||||
return CallFlowLearner(llm_client=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_flow_from_discoveries(self):
|
||||
"""Test building a call flow from exploration discoveries."""
|
||||
learner = self._make_learner()
|
||||
|
||||
discoveries = [
|
||||
{
|
||||
"audio_type": "ivr_prompt",
|
||||
"transcript": "Press 1 for billing, press 2 for sales",
|
||||
"action_taken": {"dtmf": "1"},
|
||||
},
|
||||
{
|
||||
"audio_type": "ivr_prompt",
|
||||
"transcript": "Press 3 to speak to an agent",
|
||||
"action_taken": {"dtmf": "3"},
|
||||
},
|
||||
{
|
||||
"audio_type": "music",
|
||||
"transcript": "",
|
||||
"action_taken": None,
|
||||
},
|
||||
{
|
||||
"audio_type": "live_human",
|
||||
"transcript": "Hi, thanks for calling. How can I help?",
|
||||
"action_taken": None,
|
||||
},
|
||||
]
|
||||
|
||||
flow = await learner.build_flow(
|
||||
phone_number="+18005551234",
|
||||
discovered_steps=discoveries,
|
||||
intent="cancel my card",
|
||||
company_name="Test Bank",
|
||||
)
|
||||
|
||||
assert flow is not None
|
||||
assert flow.phone_number == "+18005551234"
|
||||
assert "Test Bank" in flow.name
|
||||
assert len(flow.steps) == 4 # IVR, IVR, hold, human
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_flow_no_discoveries(self):
|
||||
"""Test that build_flow returns empty flow when no meaningful data."""
|
||||
learner = self._make_learner()
|
||||
flow = await learner.build_flow(
|
||||
phone_number="+18005551234",
|
||||
discovered_steps=[],
|
||||
)
|
||||
assert flow is not None
|
||||
assert len(flow.steps) == 0
|
||||
assert "empty" in [t.lower() for t in flow.tags]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_discoveries(self):
|
||||
"""Test merging new discoveries into existing flow."""
|
||||
learner = self._make_learner()
|
||||
|
||||
# Build initial flow
|
||||
initial_steps = [
|
||||
{
|
||||
"audio_type": "ivr_prompt",
|
||||
"transcript": "Press 1 for billing",
|
||||
"action_taken": {"dtmf": "1"},
|
||||
},
|
||||
{
|
||||
"audio_type": "music",
|
||||
"transcript": "",
|
||||
"action_taken": None,
|
||||
},
|
||||
]
|
||||
flow = await learner.build_flow(
|
||||
phone_number="+18005551234",
|
||||
discovered_steps=initial_steps,
|
||||
intent="billing inquiry",
|
||||
)
|
||||
original_step_count = len(flow.steps)
|
||||
assert original_step_count == 2
|
||||
|
||||
# Merge new discoveries
|
||||
new_steps = [
|
||||
{
|
||||
"audio_type": "ivr_prompt",
|
||||
"transcript": "Press 1 for billing",
|
||||
"action_taken": {"dtmf": "1"},
|
||||
},
|
||||
{
|
||||
"audio_type": "music",
|
||||
"transcript": "",
|
||||
"action_taken": None,
|
||||
},
|
||||
{
|
||||
"audio_type": "live_human",
|
||||
"transcript": "Hello, billing department",
|
||||
"action_taken": None,
|
||||
},
|
||||
]
|
||||
|
||||
merged = await learner.merge_discoveries(
|
||||
existing_flow=flow,
|
||||
new_steps=new_steps,
|
||||
intent="billing inquiry",
|
||||
)
|
||||
|
||||
assert merged is not None
|
||||
assert merged.times_used == 2 # Incremented
|
||||
assert merged.last_used is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discovery_to_step_types(self):
|
||||
"""Test that different audio types produce correct step actions."""
|
||||
from models.call_flow import ActionType
|
||||
|
||||
learner = self._make_learner()
|
||||
|
||||
# IVR prompt with DTMF
|
||||
step = learner._discovery_to_step(
|
||||
{"audio_type": "ivr_prompt", "transcript": "Press 1", "action_taken": {"dtmf": "1"}},
|
||||
0, [],
|
||||
)
|
||||
assert step is not None
|
||||
assert step.action == ActionType.DTMF
|
||||
assert step.action_value == "1"
|
||||
|
||||
# Hold music
|
||||
step = learner._discovery_to_step(
|
||||
{"audio_type": "music", "transcript": "", "action_taken": None},
|
||||
1, [],
|
||||
)
|
||||
assert step is not None
|
||||
assert step.action == ActionType.HOLD
|
||||
|
||||
# Live human
|
||||
step = learner._discovery_to_step(
|
||||
{"audio_type": "live_human", "transcript": "Hello", "action_taken": None},
|
||||
2, [],
|
||||
)
|
||||
assert step is not None
|
||||
assert step.action == ActionType.TRANSFER
|
||||
|
||||
|
||||
# ============================================================
|
||||
# EventBus Integration Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestEventBusIntegration:
|
||||
"""Test EventBus with real async producers/consumers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_subscribers(self):
|
||||
"""Multiple subscribers each get all events."""
|
||||
bus = EventBus()
|
||||
sub1 = bus.subscribe()
|
||||
sub2 = bus.subscribe()
|
||||
|
||||
event = GatewayEvent(
|
||||
type=EventType.CALL_INITIATED,
|
||||
call_id="call_1",
|
||||
data={},
|
||||
)
|
||||
await bus.publish(event)
|
||||
|
||||
e1 = await asyncio.wait_for(sub1.__anext__(), timeout=1.0)
|
||||
e2 = await asyncio.wait_for(sub2.__anext__(), timeout=1.0)
|
||||
|
||||
assert e1.call_id == "call_1"
|
||||
assert e2.call_id == "call_1"
|
||||
assert bus.subscriber_count == 2
|
||||
|
||||
# Unsubscribe using .close() which passes the internal entry tuple
|
||||
sub1.close()
|
||||
sub2.close()
|
||||
assert bus.subscriber_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_history_limit(self):
|
||||
"""Event history respects max size."""
|
||||
bus = EventBus(max_history=5)
|
||||
|
||||
for i in range(10):
|
||||
await bus.publish(
|
||||
GatewayEvent(
|
||||
type=EventType.IVR_STEP,
|
||||
call_id=f"call_{i}",
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
# recent_events is a property, not a method
|
||||
history = bus.recent_events
|
||||
assert len(history) == 5
|
||||
# Should have the most recent 5
|
||||
assert history[-1].call_id == "call_9"
|
||||
assert history[0].call_id == "call_5"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_type_filtering(self):
|
||||
"""Subscribers can filter by event type."""
|
||||
bus = EventBus()
|
||||
# Only subscribe to hold-related events
|
||||
sub = bus.subscribe(event_types={EventType.HOLD_DETECTED, EventType.HUMAN_DETECTED})
|
||||
|
||||
# Publish multiple event types
|
||||
await bus.publish(GatewayEvent(type=EventType.CALL_INITIATED, call_id="c1", data={}))
|
||||
await bus.publish(GatewayEvent(type=EventType.HOLD_DETECTED, call_id="c1", data={}))
|
||||
await bus.publish(GatewayEvent(type=EventType.IVR_STEP, call_id="c1", data={}))
|
||||
await bus.publish(GatewayEvent(type=EventType.HUMAN_DETECTED, call_id="c1", data={}))
|
||||
|
||||
# Should only receive the 2 matching events
|
||||
e1 = await asyncio.wait_for(sub.__anext__(), timeout=1.0)
|
||||
e2 = await asyncio.wait_for(sub.__anext__(), timeout=1.0)
|
||||
assert e1.type == EventType.HOLD_DETECTED
|
||||
assert e2.type == EventType.HUMAN_DETECTED
|
||||
|
||||
sub.close()
|
||||
Reference in New Issue
Block a user