""" TEI Client — Athena API wrapper for Palladium. Endpoints (per Athena API.yaml, all under ``/api/v1/tei/``): Reports (templates) GET /reports/ list_reports GET /reports/{public_id}/ get_report GET /reports/{public_id}/fields/ list_fields PATCH /reports/{public_id}/fields/reorder/ reorder_fields Tools (instances) GET /tools/ list_tools POST /tools/ create_tool GET /tools/{public_id}/ get_tool PATCH /tools/{public_id}/ update_tool DELETE /tools/{public_id}/ delete_tool Values (data entry) GET /tools/{public_id}/values/ get_values PUT /tools/{public_id}/values/ update_values (bulk) PATCH /tools/{public_id}/values/{field_key}/ patch_value (single) Calculation & summary POST /tools/{public_id}/calculate/ calculate GET /tools/{public_id}/summary/ get_summary GET /summary/ aggregate_summary Versions GET /tools/{public_id}/versions/ list_versions POST /tools/{public_id}/versions/ save_version GET /tools/{public_id}/versions/{n}/ get_version Export GET /tools/{public_id}/export/ export Authentication uses the ``Authorization: Api-Key {key}`` header. """ from __future__ import annotations import json import logging import os from datetime import datetime from typing import Any import requests from dotenv import load_dotenv load_dotenv() logger = logging.getLogger(__name__) API_PREFIX = "/api/v1/tei" class AthenaAPIError(Exception): """Raised when Athena returns a non-success response.""" def __init__(self, status_code: int, detail: str, url: str): self.status_code = status_code self.detail = detail self.url = url super().__init__(f"Athena API {status_code} at {url}: {detail}") class TEIClient: """ Client for Athena's TEI Calculator API. Wraps every TEI endpoint and provides a few convenience helpers used by the Palladium notebooks and Streamlit app. Environment variables (read via python-dotenv): ATHENA_BASE_URL e.g. https://athena.nttdata.com ATHENA_API_KEY Api-Key value (admin-issued) Example:: from core.tei_client import TEIClient client = TEIClient() client.test_connection() for r in client.list_reports(): print(r["name"]) """ def __init__( self, base_url: str | None = None, api_key: str | None = None, timeout: int = 30, ): self.base_url = (base_url or os.getenv("ATHENA_BASE_URL", "")).rstrip("/") self.api_key = api_key or os.getenv("ATHENA_API_KEY", "") self.timeout = timeout if not self.base_url: raise ValueError( "ATHENA_BASE_URL is required. Set it in .env or pass base_url." ) if not self.api_key: raise ValueError( "ATHENA_API_KEY is required. Set it in .env or pass api_key." ) self.session = requests.Session() self.session.headers.update( { "Authorization": f"Api-Key {self.api_key}", "Content-Type": "application/json", "Accept": "application/json", } ) logger.info("TEIClient initialised for %s", self.base_url) # ───────────────────────────────────────────── # Internal HTTP helpers # ───────────────────────────────────────────── def _url(self, path: str) -> str: if not path.startswith("/"): path = f"/{path}" return f"{self.base_url}{path}" def _request( self, method: str, path: str, params: dict | None = None, json_data: Any | None = None, ) -> Any: url = self._url(path) logger.debug("%s %s", method.upper(), url) try: response = self.session.request( method=method, url=url, params=params, json=json_data, timeout=self.timeout, ) except requests.ConnectionError as e: raise AthenaAPIError(0, f"Connection failed: {e}", url) from e except requests.Timeout as e: raise AthenaAPIError(408, "Request timed out", url) from e if response.status_code >= 400: try: payload = response.json() detail = payload.get("detail") or json.dumps(payload) except (json.JSONDecodeError, ValueError, AttributeError): detail = response.text raise AthenaAPIError(response.status_code, detail, url) if response.status_code == 204 or not response.content: return {} return response.json() def _get(self, path: str, params: dict | None = None) -> Any: return self._request("GET", path, params=params) def _post(self, path: str, data: Any | None = None) -> Any: return self._request("POST", path, json_data=data) def _put(self, path: str, data: Any | None = None) -> Any: return self._request("PUT", path, json_data=data) def _patch(self, path: str, data: Any | None = None) -> Any: return self._request("PATCH", path, json_data=data) def _delete(self, path: str) -> Any: return self._request("DELETE", path) def _paginated(self, path: str, params: dict | None = None) -> list[dict]: """ Fetch all pages of a paginated list endpoint. Athena uses the standard DRF page/results envelope:: {"count": N, "next": url|None, "previous": ..., "results": [...]} """ out: list[dict] = [] result = self._get(path, params=params) while True: if isinstance(result, list): out.extend(result) return out if not isinstance(result, dict): return out out.extend(result.get("results", []) or []) next_url = result.get("next") if not next_url: return out # Follow absolute next URL try: result = self.session.get(next_url, timeout=self.timeout).json() except Exception: # pragma: no cover – defensive return out # ───────────────────────────────────────────── # Connection test # ───────────────────────────────────────────── def test_connection(self) -> dict: """Verify API connectivity and authentication.""" try: result = self._get(f"{API_PREFIX}/reports/") count = ( result.get("count", len(result.get("results", []))) if isinstance(result, dict) else len(result) ) return { "status": "ok", "base_url": self.base_url, "authenticated": True, "reports_found": count, "timestamp": datetime.now().isoformat(), } except AthenaAPIError as e: return { "status": "error", "base_url": self.base_url, "authenticated": e.status_code != 401, "error_code": e.status_code, "detail": e.detail, "timestamp": datetime.now().isoformat(), } # ───────────────────────────────────────────── # Reports (templates) # ───────────────────────────────────────────── def list_reports(self) -> list[dict]: """List all TEI report templates (auto-paginated).""" return self._paginated(f"{API_PREFIX}/reports/") def get_report(self, public_id: str) -> dict: """Get a TEI report template by its public_id.""" return self._get(f"{API_PREFIX}/reports/{public_id}/") def list_fields( self, report_public_id: str, table: str | None = None, ) -> list[dict]: """ Get field definitions for a report. Args: report_public_id: The report template's public_id (12-char short UUID). table: Optional filter — ``'benefits'`` or ``'costs'``. Returns a list of field-definition dicts. See ``TEIField.from_dict`` for the expected shape. """ params = {"table": table} if table else None rows = self._paginated( f"{API_PREFIX}/reports/{report_public_id}/fields/", params=params ) # Defensive — server-side filter may not be implemented; filter locally. if table: rows = [r for r in rows if r.get("table") == table] rows.sort(key=lambda r: (r.get("table", ""), r.get("sort_order") or 0)) return rows def create_field(self, report_public_id: str, field: dict) -> dict: """Create a new field definition under a report (admin only).""" return self._post( f"{API_PREFIX}/reports/{report_public_id}/fields/", data=field ) def update_field(self, report_public_id: str, field_id: int, **changes) -> dict: """Patch one field definition by its integer id.""" return self._patch( f"{API_PREFIX}/reports/{report_public_id}/fields/{field_id}/", data=changes, ) def delete_field(self, report_public_id: str, field_id: int) -> dict: return self._delete( f"{API_PREFIX}/reports/{report_public_id}/fields/{field_id}/" ) def reorder_fields(self, report_public_id: str, field_ids: list[int]) -> dict: """Bulk-reorder fields. Spec: PATCH /reports/{id}/fields/reorder/.""" return self._patch( f"{API_PREFIX}/reports/{report_public_id}/fields/reorder/", data={"field_ids": field_ids}, ) # ───────────────────────────────────────────── # Tools (instances) # ───────────────────────────────────────────── def list_tools(self) -> list[dict]: """List TEI tool instances owned by the current API key.""" return self._paginated(f"{API_PREFIX}/tools/") def get_tool(self, public_id: str) -> dict: """Get a TEI tool instance by public_id.""" return self._get(f"{API_PREFIX}/tools/{public_id}/") def create_tool( self, report_public_id: str, proposal: int | None = None, engagement: int | None = None, name: str | None = None, status: str = "draft", ) -> dict: """ Create a new TEI tool instance from a report template. Athena scopes a TEI tool to a *Proposal* (which itself belongs to an Opportunity) and/or an *Engagement*. Pass the integer PK of either or both to link the tool. """ data: dict[str, Any] = {"report": report_public_id, "status": status} if proposal is not None: data["proposal"] = proposal if engagement is not None: data["engagement"] = engagement if name: data["name"] = name return self._post(f"{API_PREFIX}/tools/", data=data) def update_tool( self, public_id: str, name: str | None = None, status: str | None = None, ) -> dict: """Update tool metadata. Only ``name`` and ``status`` are mutable.""" data: dict[str, Any] = {} if name is not None: data["name"] = name if status is not None: data["status"] = status return self._patch(f"{API_PREFIX}/tools/{public_id}/", data=data) def delete_tool(self, public_id: str) -> dict: return self._delete(f"{API_PREFIX}/tools/{public_id}/") # ───────────────────────────────────────────── # Values (data entry) # ───────────────────────────────────────────── @staticmethod def _normalize_value(value: dict) -> dict: """ Normalize a value-row dict into the shape the API expects. Accepts any of the following input forms and produces a uniform wire-format dict:: # annual fields {"field_key": "A1", "year_1": 100, "year_2": 200, "year_3": 300, ...} {"field_key": "A1", "year_values": {"1": 100, "2": 200, "3": 300}, ...} # non-annual scalars {"field_key": "rate", "value": 0.10, ...} Returns a dict like:: {"field_key": "A1", "year_values": {"1": 100.0, "2": 200.0, "3": 300.0}, "risk_adjustment": 0.15, "notes": "…"} """ out: dict[str, Any] = {} if "field_key" in value: out["field_key"] = value["field_key"] elif "field" in value: out["field_key"] = value["field"] # Collect annual year_N keys into year_values year_values: dict[str, float] = {} if "year_values" in value and isinstance(value["year_values"], dict): for k, v in value["year_values"].items(): year_values[str(k)] = float(v) if v is not None else 0.0 for key, raw in value.items(): if key.startswith("year_"): try: n = int(key.split("_", 1)[1]) except ValueError: continue year_values[str(n)] = float(raw) if raw is not None else 0.0 if year_values: out["year_values"] = year_values if "value" in value and value["value"] is not None and not year_values: out["value"] = value["value"] if value.get("initial") is not None: out["initial"] = float(value["initial"]) if value.get("risk_adjustment") is not None: out["risk_adjustment"] = float(value["risk_adjustment"]) if value.get("notes"): out["notes"] = str(value["notes"]) return out def get_values(self, public_id: str) -> list[dict]: """Get all current field values for a TEI tool instance.""" result = self._get(f"{API_PREFIX}/tools/{public_id}/values/") if isinstance(result, dict): # Could be {"values": [...]} envelope, the TEITool wrapper, or a page if "values" in result and isinstance(result["values"], list): return result["values"] if "results" in result and isinstance(result["results"], list): return result["results"] return [] if isinstance(result, list): return result return [] def update_values(self, public_id: str, values: list[dict]) -> dict: """ Bulk-update field values. See ``_normalize_value`` for accepted shapes. """ payload = {"values": [self._normalize_value(v) for v in values]} return self._put(f"{API_PREFIX}/tools/{public_id}/values/", data=payload) def patch_value(self, public_id: str, field_key: str, **changes) -> dict: """ Patch a single field value by its ``field_key``. Accepts the same shorthand as ``update_values`` (``year_1=…``, etc). """ body = self._normalize_value({"field_key": field_key, **changes}) body.pop("field_key", None) # carried in URL return self._patch( f"{API_PREFIX}/tools/{public_id}/values/{field_key}/", data=body ) # ───────────────────────────────────────────── # Calculation & summary # ───────────────────────────────────────────── def calculate(self, public_id: str) -> dict: """Trigger server-side calculation; returns the updated summary.""" return self._post(f"{API_PREFIX}/tools/{public_id}/calculate/") def get_summary(self, public_id: str) -> dict: """Return the most-recent summary (404 if never calculated).""" return self._get(f"{API_PREFIX}/tools/{public_id}/summary/") def aggregate_summary(self) -> dict: """Aggregate NPV across all tools owned by the current API key.""" return self._get(f"{API_PREFIX}/summary/") # ───────────────────────────────────────────── # Versions # ───────────────────────────────────────────── def list_versions(self, public_id: str) -> list[dict]: """List all saved version snapshots for a TEI tool.""" result = self._get(f"{API_PREFIX}/tools/{public_id}/versions/") if isinstance(result, list): return result if isinstance(result, dict): if "results" in result and isinstance(result["results"], list): return result["results"] if "versions" in result and isinstance(result["versions"], list): return result["versions"] return [] def save_version(self, public_id: str, note: str = "") -> dict: """Snapshot current values + summary as a new version.""" return self._post( f"{API_PREFIX}/tools/{public_id}/versions/", data={"note": note}, ) def get_version(self, public_id: str, version_number: int) -> dict: """Get a single version's full snapshot.""" return self._get( f"{API_PREFIX}/tools/{public_id}/versions/{int(version_number)}/" ) # ───────────────────────────────────────────── # Export # ───────────────────────────────────────────── def export(self, public_id: str) -> dict: """ Return the LLM-ready export payload for the report pipeline. The shape is determined by Athena and consumed by Peitho / html2docx; Palladium's ``core.export.report_data`` builds on this. """ return self._get(f"{API_PREFIX}/tools/{public_id}/export/") # ───────────────────────────────────────────── # Convenience # ───────────────────────────────────────────── def get_benefits(self, public_id: str) -> list[dict]: """Return only benefit-table values (table='benefits').""" return [v for v in self.get_values(public_id) if v.get("table") == "benefits"] def get_costs(self, public_id: str) -> list[dict]: """Return only cost-table values (table='costs').""" return [v for v in self.get_values(public_id) if v.get("table") == "costs"] def get_tool_with_data(self, public_id: str) -> dict: """ Bundle a tool, its field definitions, current values, and summary. Convenience for notebook initialisation. The summary is allowed to 404 (returned as ``None``) when the tool has never been calculated. """ tool = self.get_tool(public_id) report_pid = tool.get("report") if isinstance(report_pid, dict): report_pid = report_pid.get("id") or report_pid.get("public_id") fields = self.list_fields(report_pid) if report_pid else [] values = self.get_values(public_id) try: summary = self.get_summary(public_id) except AthenaAPIError as e: if e.status_code == 404: summary = None else: raise return { "tool": tool, "fields": fields, "values": values, "summary": summary, } # ───────────────────────────────────────────── # Display # ───────────────────────────────────────────── def __repr__(self) -> str: # pragma: no cover return f"TEIClient(base_url='{self.base_url}')" def print_summary(self, public_id: str) -> None: """Pretty-print a financial summary block for notebooks/REPL.""" s = self.get_summary(public_id) def _f(v: Any, default: float = 0.0) -> float: try: return float(v) if v is not None else default except (TypeError, ValueError): return default print("═" * 56) print(" TEI Financial Summary") print("═" * 56) print(f" Total Benefits (PV): ${_f(s.get('total_benefits_pv')):>16,.0f}") print(f" Total Costs (PV): ${_f(s.get('total_costs_pv')):>16,.0f}") print("─" * 56) print(f" Net Present Value: ${_f(s.get('npv')):>16,.0f}") print(f" ROI: {_f(s.get('roi')):>15,.0f}%") payback = s.get("payback_months") payback_str = f"{_f(payback):.1f} months" if payback is not None else "N/A" print(f" Payback: {payback_str:>17}") print("═" * 56)