Files
palladium/core/tei_client/client.py
Robert Helewka a2420ed692 refactor: restructure repo into core/app modules with per-study folders
Reorganize Palladium codebase into a modular architecture with `core/`
shared logic and `app/` Streamlit UI, separating per-study assets into
`studies/YYYYMM_<Vendor>/` folders containing notebooks, seed data, and
configuration. Update README to reflect new structure, add `.gitignore`
entries for `.env` and study exports, and refresh component documentation.
2026-05-20 22:28:12 -04:00

564 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
TEI Client — Athena API wrapper for Palladium.
Endpoints (per Athena API.yaml, all under ``/api/v1/tei/``):
Reports (templates)
GET /reports/ list_reports
GET /reports/{public_id}/ get_report
GET /reports/{public_id}/fields/ list_fields
PATCH /reports/{public_id}/fields/reorder/ reorder_fields
Tools (instances)
GET /tools/ list_tools
POST /tools/ create_tool
GET /tools/{public_id}/ get_tool
PATCH /tools/{public_id}/ update_tool
DELETE /tools/{public_id}/ delete_tool
Values (data entry)
GET /tools/{public_id}/values/ get_values
PUT /tools/{public_id}/values/ update_values (bulk)
PATCH /tools/{public_id}/values/{field_key}/ patch_value (single)
Calculation & summary
POST /tools/{public_id}/calculate/ calculate
GET /tools/{public_id}/summary/ get_summary
GET /summary/ aggregate_summary
Versions
GET /tools/{public_id}/versions/ list_versions
POST /tools/{public_id}/versions/ save_version
GET /tools/{public_id}/versions/{n}/ get_version
Export
GET /tools/{public_id}/export/ export
Authentication uses the ``Authorization: Api-Key {key}`` header.
"""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime
from typing import Any
import requests
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
API_PREFIX = "/api/v1/tei"
class AthenaAPIError(Exception):
"""Raised when Athena returns a non-success response."""
def __init__(self, status_code: int, detail: str, url: str):
self.status_code = status_code
self.detail = detail
self.url = url
super().__init__(f"Athena API {status_code} at {url}: {detail}")
class TEIClient:
"""
Client for Athena's TEI Calculator API.
Wraps every TEI endpoint and provides a few convenience helpers used by
the Palladium notebooks and Streamlit app.
Environment variables (read via python-dotenv):
ATHENA_BASE_URL e.g. https://athena.nttdata.com
ATHENA_API_KEY Api-Key value (admin-issued)
Example::
from core.tei_client import TEIClient
client = TEIClient()
client.test_connection()
for r in client.list_reports():
print(r["name"])
"""
def __init__(
self,
base_url: str | None = None,
api_key: str | None = None,
timeout: int = 30,
):
self.base_url = (base_url or os.getenv("ATHENA_BASE_URL", "")).rstrip("/")
self.api_key = api_key or os.getenv("ATHENA_API_KEY", "")
self.timeout = timeout
if not self.base_url:
raise ValueError(
"ATHENA_BASE_URL is required. Set it in .env or pass base_url."
)
if not self.api_key:
raise ValueError(
"ATHENA_API_KEY is required. Set it in .env or pass api_key."
)
self.session = requests.Session()
self.session.headers.update(
{
"Authorization": f"Api-Key {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
)
logger.info("TEIClient initialised for %s", self.base_url)
# ─────────────────────────────────────────────
# Internal HTTP helpers
# ─────────────────────────────────────────────
def _url(self, path: str) -> str:
if not path.startswith("/"):
path = f"/{path}"
return f"{self.base_url}{path}"
def _request(
self,
method: str,
path: str,
params: dict | None = None,
json_data: Any | None = None,
) -> Any:
url = self._url(path)
logger.debug("%s %s", method.upper(), url)
try:
response = self.session.request(
method=method,
url=url,
params=params,
json=json_data,
timeout=self.timeout,
)
except requests.ConnectionError as e:
raise AthenaAPIError(0, f"Connection failed: {e}", url) from e
except requests.Timeout as e:
raise AthenaAPIError(408, "Request timed out", url) from e
if response.status_code >= 400:
try:
payload = response.json()
detail = payload.get("detail") or json.dumps(payload)
except (json.JSONDecodeError, ValueError, AttributeError):
detail = response.text
raise AthenaAPIError(response.status_code, detail, url)
if response.status_code == 204 or not response.content:
return {}
return response.json()
def _get(self, path: str, params: dict | None = None) -> Any:
return self._request("GET", path, params=params)
def _post(self, path: str, data: Any | None = None) -> Any:
return self._request("POST", path, json_data=data)
def _put(self, path: str, data: Any | None = None) -> Any:
return self._request("PUT", path, json_data=data)
def _patch(self, path: str, data: Any | None = None) -> Any:
return self._request("PATCH", path, json_data=data)
def _delete(self, path: str) -> Any:
return self._request("DELETE", path)
def _paginated(self, path: str, params: dict | None = None) -> list[dict]:
"""
Fetch all pages of a paginated list endpoint.
Athena uses the standard DRF page/results envelope::
{"count": N, "next": url|None, "previous": ..., "results": [...]}
"""
out: list[dict] = []
result = self._get(path, params=params)
while True:
if isinstance(result, list):
out.extend(result)
return out
if not isinstance(result, dict):
return out
out.extend(result.get("results", []) or [])
next_url = result.get("next")
if not next_url:
return out
# Follow absolute next URL
try:
result = self.session.get(next_url, timeout=self.timeout).json()
except Exception: # pragma: no cover defensive
return out
# ─────────────────────────────────────────────
# Connection test
# ─────────────────────────────────────────────
def test_connection(self) -> dict:
"""Verify API connectivity and authentication."""
try:
result = self._get(f"{API_PREFIX}/reports/")
count = (
result.get("count", len(result.get("results", [])))
if isinstance(result, dict)
else len(result)
)
return {
"status": "ok",
"base_url": self.base_url,
"authenticated": True,
"reports_found": count,
"timestamp": datetime.now().isoformat(),
}
except AthenaAPIError as e:
return {
"status": "error",
"base_url": self.base_url,
"authenticated": e.status_code != 401,
"error_code": e.status_code,
"detail": e.detail,
"timestamp": datetime.now().isoformat(),
}
# ─────────────────────────────────────────────
# Reports (templates)
# ─────────────────────────────────────────────
def list_reports(self) -> list[dict]:
"""List all TEI report templates (auto-paginated)."""
return self._paginated(f"{API_PREFIX}/reports/")
def get_report(self, public_id: str) -> dict:
"""Get a TEI report template by its public_id."""
return self._get(f"{API_PREFIX}/reports/{public_id}/")
def list_fields(
self,
report_public_id: str,
table: str | None = None,
) -> list[dict]:
"""
Get field definitions for a report.
Args:
report_public_id: The report template's public_id (12-char short UUID).
table: Optional filter — ``'benefits'`` or ``'costs'``.
Returns a list of field-definition dicts. See ``TEIField.from_dict``
for the expected shape.
"""
params = {"table": table} if table else None
rows = self._paginated(
f"{API_PREFIX}/reports/{report_public_id}/fields/", params=params
)
# Defensive — server-side filter may not be implemented; filter locally.
if table:
rows = [r for r in rows if r.get("table") == table]
rows.sort(key=lambda r: (r.get("table", ""), r.get("sort_order") or 0))
return rows
def create_field(self, report_public_id: str, field: dict) -> dict:
"""Create a new field definition under a report (admin only)."""
return self._post(
f"{API_PREFIX}/reports/{report_public_id}/fields/", data=field
)
def update_field(self, report_public_id: str, field_id: int, **changes) -> dict:
"""Patch one field definition by its integer id."""
return self._patch(
f"{API_PREFIX}/reports/{report_public_id}/fields/{field_id}/",
data=changes,
)
def delete_field(self, report_public_id: str, field_id: int) -> dict:
return self._delete(
f"{API_PREFIX}/reports/{report_public_id}/fields/{field_id}/"
)
def reorder_fields(self, report_public_id: str, field_ids: list[int]) -> dict:
"""Bulk-reorder fields. Spec: PATCH /reports/{id}/fields/reorder/."""
return self._patch(
f"{API_PREFIX}/reports/{report_public_id}/fields/reorder/",
data={"field_ids": field_ids},
)
# ─────────────────────────────────────────────
# Tools (instances)
# ─────────────────────────────────────────────
def list_tools(self) -> list[dict]:
"""List TEI tool instances owned by the current API key."""
return self._paginated(f"{API_PREFIX}/tools/")
def get_tool(self, public_id: str) -> dict:
"""Get a TEI tool instance by public_id."""
return self._get(f"{API_PREFIX}/tools/{public_id}/")
def create_tool(
self,
report_public_id: str,
proposal: int | None = None,
engagement: int | None = None,
name: str | None = None,
status: str = "draft",
) -> dict:
"""
Create a new TEI tool instance from a report template.
Athena scopes a TEI tool to a *Proposal* (which itself belongs to an
Opportunity) and/or an *Engagement*. Pass the integer PK of either or
both to link the tool.
"""
data: dict[str, Any] = {"report": report_public_id, "status": status}
if proposal is not None:
data["proposal"] = proposal
if engagement is not None:
data["engagement"] = engagement
if name:
data["name"] = name
return self._post(f"{API_PREFIX}/tools/", data=data)
def update_tool(
self,
public_id: str,
name: str | None = None,
status: str | None = None,
) -> dict:
"""Update tool metadata. Only ``name`` and ``status`` are mutable."""
data: dict[str, Any] = {}
if name is not None:
data["name"] = name
if status is not None:
data["status"] = status
return self._patch(f"{API_PREFIX}/tools/{public_id}/", data=data)
def delete_tool(self, public_id: str) -> dict:
return self._delete(f"{API_PREFIX}/tools/{public_id}/")
# ─────────────────────────────────────────────
# Values (data entry)
# ─────────────────────────────────────────────
@staticmethod
def _normalize_value(value: dict) -> dict:
"""
Normalize a value-row dict into the shape the API expects.
Accepts any of the following input forms and produces a uniform
wire-format dict::
# annual fields
{"field_key": "A1", "year_1": 100, "year_2": 200, "year_3": 300, ...}
{"field_key": "A1", "year_values": {"1": 100, "2": 200, "3": 300}, ...}
# non-annual scalars
{"field_key": "rate", "value": 0.10, ...}
Returns a dict like::
{"field_key": "A1",
"year_values": {"1": 100.0, "2": 200.0, "3": 300.0},
"risk_adjustment": 0.15,
"notes": ""}
"""
out: dict[str, Any] = {}
if "field_key" in value:
out["field_key"] = value["field_key"]
elif "field" in value:
out["field_key"] = value["field"]
# Collect annual year_N keys into year_values
year_values: dict[str, float] = {}
if "year_values" in value and isinstance(value["year_values"], dict):
for k, v in value["year_values"].items():
year_values[str(k)] = float(v) if v is not None else 0.0
for key, raw in value.items():
if key.startswith("year_"):
try:
n = int(key.split("_", 1)[1])
except ValueError:
continue
year_values[str(n)] = float(raw) if raw is not None else 0.0
if year_values:
out["year_values"] = year_values
if "value" in value and value["value"] is not None and not year_values:
out["value"] = value["value"]
if value.get("initial") is not None:
out["initial"] = float(value["initial"])
if value.get("risk_adjustment") is not None:
out["risk_adjustment"] = float(value["risk_adjustment"])
if value.get("notes"):
out["notes"] = str(value["notes"])
return out
def get_values(self, public_id: str) -> list[dict]:
"""Get all current field values for a TEI tool instance."""
result = self._get(f"{API_PREFIX}/tools/{public_id}/values/")
if isinstance(result, dict):
# Could be {"values": [...]} envelope, the TEITool wrapper, or a page
if "values" in result and isinstance(result["values"], list):
return result["values"]
if "results" in result and isinstance(result["results"], list):
return result["results"]
return []
if isinstance(result, list):
return result
return []
def update_values(self, public_id: str, values: list[dict]) -> dict:
"""
Bulk-update field values. See ``_normalize_value`` for accepted shapes.
"""
payload = {"values": [self._normalize_value(v) for v in values]}
return self._put(f"{API_PREFIX}/tools/{public_id}/values/", data=payload)
def patch_value(self, public_id: str, field_key: str, **changes) -> dict:
"""
Patch a single field value by its ``field_key``.
Accepts the same shorthand as ``update_values`` (``year_1=…``, etc).
"""
body = self._normalize_value({"field_key": field_key, **changes})
body.pop("field_key", None) # carried in URL
return self._patch(
f"{API_PREFIX}/tools/{public_id}/values/{field_key}/", data=body
)
# ─────────────────────────────────────────────
# Calculation & summary
# ─────────────────────────────────────────────
def calculate(self, public_id: str) -> dict:
"""Trigger server-side calculation; returns the updated summary."""
return self._post(f"{API_PREFIX}/tools/{public_id}/calculate/")
def get_summary(self, public_id: str) -> dict:
"""Return the most-recent summary (404 if never calculated)."""
return self._get(f"{API_PREFIX}/tools/{public_id}/summary/")
def aggregate_summary(self) -> dict:
"""Aggregate NPV across all tools owned by the current API key."""
return self._get(f"{API_PREFIX}/summary/")
# ─────────────────────────────────────────────
# Versions
# ─────────────────────────────────────────────
def list_versions(self, public_id: str) -> list[dict]:
"""List all saved version snapshots for a TEI tool."""
result = self._get(f"{API_PREFIX}/tools/{public_id}/versions/")
if isinstance(result, list):
return result
if isinstance(result, dict):
if "results" in result and isinstance(result["results"], list):
return result["results"]
if "versions" in result and isinstance(result["versions"], list):
return result["versions"]
return []
def save_version(self, public_id: str, note: str = "") -> dict:
"""Snapshot current values + summary as a new version."""
return self._post(
f"{API_PREFIX}/tools/{public_id}/versions/",
data={"note": note},
)
def get_version(self, public_id: str, version_number: int) -> dict:
"""Get a single version's full snapshot."""
return self._get(
f"{API_PREFIX}/tools/{public_id}/versions/{int(version_number)}/"
)
# ─────────────────────────────────────────────
# Export
# ─────────────────────────────────────────────
def export(self, public_id: str) -> dict:
"""
Return the LLM-ready export payload for the report pipeline.
The shape is determined by Athena and consumed by Peitho /
html2docx; Palladium's ``core.export.report_data`` builds on this.
"""
return self._get(f"{API_PREFIX}/tools/{public_id}/export/")
# ─────────────────────────────────────────────
# Convenience
# ─────────────────────────────────────────────
def get_benefits(self, public_id: str) -> list[dict]:
"""Return only benefit-table values (table='benefits')."""
return [v for v in self.get_values(public_id) if v.get("table") == "benefits"]
def get_costs(self, public_id: str) -> list[dict]:
"""Return only cost-table values (table='costs')."""
return [v for v in self.get_values(public_id) if v.get("table") == "costs"]
def get_tool_with_data(self, public_id: str) -> dict:
"""
Bundle a tool, its field definitions, current values, and summary.
Convenience for notebook initialisation. The summary is allowed to
404 (returned as ``None``) when the tool has never been calculated.
"""
tool = self.get_tool(public_id)
report_pid = tool.get("report")
if isinstance(report_pid, dict):
report_pid = report_pid.get("id") or report_pid.get("public_id")
fields = self.list_fields(report_pid) if report_pid else []
values = self.get_values(public_id)
try:
summary = self.get_summary(public_id)
except AthenaAPIError as e:
if e.status_code == 404:
summary = None
else:
raise
return {
"tool": tool,
"fields": fields,
"values": values,
"summary": summary,
}
# ─────────────────────────────────────────────
# Display
# ─────────────────────────────────────────────
def __repr__(self) -> str: # pragma: no cover
return f"TEIClient(base_url='{self.base_url}')"
def print_summary(self, public_id: str) -> None:
"""Pretty-print a financial summary block for notebooks/REPL."""
s = self.get_summary(public_id)
def _f(v: Any, default: float = 0.0) -> float:
try:
return float(v) if v is not None else default
except (TypeError, ValueError):
return default
print("" * 56)
print(" TEI Financial Summary")
print("" * 56)
print(f" Total Benefits (PV): ${_f(s.get('total_benefits_pv')):>16,.0f}")
print(f" Total Costs (PV): ${_f(s.get('total_costs_pv')):>16,.0f}")
print("" * 56)
print(f" Net Present Value: ${_f(s.get('npv')):>16,.0f}")
print(f" ROI: {_f(s.get('roi')):>15,.0f}%")
payback = s.get("payback_months")
payback_str = f"{_f(payback):.1f} months" if payback is not None else "N/A"
print(f" Payback: {payback_str:>17}")
print("" * 56)