refactor: restructure repo into core/app modules with per-study folders
Reorganize Palladium codebase into a modular architecture with `core/` shared logic and `app/` Streamlit UI, separating per-study assets into `studies/YYYYMM_<Vendor>/` folders containing notebooks, seed data, and configuration. Update README to reflect new structure, add `.gitignore` entries for `.env` and study exports, and refresh component documentation.
This commit is contained in:
563
core/tei_client/client.py
Normal file
563
core/tei_client/client.py
Normal file
@@ -0,0 +1,563 @@
|
||||
"""
|
||||
TEI Client — Athena API wrapper for Palladium.
|
||||
|
||||
Endpoints (per Athena API.yaml, all under ``/api/v1/tei/``):
|
||||
|
||||
Reports (templates)
|
||||
GET /reports/ list_reports
|
||||
GET /reports/{public_id}/ get_report
|
||||
GET /reports/{public_id}/fields/ list_fields
|
||||
PATCH /reports/{public_id}/fields/reorder/ reorder_fields
|
||||
|
||||
Tools (instances)
|
||||
GET /tools/ list_tools
|
||||
POST /tools/ create_tool
|
||||
GET /tools/{public_id}/ get_tool
|
||||
PATCH /tools/{public_id}/ update_tool
|
||||
DELETE /tools/{public_id}/ delete_tool
|
||||
|
||||
Values (data entry)
|
||||
GET /tools/{public_id}/values/ get_values
|
||||
PUT /tools/{public_id}/values/ update_values (bulk)
|
||||
PATCH /tools/{public_id}/values/{field_key}/ patch_value (single)
|
||||
|
||||
Calculation & summary
|
||||
POST /tools/{public_id}/calculate/ calculate
|
||||
GET /tools/{public_id}/summary/ get_summary
|
||||
GET /summary/ aggregate_summary
|
||||
|
||||
Versions
|
||||
GET /tools/{public_id}/versions/ list_versions
|
||||
POST /tools/{public_id}/versions/ save_version
|
||||
GET /tools/{public_id}/versions/{n}/ get_version
|
||||
|
||||
Export
|
||||
GET /tools/{public_id}/export/ export
|
||||
|
||||
Authentication uses the ``Authorization: Api-Key {key}`` header.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
API_PREFIX = "/api/v1/tei"
|
||||
|
||||
|
||||
class AthenaAPIError(Exception):
|
||||
"""Raised when Athena returns a non-success response."""
|
||||
|
||||
def __init__(self, status_code: int, detail: str, url: str):
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
self.url = url
|
||||
super().__init__(f"Athena API {status_code} at {url}: {detail}")
|
||||
|
||||
|
||||
class TEIClient:
|
||||
"""
|
||||
Client for Athena's TEI Calculator API.
|
||||
|
||||
Wraps every TEI endpoint and provides a few convenience helpers used by
|
||||
the Palladium notebooks and Streamlit app.
|
||||
|
||||
Environment variables (read via python-dotenv):
|
||||
ATHENA_BASE_URL e.g. https://athena.nttdata.com
|
||||
ATHENA_API_KEY Api-Key value (admin-issued)
|
||||
|
||||
Example::
|
||||
|
||||
from core.tei_client import TEIClient
|
||||
client = TEIClient()
|
||||
client.test_connection()
|
||||
for r in client.list_reports():
|
||||
print(r["name"])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
timeout: int = 30,
|
||||
):
|
||||
self.base_url = (base_url or os.getenv("ATHENA_BASE_URL", "")).rstrip("/")
|
||||
self.api_key = api_key or os.getenv("ATHENA_API_KEY", "")
|
||||
self.timeout = timeout
|
||||
|
||||
if not self.base_url:
|
||||
raise ValueError(
|
||||
"ATHENA_BASE_URL is required. Set it in .env or pass base_url."
|
||||
)
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"ATHENA_API_KEY is required. Set it in .env or pass api_key."
|
||||
)
|
||||
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(
|
||||
{
|
||||
"Authorization": f"Api-Key {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("TEIClient initialised for %s", self.base_url)
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Internal HTTP helpers
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
def _url(self, path: str) -> str:
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
return f"{self.base_url}{path}"
|
||||
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict | None = None,
|
||||
json_data: Any | None = None,
|
||||
) -> Any:
|
||||
url = self._url(path)
|
||||
logger.debug("%s %s", method.upper(), url)
|
||||
|
||||
try:
|
||||
response = self.session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
json=json_data,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
except requests.ConnectionError as e:
|
||||
raise AthenaAPIError(0, f"Connection failed: {e}", url) from e
|
||||
except requests.Timeout as e:
|
||||
raise AthenaAPIError(408, "Request timed out", url) from e
|
||||
|
||||
if response.status_code >= 400:
|
||||
try:
|
||||
payload = response.json()
|
||||
detail = payload.get("detail") or json.dumps(payload)
|
||||
except (json.JSONDecodeError, ValueError, AttributeError):
|
||||
detail = response.text
|
||||
raise AthenaAPIError(response.status_code, detail, url)
|
||||
|
||||
if response.status_code == 204 or not response.content:
|
||||
return {}
|
||||
return response.json()
|
||||
|
||||
def _get(self, path: str, params: dict | None = None) -> Any:
|
||||
return self._request("GET", path, params=params)
|
||||
|
||||
def _post(self, path: str, data: Any | None = None) -> Any:
|
||||
return self._request("POST", path, json_data=data)
|
||||
|
||||
def _put(self, path: str, data: Any | None = None) -> Any:
|
||||
return self._request("PUT", path, json_data=data)
|
||||
|
||||
def _patch(self, path: str, data: Any | None = None) -> Any:
|
||||
return self._request("PATCH", path, json_data=data)
|
||||
|
||||
def _delete(self, path: str) -> Any:
|
||||
return self._request("DELETE", path)
|
||||
|
||||
def _paginated(self, path: str, params: dict | None = None) -> list[dict]:
|
||||
"""
|
||||
Fetch all pages of a paginated list endpoint.
|
||||
|
||||
Athena uses the standard DRF page/results envelope::
|
||||
|
||||
{"count": N, "next": url|None, "previous": ..., "results": [...]}
|
||||
"""
|
||||
out: list[dict] = []
|
||||
result = self._get(path, params=params)
|
||||
while True:
|
||||
if isinstance(result, list):
|
||||
out.extend(result)
|
||||
return out
|
||||
if not isinstance(result, dict):
|
||||
return out
|
||||
out.extend(result.get("results", []) or [])
|
||||
next_url = result.get("next")
|
||||
if not next_url:
|
||||
return out
|
||||
# Follow absolute next URL
|
||||
try:
|
||||
result = self.session.get(next_url, timeout=self.timeout).json()
|
||||
except Exception: # pragma: no cover – defensive
|
||||
return out
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Connection test
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
def test_connection(self) -> dict:
|
||||
"""Verify API connectivity and authentication."""
|
||||
try:
|
||||
result = self._get(f"{API_PREFIX}/reports/")
|
||||
count = (
|
||||
result.get("count", len(result.get("results", [])))
|
||||
if isinstance(result, dict)
|
||||
else len(result)
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"base_url": self.base_url,
|
||||
"authenticated": True,
|
||||
"reports_found": count,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
except AthenaAPIError as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"base_url": self.base_url,
|
||||
"authenticated": e.status_code != 401,
|
||||
"error_code": e.status_code,
|
||||
"detail": e.detail,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Reports (templates)
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
def list_reports(self) -> list[dict]:
|
||||
"""List all TEI report templates (auto-paginated)."""
|
||||
return self._paginated(f"{API_PREFIX}/reports/")
|
||||
|
||||
def get_report(self, public_id: str) -> dict:
|
||||
"""Get a TEI report template by its public_id."""
|
||||
return self._get(f"{API_PREFIX}/reports/{public_id}/")
|
||||
|
||||
def list_fields(
|
||||
self,
|
||||
report_public_id: str,
|
||||
table: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get field definitions for a report.
|
||||
|
||||
Args:
|
||||
report_public_id: The report template's public_id (12-char short UUID).
|
||||
table: Optional filter — ``'benefits'`` or ``'costs'``.
|
||||
|
||||
Returns a list of field-definition dicts. See ``TEIField.from_dict``
|
||||
for the expected shape.
|
||||
"""
|
||||
params = {"table": table} if table else None
|
||||
rows = self._paginated(
|
||||
f"{API_PREFIX}/reports/{report_public_id}/fields/", params=params
|
||||
)
|
||||
# Defensive — server-side filter may not be implemented; filter locally.
|
||||
if table:
|
||||
rows = [r for r in rows if r.get("table") == table]
|
||||
rows.sort(key=lambda r: (r.get("table", ""), r.get("sort_order") or 0))
|
||||
return rows
|
||||
|
||||
def create_field(self, report_public_id: str, field: dict) -> dict:
|
||||
"""Create a new field definition under a report (admin only)."""
|
||||
return self._post(
|
||||
f"{API_PREFIX}/reports/{report_public_id}/fields/", data=field
|
||||
)
|
||||
|
||||
def update_field(self, report_public_id: str, field_id: int, **changes) -> dict:
|
||||
"""Patch one field definition by its integer id."""
|
||||
return self._patch(
|
||||
f"{API_PREFIX}/reports/{report_public_id}/fields/{field_id}/",
|
||||
data=changes,
|
||||
)
|
||||
|
||||
def delete_field(self, report_public_id: str, field_id: int) -> dict:
|
||||
return self._delete(
|
||||
f"{API_PREFIX}/reports/{report_public_id}/fields/{field_id}/"
|
||||
)
|
||||
|
||||
def reorder_fields(self, report_public_id: str, field_ids: list[int]) -> dict:
|
||||
"""Bulk-reorder fields. Spec: PATCH /reports/{id}/fields/reorder/."""
|
||||
return self._patch(
|
||||
f"{API_PREFIX}/reports/{report_public_id}/fields/reorder/",
|
||||
data={"field_ids": field_ids},
|
||||
)
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Tools (instances)
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
def list_tools(self) -> list[dict]:
|
||||
"""List TEI tool instances owned by the current API key."""
|
||||
return self._paginated(f"{API_PREFIX}/tools/")
|
||||
|
||||
def get_tool(self, public_id: str) -> dict:
|
||||
"""Get a TEI tool instance by public_id."""
|
||||
return self._get(f"{API_PREFIX}/tools/{public_id}/")
|
||||
|
||||
def create_tool(
|
||||
self,
|
||||
report_public_id: str,
|
||||
proposal: int | None = None,
|
||||
engagement: int | None = None,
|
||||
name: str | None = None,
|
||||
status: str = "draft",
|
||||
) -> dict:
|
||||
"""
|
||||
Create a new TEI tool instance from a report template.
|
||||
|
||||
Athena scopes a TEI tool to a *Proposal* (which itself belongs to an
|
||||
Opportunity) and/or an *Engagement*. Pass the integer PK of either or
|
||||
both to link the tool.
|
||||
"""
|
||||
data: dict[str, Any] = {"report": report_public_id, "status": status}
|
||||
if proposal is not None:
|
||||
data["proposal"] = proposal
|
||||
if engagement is not None:
|
||||
data["engagement"] = engagement
|
||||
if name:
|
||||
data["name"] = name
|
||||
return self._post(f"{API_PREFIX}/tools/", data=data)
|
||||
|
||||
def update_tool(
|
||||
self,
|
||||
public_id: str,
|
||||
name: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> dict:
|
||||
"""Update tool metadata. Only ``name`` and ``status`` are mutable."""
|
||||
data: dict[str, Any] = {}
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if status is not None:
|
||||
data["status"] = status
|
||||
return self._patch(f"{API_PREFIX}/tools/{public_id}/", data=data)
|
||||
|
||||
def delete_tool(self, public_id: str) -> dict:
|
||||
return self._delete(f"{API_PREFIX}/tools/{public_id}/")
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Values (data entry)
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _normalize_value(value: dict) -> dict:
|
||||
"""
|
||||
Normalize a value-row dict into the shape the API expects.
|
||||
|
||||
Accepts any of the following input forms and produces a uniform
|
||||
wire-format dict::
|
||||
|
||||
# annual fields
|
||||
{"field_key": "A1", "year_1": 100, "year_2": 200, "year_3": 300, ...}
|
||||
{"field_key": "A1", "year_values": {"1": 100, "2": 200, "3": 300}, ...}
|
||||
|
||||
# non-annual scalars
|
||||
{"field_key": "rate", "value": 0.10, ...}
|
||||
|
||||
Returns a dict like::
|
||||
|
||||
{"field_key": "A1",
|
||||
"year_values": {"1": 100.0, "2": 200.0, "3": 300.0},
|
||||
"risk_adjustment": 0.15,
|
||||
"notes": "…"}
|
||||
"""
|
||||
out: dict[str, Any] = {}
|
||||
if "field_key" in value:
|
||||
out["field_key"] = value["field_key"]
|
||||
elif "field" in value:
|
||||
out["field_key"] = value["field"]
|
||||
|
||||
# Collect annual year_N keys into year_values
|
||||
year_values: dict[str, float] = {}
|
||||
if "year_values" in value and isinstance(value["year_values"], dict):
|
||||
for k, v in value["year_values"].items():
|
||||
year_values[str(k)] = float(v) if v is not None else 0.0
|
||||
for key, raw in value.items():
|
||||
if key.startswith("year_"):
|
||||
try:
|
||||
n = int(key.split("_", 1)[1])
|
||||
except ValueError:
|
||||
continue
|
||||
year_values[str(n)] = float(raw) if raw is not None else 0.0
|
||||
|
||||
if year_values:
|
||||
out["year_values"] = year_values
|
||||
if "value" in value and value["value"] is not None and not year_values:
|
||||
out["value"] = value["value"]
|
||||
if value.get("initial") is not None:
|
||||
out["initial"] = float(value["initial"])
|
||||
if value.get("risk_adjustment") is not None:
|
||||
out["risk_adjustment"] = float(value["risk_adjustment"])
|
||||
if value.get("notes"):
|
||||
out["notes"] = str(value["notes"])
|
||||
return out
|
||||
|
||||
def get_values(self, public_id: str) -> list[dict]:
|
||||
"""Get all current field values for a TEI tool instance."""
|
||||
result = self._get(f"{API_PREFIX}/tools/{public_id}/values/")
|
||||
if isinstance(result, dict):
|
||||
# Could be {"values": [...]} envelope, the TEITool wrapper, or a page
|
||||
if "values" in result and isinstance(result["values"], list):
|
||||
return result["values"]
|
||||
if "results" in result and isinstance(result["results"], list):
|
||||
return result["results"]
|
||||
return []
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
return []
|
||||
|
||||
def update_values(self, public_id: str, values: list[dict]) -> dict:
|
||||
"""
|
||||
Bulk-update field values. See ``_normalize_value`` for accepted shapes.
|
||||
"""
|
||||
payload = {"values": [self._normalize_value(v) for v in values]}
|
||||
return self._put(f"{API_PREFIX}/tools/{public_id}/values/", data=payload)
|
||||
|
||||
def patch_value(self, public_id: str, field_key: str, **changes) -> dict:
|
||||
"""
|
||||
Patch a single field value by its ``field_key``.
|
||||
|
||||
Accepts the same shorthand as ``update_values`` (``year_1=…``, etc).
|
||||
"""
|
||||
body = self._normalize_value({"field_key": field_key, **changes})
|
||||
body.pop("field_key", None) # carried in URL
|
||||
return self._patch(
|
||||
f"{API_PREFIX}/tools/{public_id}/values/{field_key}/", data=body
|
||||
)
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Calculation & summary
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
def calculate(self, public_id: str) -> dict:
|
||||
"""Trigger server-side calculation; returns the updated summary."""
|
||||
return self._post(f"{API_PREFIX}/tools/{public_id}/calculate/")
|
||||
|
||||
def get_summary(self, public_id: str) -> dict:
|
||||
"""Return the most-recent summary (404 if never calculated)."""
|
||||
return self._get(f"{API_PREFIX}/tools/{public_id}/summary/")
|
||||
|
||||
def aggregate_summary(self) -> dict:
|
||||
"""Aggregate NPV across all tools owned by the current API key."""
|
||||
return self._get(f"{API_PREFIX}/summary/")
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Versions
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
def list_versions(self, public_id: str) -> list[dict]:
|
||||
"""List all saved version snapshots for a TEI tool."""
|
||||
result = self._get(f"{API_PREFIX}/tools/{public_id}/versions/")
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
if isinstance(result, dict):
|
||||
if "results" in result and isinstance(result["results"], list):
|
||||
return result["results"]
|
||||
if "versions" in result and isinstance(result["versions"], list):
|
||||
return result["versions"]
|
||||
return []
|
||||
|
||||
def save_version(self, public_id: str, note: str = "") -> dict:
|
||||
"""Snapshot current values + summary as a new version."""
|
||||
return self._post(
|
||||
f"{API_PREFIX}/tools/{public_id}/versions/",
|
||||
data={"note": note},
|
||||
)
|
||||
|
||||
def get_version(self, public_id: str, version_number: int) -> dict:
|
||||
"""Get a single version's full snapshot."""
|
||||
return self._get(
|
||||
f"{API_PREFIX}/tools/{public_id}/versions/{int(version_number)}/"
|
||||
)
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Export
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
def export(self, public_id: str) -> dict:
|
||||
"""
|
||||
Return the LLM-ready export payload for the report pipeline.
|
||||
|
||||
The shape is determined by Athena and consumed by Peitho /
|
||||
html2docx; Palladium's ``core.export.report_data`` builds on this.
|
||||
"""
|
||||
return self._get(f"{API_PREFIX}/tools/{public_id}/export/")
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Convenience
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
def get_benefits(self, public_id: str) -> list[dict]:
|
||||
"""Return only benefit-table values (table='benefits')."""
|
||||
return [v for v in self.get_values(public_id) if v.get("table") == "benefits"]
|
||||
|
||||
def get_costs(self, public_id: str) -> list[dict]:
|
||||
"""Return only cost-table values (table='costs')."""
|
||||
return [v for v in self.get_values(public_id) if v.get("table") == "costs"]
|
||||
|
||||
def get_tool_with_data(self, public_id: str) -> dict:
|
||||
"""
|
||||
Bundle a tool, its field definitions, current values, and summary.
|
||||
|
||||
Convenience for notebook initialisation. The summary is allowed to
|
||||
404 (returned as ``None``) when the tool has never been calculated.
|
||||
"""
|
||||
tool = self.get_tool(public_id)
|
||||
report_pid = tool.get("report")
|
||||
if isinstance(report_pid, dict):
|
||||
report_pid = report_pid.get("id") or report_pid.get("public_id")
|
||||
fields = self.list_fields(report_pid) if report_pid else []
|
||||
values = self.get_values(public_id)
|
||||
try:
|
||||
summary = self.get_summary(public_id)
|
||||
except AthenaAPIError as e:
|
||||
if e.status_code == 404:
|
||||
summary = None
|
||||
else:
|
||||
raise
|
||||
return {
|
||||
"tool": tool,
|
||||
"fields": fields,
|
||||
"values": values,
|
||||
"summary": summary,
|
||||
}
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# Display
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return f"TEIClient(base_url='{self.base_url}')"
|
||||
|
||||
def print_summary(self, public_id: str) -> None:
|
||||
"""Pretty-print a financial summary block for notebooks/REPL."""
|
||||
s = self.get_summary(public_id)
|
||||
|
||||
def _f(v: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
return float(v) if v is not None else default
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
print("═" * 56)
|
||||
print(" TEI Financial Summary")
|
||||
print("═" * 56)
|
||||
print(f" Total Benefits (PV): ${_f(s.get('total_benefits_pv')):>16,.0f}")
|
||||
print(f" Total Costs (PV): ${_f(s.get('total_costs_pv')):>16,.0f}")
|
||||
print("─" * 56)
|
||||
print(f" Net Present Value: ${_f(s.get('npv')):>16,.0f}")
|
||||
print(f" ROI: {_f(s.get('roi')):>15,.0f}%")
|
||||
payback = s.get("payback_months")
|
||||
payback_str = f"{_f(payback):.1f} months" if payback is not None else "N/A"
|
||||
print(f" Payback: {payback_str:>17}")
|
||||
print("═" * 56)
|
||||
Reference in New Issue
Block a user