from __future__ import annotations
import json
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from tre.utils.converters import as_bool, as_float, as_int
logger = logging.getLogger("tre.controller")
# ------------------------------------------------------------------
# schema definition for validation
# ------------------------------------------------------------------
SUPPORTED_SCHEMA_VERSIONS = {1}
KNOWN_TOP_LEVEL_KEYS = {"version", "defaults", "models"}
KNOWN_MODEL_SECTIONS = {
"slo_class",
"latency_slo_ms",
"theta_reliability_target",
"weights",
"control",
"safescale",
"bootstrap",
}
KNOWN_LATENCY_FIELDS = {"ttft_p95", "tpot_p95", "e2e_p95"}
KNOWN_WEIGHT_FIELDS = {"w_p", "w_d"}
KNOWN_CONTROL_FIELDS = {
"qmin",
"qsat",
"epsat",
"Hsat",
"delta_crit",
"delta_high",
"rho_rescue",
"rho_fairness",
"delta_release",
"delta_transfer_safe",
"max_zero_load_releases_per_loop",
"receiver_thrashing_trs",
"receiver_thrashing_eff",
"receiver_congested_trs",
"donor_waste_trs",
"donor_waste_eff",
"donor_surplus_trs",
"fairness_rich_ratio",
"fairness_poor_ratio",
"shrink_safe_trs",
"shrink_relaxed_trs",
"shrink_relaxed_trigger_trs",
"shrink_relaxed_max_pods",
"proactive_release_min_trs",
}
KNOWN_SAFESCALE_FIELDS = {
"cdec",
"HQ",
"ttft_p95_slo_ms",
"tpot_p95_slo_ms",
"default_probe_window_ms",
"min_probe_ms",
"max_probe_ms",
"cw2_fallback_ms",
}
KNOWN_BOOTSTRAP_FIELDS = {
"min_support",
"min_confidence",
"allow_aggressive_down",
}
REQUIRED_LATENCY_FIELDS = {"ttft_p95", "tpot_p95", "e2e_p95"}
@dataclass
class ProfileValidationResult:
"""Validation result for a model profile configuration."""
valid: bool = True
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
def add_error(self, msg: str) -> None:
self.errors.append(msg)
self.valid = False
def add_warning(self, msg: str) -> None:
self.warnings.append(msg)
def _check_unknown_keys(
section: Dict[str, Any],
known: Set[str],
path: str,
result: ProfileValidationResult,
) -> None:
unknown = set(section.keys()) - known
for key in sorted(unknown):
result.add_warning(f"{path}: unknown field '{key}'")
def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
merged = dict(base)
for key, value in override.items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
merged[key] = _deep_merge(merged[key], value)
else:
merged[key] = value
return merged
@dataclass(frozen=True)
class ModelProfile:
model_name: str
slo_class: str
latency_slo_ms: Dict[str, float]
theta_reliability_target: float
weights: Dict[str, float]
control: Dict[str, float]
safescale: Dict[str, float]
bootstrap: Dict[str, Any]
@dataclass(frozen=True)
class ModelRuntimeConfig:
model_name: str
profile_version: int
slo_class: str
latency_slo_ms: Dict[str, float]
theta_reliability_target: float
theta_m: Optional[float]
theta_confidence: float
theta_support: Optional[int]
theta_coverage: Optional[Dict[str, Any]]
calibration_run_id: Optional[str]
calibration_generated_at: Optional[str]
calibration_updated_at: Optional[str]
calibration_profile_version: Optional[int]
weights: Dict[str, float]
control: Dict[str, float]
safescale: Dict[str, float]
bootstrap: Dict[str, Any]
class ModelProfileStore:
def __init__(self, config_path: Optional[str] = None) -> None:
default_path = (
Path(__file__).resolve().parent / "configs" / "model_slo_profiles.json"
)
self.config_path = Path(
os.getenv("TRE_MODEL_SLO_PROFILE_PATH", config_path or str(default_path))
)
self._version = 1
self._defaults: Dict[str, Any] = {}
self._models: Dict[str, Any] = {}
self.reload()
@property
def version(self) -> int:
return self._version
def reload(self) -> None:
if not self.config_path.exists():
logger.warning(
json.dumps(
{
"event": "model_profile_missing",
"path": str(self.config_path),
"fallback": "builtin_defaults",
}
)
)
raw: Dict[str, Any] = {}
else:
try:
raw = json.loads(self.config_path.read_text(encoding="utf-8"))
except Exception as e:
logger.error(
json.dumps(
{
"event": "model_profile_load_error",
"path": str(self.config_path),
"error": str(e),
}
)
)
raw = {}
self._version = as_int(raw.get("version"), 1)
self._defaults = self._normalize_defaults(raw.get("defaults") or {})
self._models = raw.get("models") if isinstance(raw.get("models"), dict) else {}
validation = self.validate(raw)
if not validation.valid:
for err in validation.errors:
logger.error(
json.dumps(
{"event": "model_profile_validation_error", "error": err}
)
)
for warn in validation.warnings:
logger.warning(
json.dumps(
{"event": "model_profile_validation_warning", "warning": warn}
)
)
self._last_validation = validation
@property
def last_validation(self) -> ProfileValidationResult:
return getattr(self, "_last_validation", ProfileValidationResult())
def validate(self, raw: Dict[str, Any]) -> ProfileValidationResult:
"""Validate schema of raw profile JSON. Returns errors and warnings."""
result = ProfileValidationResult()
if not raw:
return result
# version check
version = raw.get("version")
if version is not None and int(version) not in SUPPORTED_SCHEMA_VERSIONS:
result.add_error(
f"unsupported profile schema version {version}; "
f"supported: {sorted(SUPPORTED_SCHEMA_VERSIONS)}"
)
# top-level unknown keys
_check_unknown_keys(raw, KNOWN_TOP_LEVEL_KEYS, "top-level", result)
# validate defaults section
defaults = raw.get("defaults")
if isinstance(defaults, dict):
self._validate_model_section(defaults, "defaults", result)
# validate per-model sections
models = raw.get("models")
if isinstance(models, dict):
for model_name, model_cfg in models.items():
if not isinstance(model_cfg, dict):
result.add_error(
f"models.{model_name}: expected dict, got {type(model_cfg).__name__}"
)
continue
self._validate_model_section(model_cfg, f"models.{model_name}", result)
return result
def _validate_model_section(
self,
section: Dict[str, Any],
path: str,
result: ProfileValidationResult,
) -> None:
"""Validate a defaults or per-model override section."""
_check_unknown_keys(section, KNOWN_MODEL_SECTIONS, path, result)
latency = section.get("latency_slo_ms")
if isinstance(latency, dict):
_check_unknown_keys(
latency, KNOWN_LATENCY_FIELDS, f"{path}.latency_slo_ms", result
)
for field_name in REQUIRED_LATENCY_FIELDS:
val = latency.get(field_name)
if val is not None:
try:
fval = float(val)
if fval <= 0:
result.add_error(
f"{path}.latency_slo_ms.{field_name}: must be positive, got {fval}"
)
except (TypeError, ValueError):
result.add_error(
f"{path}.latency_slo_ms.{field_name}: cannot convert to float: {val!r}"
)
weights = section.get("weights")
if isinstance(weights, dict):
_check_unknown_keys(weights, KNOWN_WEIGHT_FIELDS, f"{path}.weights", result)
control = section.get("control")
if isinstance(control, dict):
_check_unknown_keys(
control, KNOWN_CONTROL_FIELDS, f"{path}.control", result
)
safescale = section.get("safescale")
if isinstance(safescale, dict):
_check_unknown_keys(
safescale, KNOWN_SAFESCALE_FIELDS, f"{path}.safescale", result
)
bootstrap = section.get("bootstrap")
if isinstance(bootstrap, dict):
_check_unknown_keys(
bootstrap, KNOWN_BOOTSTRAP_FIELDS, f"{path}.bootstrap", result
)
def list_models(self) -> List[str]:
return sorted(list(self._models.keys()))
def get_model_profile(self, model_name: str) -> ModelProfile:
model_override = self._models.get(model_name, {})
if not isinstance(model_override, dict):
model_override = {}
merged = _deep_merge(self._defaults, model_override)
return self._normalize_profile(model_name, merged)
def _normalize_defaults(self, defaults: Dict[str, Any]) -> Dict[str, Any]:
builtin_defaults = {
"slo_class": "interactive",
"latency_slo_ms": {
"ttft_p95": 1200.0,
"tpot_p95": 100.0,
"e2e_p95": 10000.0,
},
"theta_reliability_target": 0.99,
"weights": {"w_p": 0.04, "w_d": 1.0},
"control": {
"qmin": 1.0,
"qsat": 4.0,
"epsat": 0.05,
"Hsat": 3,
"delta_crit": 0.2,
"delta_high": 0.25,
"rho_rescue": 0.5,
"rho_fairness": 0.25,
"delta_release": 0.15,
"delta_transfer_safe": 0.35,
"max_zero_load_releases_per_loop": 2.0,
"receiver_thrashing_trs": 1000.0,
"receiver_thrashing_eff": 200.0,
"receiver_congested_trs": 1500.0,
"donor_waste_trs": 2500.0,
"donor_waste_eff": 300.0,
"donor_surplus_trs": 2000.0,
"fairness_rich_ratio": 1.1,
"fairness_poor_ratio": 0.9,
"shrink_safe_trs": 1500.0,
"shrink_relaxed_trs": 1000.0,
"shrink_relaxed_trigger_trs": 1950.0,
"shrink_relaxed_max_pods": 3.0,
"proactive_release_min_trs": 2000.0,
},
"safescale": {
"cdec": 2.0,
"HQ": 0.25,
"ttft_p95_slo_ms": 1200.0,
"tpot_p95_slo_ms": 100.0,
"default_probe_window_ms": 60000.0,
"min_probe_ms": 15000.0,
"max_probe_ms": 300000.0,
"cw2_fallback_ms": 300000.0,
},
"bootstrap": {
"min_support": 300,
"min_confidence": 0.8,
"allow_aggressive_down": False,
},
}
return _deep_merge(builtin_defaults, defaults)
def _normalize_profile(
self, model_name: str, merged: Dict[str, Any]
) -> ModelProfile:
latency = (
merged.get("latency_slo_ms")
if isinstance(merged.get("latency_slo_ms"), dict)
else {}
)
weights = (
merged.get("weights") if isinstance(merged.get("weights"), dict) else {}
)
control = (
merged.get("control") if isinstance(merged.get("control"), dict) else {}
)
safescale = (
merged.get("safescale") if isinstance(merged.get("safescale"), dict) else {}
)
bootstrap = (
merged.get("bootstrap") if isinstance(merged.get("bootstrap"), dict) else {}
)
normalized_latency = {
"ttft_p95": as_float(latency.get("ttft_p95"), 1200.0),
"tpot_p95": as_float(latency.get("tpot_p95"), 100.0),
"e2e_p95": as_float(latency.get("e2e_p95"), 10000.0),
}
normalized_weights = {
"w_p": as_float(weights.get("w_p"), 0.04),
"w_d": as_float(weights.get("w_d"), 1.0),
}
normalized_control = {
"qmin": as_float(control.get("qmin"), 1.0),
"qsat": as_float(control.get("qsat"), 4.0),
"epsat": as_float(control.get("epsat"), 0.05),
"Hsat": float(as_int(control.get("Hsat"), 3)),
"delta_crit": as_float(control.get("delta_crit"), 0.2),
"delta_high": as_float(control.get("delta_high"), 0.25),
"rho_rescue": as_float(control.get("rho_rescue"), 0.5),
"rho_fairness": as_float(control.get("rho_fairness"), 0.25),
"delta_release": as_float(control.get("delta_release"), 0.15),
"delta_transfer_safe": as_float(control.get("delta_transfer_safe"), 0.35),
"max_zero_load_releases_per_loop": as_float(
control.get("max_zero_load_releases_per_loop"), 2.0
),
"receiver_thrashing_trs": as_float(
control.get("receiver_thrashing_trs"), 1000.0
),
"receiver_thrashing_eff": as_float(
control.get("receiver_thrashing_eff"), 200.0
),
"receiver_congested_trs": as_float(
control.get("receiver_congested_trs"), 1500.0
),
"donor_waste_trs": as_float(control.get("donor_waste_trs"), 2500.0),
"donor_waste_eff": as_float(control.get("donor_waste_eff"), 300.0),
"donor_surplus_trs": as_float(control.get("donor_surplus_trs"), 2000.0),
"fairness_rich_ratio": as_float(control.get("fairness_rich_ratio"), 1.1),
"fairness_poor_ratio": as_float(control.get("fairness_poor_ratio"), 0.9),
"shrink_safe_trs": as_float(control.get("shrink_safe_trs"), 1500.0),
"shrink_relaxed_trs": as_float(control.get("shrink_relaxed_trs"), 1000.0),
"shrink_relaxed_trigger_trs": as_float(
control.get("shrink_relaxed_trigger_trs"), 1950.0
),
"shrink_relaxed_max_pods": as_float(
control.get("shrink_relaxed_max_pods"), 3.0
),
"proactive_release_min_trs": as_float(
control.get("proactive_release_min_trs"), 2000.0
),
}
normalized_safescale = {
"cdec": as_float(safescale.get("cdec"), 2.0),
"HQ": as_float(safescale.get("HQ"), 0.25),
"ttft_p95_slo_ms": as_float(
safescale.get("ttft_p95_slo_ms"), normalized_latency["ttft_p95"]
),
"tpot_p95_slo_ms": as_float(
safescale.get("tpot_p95_slo_ms"), normalized_latency["tpot_p95"]
),
"default_probe_window_ms": as_float(
safescale.get("default_probe_window_ms"), 60000.0
),
"min_probe_ms": as_float(safescale.get("min_probe_ms"), 15000.0),
"max_probe_ms": as_float(safescale.get("max_probe_ms"), 300000.0),
"cw2_fallback_ms": as_float(
safescale.get("cw2_fallback_ms"),
as_float(safescale.get("max_probe_ms"), 300000.0),
),
}
normalized_bootstrap = {
"min_support": as_int(bootstrap.get("min_support"), 300),
"min_confidence": as_float(bootstrap.get("min_confidence"), 0.8),
"allow_aggressive_down": as_bool(
bootstrap.get("allow_aggressive_down"), False
),
}
return ModelProfile(
model_name=model_name,
slo_class=str(merged.get("slo_class", "interactive")),
latency_slo_ms=normalized_latency,
theta_reliability_target=as_float(
merged.get("theta_reliability_target"), 0.99
),
weights=normalized_weights,
control=normalized_control,
safescale=normalized_safescale,
bootstrap=normalized_bootstrap,
)
def build_runtime_config(
model_name: str,
profile_store: ModelProfileStore,
state_store: Any,
) -> ModelRuntimeConfig:
profile = profile_store.get_model_profile(model_name)
calibration = (
state_store.load_model_calibration(model_name) if state_store else None
)
theta_m: Optional[float] = None
theta_confidence = 0.0
theta_support: Optional[int] = None
theta_coverage: Optional[Dict[str, Any]] = None
calibration_run_id: Optional[str] = None
calibration_generated_at: Optional[str] = None
calibration_updated_at: Optional[str] = None
calibration_profile_version: Optional[int] = None
if isinstance(calibration, dict):
if calibration.get("theta_m") is not None:
theta_m = as_float(calibration.get("theta_m"), 0.0)
theta_confidence = as_float(calibration.get("confidence"), 0.0)
if calibration.get("support") is not None:
theta_support = as_int(calibration.get("support"), 0)
coverage = calibration.get("coverage")
if isinstance(coverage, dict):
theta_coverage = coverage
if calibration.get("run_id") is not None:
calibration_run_id = str(calibration.get("run_id"))
if calibration.get("generated_at") is not None:
calibration_generated_at = str(calibration.get("generated_at"))
if calibration.get("updated_at") is not None:
calibration_updated_at = str(calibration.get("updated_at"))
if calibration.get("profile_version") is not None:
calibration_profile_version = as_int(
calibration.get("profile_version"), profile_store.version
)
return ModelRuntimeConfig(
model_name=model_name,
profile_version=profile_store.version,
slo_class=profile.slo_class,
latency_slo_ms=profile.latency_slo_ms,
theta_reliability_target=profile.theta_reliability_target,
theta_m=theta_m,
theta_confidence=theta_confidence,
theta_support=theta_support,
theta_coverage=theta_coverage,
calibration_run_id=calibration_run_id,
calibration_generated_at=calibration_generated_at,
calibration_updated_at=calibration_updated_at,
calibration_profile_version=calibration_profile_version,
weights=profile.weights,
control=profile.control,
safescale=profile.safescale,
bootstrap=profile.bootstrap,
)