""" TEI client tests with mocked HTTP. We mock ``requests.Session.request`` so tests do not require network access or a live Athena instance. """ from __future__ import annotations import json from unittest.mock import MagicMock import pytest from core.tei_client import AthenaAPIError, TEIClient def _mock_response(status: int, body=None) -> MagicMock: resp = MagicMock() resp.status_code = status resp.content = b"{}" if body is None else json.dumps(body).encode() resp.json.return_value = body if body is not None else {} resp.text = json.dumps(body or {}) return resp @pytest.fixture def client(monkeypatch) -> TEIClient: c = TEIClient() c.session = MagicMock() return c class TestConfig: def test_requires_base_url(self, monkeypatch): monkeypatch.delenv("ATHENA_BASE_URL", raising=False) with pytest.raises(ValueError, match="ATHENA_BASE_URL"): TEIClient(api_key="x") def test_requires_api_key(self, monkeypatch): monkeypatch.delenv("ATHENA_API_KEY", raising=False) with pytest.raises(ValueError, match="ATHENA_API_KEY"): TEIClient(base_url="https://example.com") def test_authorization_header(self): c = TEIClient(base_url="https://example.com", api_key="abc123") assert c.session.headers["Authorization"] == "Api-Key abc123" class TestPaths: """Verify each endpoint targets the documented URL.""" def _last_call_url(self, client: TEIClient) -> str: return client.session.request.call_args.kwargs["url"] def test_list_reports_path(self, client): client.session.request.return_value = _mock_response( 200, {"results": [], "next": None} ) client.list_reports() assert self._last_call_url(client) == "https://athena.test/api/v1/tei/reports/" def test_get_tool_path(self, client): client.session.request.return_value = _mock_response(200, {"id": "abc"}) client.get_tool("abc123") assert self._last_call_url(client).endswith("/api/v1/tei/tools/abc123/") def test_calculate_path(self, client): client.session.request.return_value = _mock_response(200, {}) client.calculate("abc") assert self._last_call_url(client).endswith("/api/v1/tei/tools/abc/calculate/") assert client.session.request.call_args.kwargs["method"] == "POST" def test_export_path(self, client): client.session.request.return_value = _mock_response(200, {}) client.export("abc") assert self._last_call_url(client).endswith("/api/v1/tei/tools/abc/export/") def test_aggregate_summary_path(self, client): client.session.request.return_value = _mock_response(200, {}) client.aggregate_summary() assert self._last_call_url(client).endswith("/api/v1/tei/summary/") def test_save_version_path(self, client): client.session.request.return_value = _mock_response(201, {"version_number": 1}) client.save_version("abc", note="initial", date="2026-06-10") url = self._last_call_url(client) assert url.endswith("/api/v1/tei/tools/abc/versions/") body = client.session.request.call_args.kwargs["json"] assert body == {"date": "2026-06-10", "note": "initial"} def test_save_version_defaults_date_to_today(self, client): client.session.request.return_value = _mock_response(201, {"version_number": 1}) client.save_version("abc", note="x") body = client.session.request.call_args.kwargs["json"] assert body["note"] == "x" assert len(body["date"]) == 10 # YYYY-MM-DD def test_patch_value_year_in_query(self, client): client.session.request.return_value = _mock_response(200, {}) client.patch_value("abc", "fkey", year=2, value=100) url = self._last_call_url(client) assert url.endswith("/api/v1/tei/tools/abc/values/fkey/?year=2") body = client.session.request.call_args.kwargs["json"] assert body == {"value": "100"} def test_list_clients_path(self, client): client.session.request.return_value = _mock_response( 200, {"results": [], "next": None} ) client.list_clients(search="acme") assert self._last_call_url(client).endswith("/api/v1/orbit/clients/") assert client.session.request.call_args.kwargs["params"] == {"search": "acme"} def test_list_proposals_filters_by_opportunity(self, client): client.session.request.return_value = _mock_response( 200, {"results": [], "next": None} ) client.list_proposals(opportunity_id=42) assert self._last_call_url(client).endswith("/api/v1/orbit/proposals/") assert client.session.request.call_args.kwargs["params"] == { "opportunity_id": 42 } def test_list_engagements_path(self, client): client.session.request.return_value = _mock_response( 200, {"results": [], "next": None} ) client.list_engagements() assert self._last_call_url(client).endswith( "/api/v1/engagement/engagements/" ) def test_create_proposal_body(self, client): client.session.request.return_value = _mock_response(201, {"id": 7}) client.create_proposal("Acme TEI", opportunity_id=42) body = client.session.request.call_args.kwargs["json"] assert body == {"name": "Acme TEI", "opportunity_id": 42, "status": "Draft"} def test_reorder_fields_body(self, client): client.session.request.return_value = _mock_response(200, {}) client.reorder_fields("rep", [7, 3, 9]) body = client.session.request.call_args.kwargs["json"] assert body == { "field_order": [ {"id": 7, "sort_order": 1}, {"id": 3, "sort_order": 2}, {"id": 9, "sort_order": 3}, ] } class TestErrorHandling: def test_404_raises_athena_error(self, client): client.session.request.return_value = _mock_response( 404, {"detail": "Not found"} ) with pytest.raises(AthenaAPIError) as ei: client.get_tool("missing") assert ei.value.status_code == 404 assert "Not found" in ei.value.detail def test_test_connection_returns_error_dict(self, client): client.session.request.return_value = _mock_response( 401, {"detail": "Invalid token"} ) result = client.test_connection() assert result["status"] == "error" assert result["authenticated"] is False assert result["error_code"] == 401 class TestPagination: def test_walks_next_links(self, client): # First page returns one item with a `next` URL; second page returns # one more item and no next. page1 = _mock_response( 200, { "results": [{"id": 1}], "next": "https://athena.test/api/v1/tei/reports/?page=2", }, ) page2 = _mock_response(200, {"results": [{"id": 2}], "next": None}) client.session.request.return_value = page1 client.session.get.return_value = page2 # follow next via session.get out = client.list_reports() assert [r["id"] for r in out] == [1, 2] class TestRowsFromValue: """_rows_from_value expands friendly dicts into documented wire rows.""" def test_year_underscore_keys(self): rows = TEIClient._rows_from_value( {"field_key": "x", "year_1": 100, "year_2": 200, "risk_adjustment": 0.1} ) assert [(r["field_key"], r["year"], r["value"]) for r in rows] == [ ("x", 1, "100.0"), ("x", 2, "200.0"), ] assert all(r["risk_adjustment"] == "0.1" for r in rows) def test_year_values_dict(self): rows = TEIClient._rows_from_value( {"field_key": "x", "year_values": {"1": 50, "3": 75}, "notes": "hi"} ) assert [(r["year"], r["value"]) for r in rows] == [(1, "50.0"), (3, "75.0")] # Notes land on the first year row only. assert rows[0]["notes"] == "hi" assert rows[1]["notes"] is None def test_initial_becomes_companion_row(self): rows = TEIClient._rows_from_value( {"field_key": "x", "initial": 1000, "year_1": 5} ) companion = [r for r in rows if r["field_key"] == "x_initial"] assert len(companion) == 1 assert companion[0]["year"] is None assert companion[0]["value"] == "1000.0" def test_scalar_value(self): rows = TEIClient._rows_from_value({"field_key": "rate", "value": 0.10}) assert rows == [ { "field_key": "rate", "year": None, "value": "0.1", "risk_adjustment": None, "notes": None, } ] class TestUpdateValuesPayload: def test_flat_rows_in_envelope(self, client): client.session.request.return_value = _mock_response(200, {}) client.update_values( "abc", [{"field_key": "x", "year_1": 100}, {"field_key": "y", "year_1": 200}], ) body = client.session.request.call_args.kwargs["json"] assert "values" in body assert len(body["values"]) == 2 # one row per field/year assert body["values"][0] == { "field_key": "x", "year": 1, "value": "100.0", "risk_adjustment": None, "notes": None, } class TestGetValuesFriendlyShape: def test_documented_years_shape_is_flattened(self, client): client.session.request.return_value = _mock_response( 200, { "tool_id": "abc", "values": [ { "field_key": "ben", "table": "benefits", "is_annual": True, "risk_adjustment": "0.15", "years": { "1": {"value": "100.00", "risk_adjustment": None, "notes": ""}, "2": {"value": "200.00", "risk_adjustment": None, "notes": ""}, }, }, { "field_key": "cost", "table": "costs", "is_annual": True, "years": {"1": {"value": "10.00"}}, }, { "field_key": "cost_initial", "table": "costs", "is_annual": False, "value": "500.00", }, ], }, ) rows = client.get_values("abc") by_key = {r["field_key"]: r for r in rows} assert by_key["ben"]["year_values"] == {"1": 100.0, "2": 200.0} assert by_key["ben"]["risk_adjustment"] == 0.15 # companion *_initial folded into parent, not standalone assert "cost_initial" not in by_key assert by_key["cost"]["initial"] == 500.0