Reorganize Palladium codebase into a modular architecture with `core/` shared logic and `app/` Streamlit UI, separating per-study assets into `studies/YYYYMM_<Vendor>/` folders containing notebooks, seed data, and configuration. Update README to reflect new structure, add `.gitignore` entries for `.env` and study exports, and refresh component documentation.
564 lines
22 KiB
Python
564 lines
22 KiB
Python
"""
|
||
TEI Client — Athena API wrapper for Palladium.
|
||
|
||
Endpoints (per Athena API.yaml, all under ``/api/v1/tei/``):
|
||
|
||
Reports (templates)
|
||
GET /reports/ list_reports
|
||
GET /reports/{public_id}/ get_report
|
||
GET /reports/{public_id}/fields/ list_fields
|
||
PATCH /reports/{public_id}/fields/reorder/ reorder_fields
|
||
|
||
Tools (instances)
|
||
GET /tools/ list_tools
|
||
POST /tools/ create_tool
|
||
GET /tools/{public_id}/ get_tool
|
||
PATCH /tools/{public_id}/ update_tool
|
||
DELETE /tools/{public_id}/ delete_tool
|
||
|
||
Values (data entry)
|
||
GET /tools/{public_id}/values/ get_values
|
||
PUT /tools/{public_id}/values/ update_values (bulk)
|
||
PATCH /tools/{public_id}/values/{field_key}/ patch_value (single)
|
||
|
||
Calculation & summary
|
||
POST /tools/{public_id}/calculate/ calculate
|
||
GET /tools/{public_id}/summary/ get_summary
|
||
GET /summary/ aggregate_summary
|
||
|
||
Versions
|
||
GET /tools/{public_id}/versions/ list_versions
|
||
POST /tools/{public_id}/versions/ save_version
|
||
GET /tools/{public_id}/versions/{n}/ get_version
|
||
|
||
Export
|
||
GET /tools/{public_id}/export/ export
|
||
|
||
Authentication uses the ``Authorization: Api-Key {key}`` header.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import os
|
||
from datetime import datetime
|
||
from typing import Any
|
||
|
||
import requests
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv()
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
API_PREFIX = "/api/v1/tei"
|
||
|
||
|
||
class AthenaAPIError(Exception):
|
||
"""Raised when Athena returns a non-success response."""
|
||
|
||
def __init__(self, status_code: int, detail: str, url: str):
|
||
self.status_code = status_code
|
||
self.detail = detail
|
||
self.url = url
|
||
super().__init__(f"Athena API {status_code} at {url}: {detail}")
|
||
|
||
|
||
class TEIClient:
|
||
"""
|
||
Client for Athena's TEI Calculator API.
|
||
|
||
Wraps every TEI endpoint and provides a few convenience helpers used by
|
||
the Palladium notebooks and Streamlit app.
|
||
|
||
Environment variables (read via python-dotenv):
|
||
ATHENA_BASE_URL e.g. https://athena.nttdata.com
|
||
ATHENA_API_KEY Api-Key value (admin-issued)
|
||
|
||
Example::
|
||
|
||
from core.tei_client import TEIClient
|
||
client = TEIClient()
|
||
client.test_connection()
|
||
for r in client.list_reports():
|
||
print(r["name"])
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
base_url: str | None = None,
|
||
api_key: str | None = None,
|
||
timeout: int = 30,
|
||
):
|
||
self.base_url = (base_url or os.getenv("ATHENA_BASE_URL", "")).rstrip("/")
|
||
self.api_key = api_key or os.getenv("ATHENA_API_KEY", "")
|
||
self.timeout = timeout
|
||
|
||
if not self.base_url:
|
||
raise ValueError(
|
||
"ATHENA_BASE_URL is required. Set it in .env or pass base_url."
|
||
)
|
||
if not self.api_key:
|
||
raise ValueError(
|
||
"ATHENA_API_KEY is required. Set it in .env or pass api_key."
|
||
)
|
||
|
||
self.session = requests.Session()
|
||
self.session.headers.update(
|
||
{
|
||
"Authorization": f"Api-Key {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
"Accept": "application/json",
|
||
}
|
||
)
|
||
|
||
logger.info("TEIClient initialised for %s", self.base_url)
|
||
|
||
# ─────────────────────────────────────────────
|
||
# Internal HTTP helpers
|
||
# ─────────────────────────────────────────────
|
||
|
||
def _url(self, path: str) -> str:
|
||
if not path.startswith("/"):
|
||
path = f"/{path}"
|
||
return f"{self.base_url}{path}"
|
||
|
||
def _request(
|
||
self,
|
||
method: str,
|
||
path: str,
|
||
params: dict | None = None,
|
||
json_data: Any | None = None,
|
||
) -> Any:
|
||
url = self._url(path)
|
||
logger.debug("%s %s", method.upper(), url)
|
||
|
||
try:
|
||
response = self.session.request(
|
||
method=method,
|
||
url=url,
|
||
params=params,
|
||
json=json_data,
|
||
timeout=self.timeout,
|
||
)
|
||
except requests.ConnectionError as e:
|
||
raise AthenaAPIError(0, f"Connection failed: {e}", url) from e
|
||
except requests.Timeout as e:
|
||
raise AthenaAPIError(408, "Request timed out", url) from e
|
||
|
||
if response.status_code >= 400:
|
||
try:
|
||
payload = response.json()
|
||
detail = payload.get("detail") or json.dumps(payload)
|
||
except (json.JSONDecodeError, ValueError, AttributeError):
|
||
detail = response.text
|
||
raise AthenaAPIError(response.status_code, detail, url)
|
||
|
||
if response.status_code == 204 or not response.content:
|
||
return {}
|
||
return response.json()
|
||
|
||
def _get(self, path: str, params: dict | None = None) -> Any:
|
||
return self._request("GET", path, params=params)
|
||
|
||
def _post(self, path: str, data: Any | None = None) -> Any:
|
||
return self._request("POST", path, json_data=data)
|
||
|
||
def _put(self, path: str, data: Any | None = None) -> Any:
|
||
return self._request("PUT", path, json_data=data)
|
||
|
||
def _patch(self, path: str, data: Any | None = None) -> Any:
|
||
return self._request("PATCH", path, json_data=data)
|
||
|
||
def _delete(self, path: str) -> Any:
|
||
return self._request("DELETE", path)
|
||
|
||
def _paginated(self, path: str, params: dict | None = None) -> list[dict]:
|
||
"""
|
||
Fetch all pages of a paginated list endpoint.
|
||
|
||
Athena uses the standard DRF page/results envelope::
|
||
|
||
{"count": N, "next": url|None, "previous": ..., "results": [...]}
|
||
"""
|
||
out: list[dict] = []
|
||
result = self._get(path, params=params)
|
||
while True:
|
||
if isinstance(result, list):
|
||
out.extend(result)
|
||
return out
|
||
if not isinstance(result, dict):
|
||
return out
|
||
out.extend(result.get("results", []) or [])
|
||
next_url = result.get("next")
|
||
if not next_url:
|
||
return out
|
||
# Follow absolute next URL
|
||
try:
|
||
result = self.session.get(next_url, timeout=self.timeout).json()
|
||
except Exception: # pragma: no cover – defensive
|
||
return out
|
||
|
||
# ─────────────────────────────────────────────
|
||
# Connection test
|
||
# ─────────────────────────────────────────────
|
||
|
||
def test_connection(self) -> dict:
|
||
"""Verify API connectivity and authentication."""
|
||
try:
|
||
result = self._get(f"{API_PREFIX}/reports/")
|
||
count = (
|
||
result.get("count", len(result.get("results", [])))
|
||
if isinstance(result, dict)
|
||
else len(result)
|
||
)
|
||
return {
|
||
"status": "ok",
|
||
"base_url": self.base_url,
|
||
"authenticated": True,
|
||
"reports_found": count,
|
||
"timestamp": datetime.now().isoformat(),
|
||
}
|
||
except AthenaAPIError as e:
|
||
return {
|
||
"status": "error",
|
||
"base_url": self.base_url,
|
||
"authenticated": e.status_code != 401,
|
||
"error_code": e.status_code,
|
||
"detail": e.detail,
|
||
"timestamp": datetime.now().isoformat(),
|
||
}
|
||
|
||
# ─────────────────────────────────────────────
|
||
# Reports (templates)
|
||
# ─────────────────────────────────────────────
|
||
|
||
def list_reports(self) -> list[dict]:
|
||
"""List all TEI report templates (auto-paginated)."""
|
||
return self._paginated(f"{API_PREFIX}/reports/")
|
||
|
||
def get_report(self, public_id: str) -> dict:
|
||
"""Get a TEI report template by its public_id."""
|
||
return self._get(f"{API_PREFIX}/reports/{public_id}/")
|
||
|
||
def list_fields(
|
||
self,
|
||
report_public_id: str,
|
||
table: str | None = None,
|
||
) -> list[dict]:
|
||
"""
|
||
Get field definitions for a report.
|
||
|
||
Args:
|
||
report_public_id: The report template's public_id (12-char short UUID).
|
||
table: Optional filter — ``'benefits'`` or ``'costs'``.
|
||
|
||
Returns a list of field-definition dicts. See ``TEIField.from_dict``
|
||
for the expected shape.
|
||
"""
|
||
params = {"table": table} if table else None
|
||
rows = self._paginated(
|
||
f"{API_PREFIX}/reports/{report_public_id}/fields/", params=params
|
||
)
|
||
# Defensive — server-side filter may not be implemented; filter locally.
|
||
if table:
|
||
rows = [r for r in rows if r.get("table") == table]
|
||
rows.sort(key=lambda r: (r.get("table", ""), r.get("sort_order") or 0))
|
||
return rows
|
||
|
||
def create_field(self, report_public_id: str, field: dict) -> dict:
|
||
"""Create a new field definition under a report (admin only)."""
|
||
return self._post(
|
||
f"{API_PREFIX}/reports/{report_public_id}/fields/", data=field
|
||
)
|
||
|
||
def update_field(self, report_public_id: str, field_id: int, **changes) -> dict:
|
||
"""Patch one field definition by its integer id."""
|
||
return self._patch(
|
||
f"{API_PREFIX}/reports/{report_public_id}/fields/{field_id}/",
|
||
data=changes,
|
||
)
|
||
|
||
def delete_field(self, report_public_id: str, field_id: int) -> dict:
|
||
return self._delete(
|
||
f"{API_PREFIX}/reports/{report_public_id}/fields/{field_id}/"
|
||
)
|
||
|
||
def reorder_fields(self, report_public_id: str, field_ids: list[int]) -> dict:
|
||
"""Bulk-reorder fields. Spec: PATCH /reports/{id}/fields/reorder/."""
|
||
return self._patch(
|
||
f"{API_PREFIX}/reports/{report_public_id}/fields/reorder/",
|
||
data={"field_ids": field_ids},
|
||
)
|
||
|
||
# ─────────────────────────────────────────────
|
||
# Tools (instances)
|
||
# ─────────────────────────────────────────────
|
||
|
||
def list_tools(self) -> list[dict]:
|
||
"""List TEI tool instances owned by the current API key."""
|
||
return self._paginated(f"{API_PREFIX}/tools/")
|
||
|
||
def get_tool(self, public_id: str) -> dict:
|
||
"""Get a TEI tool instance by public_id."""
|
||
return self._get(f"{API_PREFIX}/tools/{public_id}/")
|
||
|
||
def create_tool(
|
||
self,
|
||
report_public_id: str,
|
||
proposal: int | None = None,
|
||
engagement: int | None = None,
|
||
name: str | None = None,
|
||
status: str = "draft",
|
||
) -> dict:
|
||
"""
|
||
Create a new TEI tool instance from a report template.
|
||
|
||
Athena scopes a TEI tool to a *Proposal* (which itself belongs to an
|
||
Opportunity) and/or an *Engagement*. Pass the integer PK of either or
|
||
both to link the tool.
|
||
"""
|
||
data: dict[str, Any] = {"report": report_public_id, "status": status}
|
||
if proposal is not None:
|
||
data["proposal"] = proposal
|
||
if engagement is not None:
|
||
data["engagement"] = engagement
|
||
if name:
|
||
data["name"] = name
|
||
return self._post(f"{API_PREFIX}/tools/", data=data)
|
||
|
||
def update_tool(
|
||
self,
|
||
public_id: str,
|
||
name: str | None = None,
|
||
status: str | None = None,
|
||
) -> dict:
|
||
"""Update tool metadata. Only ``name`` and ``status`` are mutable."""
|
||
data: dict[str, Any] = {}
|
||
if name is not None:
|
||
data["name"] = name
|
||
if status is not None:
|
||
data["status"] = status
|
||
return self._patch(f"{API_PREFIX}/tools/{public_id}/", data=data)
|
||
|
||
def delete_tool(self, public_id: str) -> dict:
|
||
return self._delete(f"{API_PREFIX}/tools/{public_id}/")
|
||
|
||
# ─────────────────────────────────────────────
|
||
# Values (data entry)
|
||
# ─────────────────────────────────────────────
|
||
|
||
@staticmethod
|
||
def _normalize_value(value: dict) -> dict:
|
||
"""
|
||
Normalize a value-row dict into the shape the API expects.
|
||
|
||
Accepts any of the following input forms and produces a uniform
|
||
wire-format dict::
|
||
|
||
# annual fields
|
||
{"field_key": "A1", "year_1": 100, "year_2": 200, "year_3": 300, ...}
|
||
{"field_key": "A1", "year_values": {"1": 100, "2": 200, "3": 300}, ...}
|
||
|
||
# non-annual scalars
|
||
{"field_key": "rate", "value": 0.10, ...}
|
||
|
||
Returns a dict like::
|
||
|
||
{"field_key": "A1",
|
||
"year_values": {"1": 100.0, "2": 200.0, "3": 300.0},
|
||
"risk_adjustment": 0.15,
|
||
"notes": "…"}
|
||
"""
|
||
out: dict[str, Any] = {}
|
||
if "field_key" in value:
|
||
out["field_key"] = value["field_key"]
|
||
elif "field" in value:
|
||
out["field_key"] = value["field"]
|
||
|
||
# Collect annual year_N keys into year_values
|
||
year_values: dict[str, float] = {}
|
||
if "year_values" in value and isinstance(value["year_values"], dict):
|
||
for k, v in value["year_values"].items():
|
||
year_values[str(k)] = float(v) if v is not None else 0.0
|
||
for key, raw in value.items():
|
||
if key.startswith("year_"):
|
||
try:
|
||
n = int(key.split("_", 1)[1])
|
||
except ValueError:
|
||
continue
|
||
year_values[str(n)] = float(raw) if raw is not None else 0.0
|
||
|
||
if year_values:
|
||
out["year_values"] = year_values
|
||
if "value" in value and value["value"] is not None and not year_values:
|
||
out["value"] = value["value"]
|
||
if value.get("initial") is not None:
|
||
out["initial"] = float(value["initial"])
|
||
if value.get("risk_adjustment") is not None:
|
||
out["risk_adjustment"] = float(value["risk_adjustment"])
|
||
if value.get("notes"):
|
||
out["notes"] = str(value["notes"])
|
||
return out
|
||
|
||
def get_values(self, public_id: str) -> list[dict]:
|
||
"""Get all current field values for a TEI tool instance."""
|
||
result = self._get(f"{API_PREFIX}/tools/{public_id}/values/")
|
||
if isinstance(result, dict):
|
||
# Could be {"values": [...]} envelope, the TEITool wrapper, or a page
|
||
if "values" in result and isinstance(result["values"], list):
|
||
return result["values"]
|
||
if "results" in result and isinstance(result["results"], list):
|
||
return result["results"]
|
||
return []
|
||
if isinstance(result, list):
|
||
return result
|
||
return []
|
||
|
||
def update_values(self, public_id: str, values: list[dict]) -> dict:
|
||
"""
|
||
Bulk-update field values. See ``_normalize_value`` for accepted shapes.
|
||
"""
|
||
payload = {"values": [self._normalize_value(v) for v in values]}
|
||
return self._put(f"{API_PREFIX}/tools/{public_id}/values/", data=payload)
|
||
|
||
def patch_value(self, public_id: str, field_key: str, **changes) -> dict:
|
||
"""
|
||
Patch a single field value by its ``field_key``.
|
||
|
||
Accepts the same shorthand as ``update_values`` (``year_1=…``, etc).
|
||
"""
|
||
body = self._normalize_value({"field_key": field_key, **changes})
|
||
body.pop("field_key", None) # carried in URL
|
||
return self._patch(
|
||
f"{API_PREFIX}/tools/{public_id}/values/{field_key}/", data=body
|
||
)
|
||
|
||
# ─────────────────────────────────────────────
|
||
# Calculation & summary
|
||
# ─────────────────────────────────────────────
|
||
|
||
def calculate(self, public_id: str) -> dict:
|
||
"""Trigger server-side calculation; returns the updated summary."""
|
||
return self._post(f"{API_PREFIX}/tools/{public_id}/calculate/")
|
||
|
||
def get_summary(self, public_id: str) -> dict:
|
||
"""Return the most-recent summary (404 if never calculated)."""
|
||
return self._get(f"{API_PREFIX}/tools/{public_id}/summary/")
|
||
|
||
def aggregate_summary(self) -> dict:
|
||
"""Aggregate NPV across all tools owned by the current API key."""
|
||
return self._get(f"{API_PREFIX}/summary/")
|
||
|
||
# ─────────────────────────────────────────────
|
||
# Versions
|
||
# ─────────────────────────────────────────────
|
||
|
||
def list_versions(self, public_id: str) -> list[dict]:
|
||
"""List all saved version snapshots for a TEI tool."""
|
||
result = self._get(f"{API_PREFIX}/tools/{public_id}/versions/")
|
||
if isinstance(result, list):
|
||
return result
|
||
if isinstance(result, dict):
|
||
if "results" in result and isinstance(result["results"], list):
|
||
return result["results"]
|
||
if "versions" in result and isinstance(result["versions"], list):
|
||
return result["versions"]
|
||
return []
|
||
|
||
def save_version(self, public_id: str, note: str = "") -> dict:
|
||
"""Snapshot current values + summary as a new version."""
|
||
return self._post(
|
||
f"{API_PREFIX}/tools/{public_id}/versions/",
|
||
data={"note": note},
|
||
)
|
||
|
||
def get_version(self, public_id: str, version_number: int) -> dict:
|
||
"""Get a single version's full snapshot."""
|
||
return self._get(
|
||
f"{API_PREFIX}/tools/{public_id}/versions/{int(version_number)}/"
|
||
)
|
||
|
||
# ─────────────────────────────────────────────
|
||
# Export
|
||
# ─────────────────────────────────────────────
|
||
|
||
def export(self, public_id: str) -> dict:
|
||
"""
|
||
Return the LLM-ready export payload for the report pipeline.
|
||
|
||
The shape is determined by Athena and consumed by Peitho /
|
||
html2docx; Palladium's ``core.export.report_data`` builds on this.
|
||
"""
|
||
return self._get(f"{API_PREFIX}/tools/{public_id}/export/")
|
||
|
||
# ─────────────────────────────────────────────
|
||
# Convenience
|
||
# ─────────────────────────────────────────────
|
||
|
||
def get_benefits(self, public_id: str) -> list[dict]:
|
||
"""Return only benefit-table values (table='benefits')."""
|
||
return [v for v in self.get_values(public_id) if v.get("table") == "benefits"]
|
||
|
||
def get_costs(self, public_id: str) -> list[dict]:
|
||
"""Return only cost-table values (table='costs')."""
|
||
return [v for v in self.get_values(public_id) if v.get("table") == "costs"]
|
||
|
||
def get_tool_with_data(self, public_id: str) -> dict:
|
||
"""
|
||
Bundle a tool, its field definitions, current values, and summary.
|
||
|
||
Convenience for notebook initialisation. The summary is allowed to
|
||
404 (returned as ``None``) when the tool has never been calculated.
|
||
"""
|
||
tool = self.get_tool(public_id)
|
||
report_pid = tool.get("report")
|
||
if isinstance(report_pid, dict):
|
||
report_pid = report_pid.get("id") or report_pid.get("public_id")
|
||
fields = self.list_fields(report_pid) if report_pid else []
|
||
values = self.get_values(public_id)
|
||
try:
|
||
summary = self.get_summary(public_id)
|
||
except AthenaAPIError as e:
|
||
if e.status_code == 404:
|
||
summary = None
|
||
else:
|
||
raise
|
||
return {
|
||
"tool": tool,
|
||
"fields": fields,
|
||
"values": values,
|
||
"summary": summary,
|
||
}
|
||
|
||
# ─────────────────────────────────────────────
|
||
# Display
|
||
# ─────────────────────────────────────────────
|
||
|
||
def __repr__(self) -> str: # pragma: no cover
|
||
return f"TEIClient(base_url='{self.base_url}')"
|
||
|
||
def print_summary(self, public_id: str) -> None:
|
||
"""Pretty-print a financial summary block for notebooks/REPL."""
|
||
s = self.get_summary(public_id)
|
||
|
||
def _f(v: Any, default: float = 0.0) -> float:
|
||
try:
|
||
return float(v) if v is not None else default
|
||
except (TypeError, ValueError):
|
||
return default
|
||
|
||
print("═" * 56)
|
||
print(" TEI Financial Summary")
|
||
print("═" * 56)
|
||
print(f" Total Benefits (PV): ${_f(s.get('total_benefits_pv')):>16,.0f}")
|
||
print(f" Total Costs (PV): ${_f(s.get('total_costs_pv')):>16,.0f}")
|
||
print("─" * 56)
|
||
print(f" Net Present Value: ${_f(s.get('npv')):>16,.0f}")
|
||
print(f" ROI: {_f(s.get('roi')):>15,.0f}%")
|
||
payback = s.get("payback_months")
|
||
payback_str = f"{_f(payback):.1f} months" if payback is not None else "N/A"
|
||
print(f" Payback: {payback_str:>17}")
|
||
print("═" * 56)
|