测试代码样式

from __future__ import annotations

import json
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Set

from tre.utils.converters import as_bool, as_float, as_int


logger = logging.getLogger("tre.controller")

# ------------------------------------------------------------------
# schema definition for validation
# ------------------------------------------------------------------

SUPPORTED_SCHEMA_VERSIONS = {1}

KNOWN_TOP_LEVEL_KEYS = {"version", "defaults", "models"}

KNOWN_MODEL_SECTIONS = {
    "slo_class",
    "latency_slo_ms",
    "theta_reliability_target",
    "weights",
    "control",
    "safescale",
    "bootstrap",
}

KNOWN_LATENCY_FIELDS = {"ttft_p95", "tpot_p95", "e2e_p95"}

KNOWN_WEIGHT_FIELDS = {"w_p", "w_d"}

KNOWN_CONTROL_FIELDS = {
    "qmin",
    "qsat",
    "epsat",
    "Hsat",
    "delta_crit",
    "delta_high",
    "rho_rescue",
    "rho_fairness",
    "delta_release",
    "delta_transfer_safe",
    "max_zero_load_releases_per_loop",
    "receiver_thrashing_trs",
    "receiver_thrashing_eff",
    "receiver_congested_trs",
    "donor_waste_trs",
    "donor_waste_eff",
    "donor_surplus_trs",
    "fairness_rich_ratio",
    "fairness_poor_ratio",
    "shrink_safe_trs",
    "shrink_relaxed_trs",
    "shrink_relaxed_trigger_trs",
    "shrink_relaxed_max_pods",
    "proactive_release_min_trs",
}

KNOWN_SAFESCALE_FIELDS = {
    "cdec",
    "HQ",
    "ttft_p95_slo_ms",
    "tpot_p95_slo_ms",
    "default_probe_window_ms",
    "min_probe_ms",
    "max_probe_ms",
    "cw2_fallback_ms",
}

KNOWN_BOOTSTRAP_FIELDS = {
    "min_support",
    "min_confidence",
    "allow_aggressive_down",
}

REQUIRED_LATENCY_FIELDS = {"ttft_p95", "tpot_p95", "e2e_p95"}


@dataclass
class ProfileValidationResult:
    """Validation result for a model profile configuration."""

    valid: bool = True
    errors: List[str] = field(default_factory=list)
    warnings: List[str] = field(default_factory=list)

    def add_error(self, msg: str) -> None:
        self.errors.append(msg)
        self.valid = False

    def add_warning(self, msg: str) -> None:
        self.warnings.append(msg)


def _check_unknown_keys(
    section: Dict[str, Any],
    known: Set[str],
    path: str,
    result: ProfileValidationResult,
) -> None:
    unknown = set(section.keys()) - known
    for key in sorted(unknown):
        result.add_warning(f"{path}: unknown field '{key}'")


def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
    merged = dict(base)
    for key, value in override.items():
        if isinstance(value, dict) and isinstance(merged.get(key), dict):
            merged[key] = _deep_merge(merged[key], value)
        else:
            merged[key] = value
    return merged


@dataclass(frozen=True)
class ModelProfile:
    model_name: str
    slo_class: str
    latency_slo_ms: Dict[str, float]
    theta_reliability_target: float
    weights: Dict[str, float]
    control: Dict[str, float]
    safescale: Dict[str, float]
    bootstrap: Dict[str, Any]


@dataclass(frozen=True)
class ModelRuntimeConfig:
    model_name: str
    profile_version: int
    slo_class: str
    latency_slo_ms: Dict[str, float]
    theta_reliability_target: float
    theta_m: Optional[float]
    theta_confidence: float
    theta_support: Optional[int]
    theta_coverage: Optional[Dict[str, Any]]
    calibration_run_id: Optional[str]
    calibration_generated_at: Optional[str]
    calibration_updated_at: Optional[str]
    calibration_profile_version: Optional[int]
    weights: Dict[str, float]
    control: Dict[str, float]
    safescale: Dict[str, float]
    bootstrap: Dict[str, Any]


class ModelProfileStore:
    def __init__(self, config_path: Optional[str] = None) -> None:
        default_path = (
            Path(__file__).resolve().parent / "configs" / "model_slo_profiles.json"
        )
        self.config_path = Path(
            os.getenv("TRE_MODEL_SLO_PROFILE_PATH", config_path or str(default_path))
        )
        self._version = 1
        self._defaults: Dict[str, Any] = {}
        self._models: Dict[str, Any] = {}
        self.reload()

    @property
    def version(self) -> int:
        return self._version

    def reload(self) -> None:
        if not self.config_path.exists():
            logger.warning(
                json.dumps(
                    {
                        "event": "model_profile_missing",
                        "path": str(self.config_path),
                        "fallback": "builtin_defaults",
                    }
                )
            )
            raw: Dict[str, Any] = {}
        else:
            try:
                raw = json.loads(self.config_path.read_text(encoding="utf-8"))
            except Exception as e:
                logger.error(
                    json.dumps(
                        {
                            "event": "model_profile_load_error",
                            "path": str(self.config_path),
                            "error": str(e),
                        }
                    )
                )
                raw = {}

        self._version = as_int(raw.get("version"), 1)
        self._defaults = self._normalize_defaults(raw.get("defaults") or {})
        self._models = raw.get("models") if isinstance(raw.get("models"), dict) else {}

        validation = self.validate(raw)
        if not validation.valid:
            for err in validation.errors:
                logger.error(
                    json.dumps(
                        {"event": "model_profile_validation_error", "error": err}
                    )
                )
        for warn in validation.warnings:
            logger.warning(
                json.dumps(
                    {"event": "model_profile_validation_warning", "warning": warn}
                )
            )
        self._last_validation = validation

    @property
    def last_validation(self) -> ProfileValidationResult:
        return getattr(self, "_last_validation", ProfileValidationResult())

    def validate(self, raw: Dict[str, Any]) -> ProfileValidationResult:
        """Validate schema of raw profile JSON. Returns errors and warnings."""
        result = ProfileValidationResult()
        if not raw:
            return result

        # version check
        version = raw.get("version")
        if version is not None and int(version) not in SUPPORTED_SCHEMA_VERSIONS:
            result.add_error(
                f"unsupported profile schema version {version}; "
                f"supported: {sorted(SUPPORTED_SCHEMA_VERSIONS)}"
            )

        # top-level unknown keys
        _check_unknown_keys(raw, KNOWN_TOP_LEVEL_KEYS, "top-level", result)

        # validate defaults section
        defaults = raw.get("defaults")
        if isinstance(defaults, dict):
            self._validate_model_section(defaults, "defaults", result)

        # validate per-model sections
        models = raw.get("models")
        if isinstance(models, dict):
            for model_name, model_cfg in models.items():
                if not isinstance(model_cfg, dict):
                    result.add_error(
                        f"models.{model_name}: expected dict, got {type(model_cfg).__name__}"
                    )
                    continue
                self._validate_model_section(model_cfg, f"models.{model_name}", result)

        return result

    def _validate_model_section(
        self,
        section: Dict[str, Any],
        path: str,
        result: ProfileValidationResult,
    ) -> None:
        """Validate a defaults or per-model override section."""
        _check_unknown_keys(section, KNOWN_MODEL_SECTIONS, path, result)

        latency = section.get("latency_slo_ms")
        if isinstance(latency, dict):
            _check_unknown_keys(
                latency, KNOWN_LATENCY_FIELDS, f"{path}.latency_slo_ms", result
            )
            for field_name in REQUIRED_LATENCY_FIELDS:
                val = latency.get(field_name)
                if val is not None:
                    try:
                        fval = float(val)
                        if fval <= 0:
                            result.add_error(
                                f"{path}.latency_slo_ms.{field_name}: must be positive, got {fval}"
                            )
                    except (TypeError, ValueError):
                        result.add_error(
                            f"{path}.latency_slo_ms.{field_name}: cannot convert to float: {val!r}"
                        )

        weights = section.get("weights")
        if isinstance(weights, dict):
            _check_unknown_keys(weights, KNOWN_WEIGHT_FIELDS, f"{path}.weights", result)

        control = section.get("control")
        if isinstance(control, dict):
            _check_unknown_keys(
                control, KNOWN_CONTROL_FIELDS, f"{path}.control", result
            )

        safescale = section.get("safescale")
        if isinstance(safescale, dict):
            _check_unknown_keys(
                safescale, KNOWN_SAFESCALE_FIELDS, f"{path}.safescale", result
            )

        bootstrap = section.get("bootstrap")
        if isinstance(bootstrap, dict):
            _check_unknown_keys(
                bootstrap, KNOWN_BOOTSTRAP_FIELDS, f"{path}.bootstrap", result
            )

    def list_models(self) -> List[str]:
        return sorted(list(self._models.keys()))

    def get_model_profile(self, model_name: str) -> ModelProfile:
        model_override = self._models.get(model_name, {})
        if not isinstance(model_override, dict):
            model_override = {}

        merged = _deep_merge(self._defaults, model_override)
        return self._normalize_profile(model_name, merged)

    def _normalize_defaults(self, defaults: Dict[str, Any]) -> Dict[str, Any]:
        builtin_defaults = {
            "slo_class": "interactive",
            "latency_slo_ms": {
                "ttft_p95": 1200.0,
                "tpot_p95": 100.0,
                "e2e_p95": 10000.0,
            },
            "theta_reliability_target": 0.99,
            "weights": {"w_p": 0.04, "w_d": 1.0},
            "control": {
                "qmin": 1.0,
                "qsat": 4.0,
                "epsat": 0.05,
                "Hsat": 3,
                "delta_crit": 0.2,
                "delta_high": 0.25,
                "rho_rescue": 0.5,
                "rho_fairness": 0.25,
                "delta_release": 0.15,
                "delta_transfer_safe": 0.35,
                "max_zero_load_releases_per_loop": 2.0,
                "receiver_thrashing_trs": 1000.0,
                "receiver_thrashing_eff": 200.0,
                "receiver_congested_trs": 1500.0,
                "donor_waste_trs": 2500.0,
                "donor_waste_eff": 300.0,
                "donor_surplus_trs": 2000.0,
                "fairness_rich_ratio": 1.1,
                "fairness_poor_ratio": 0.9,
                "shrink_safe_trs": 1500.0,
                "shrink_relaxed_trs": 1000.0,
                "shrink_relaxed_trigger_trs": 1950.0,
                "shrink_relaxed_max_pods": 3.0,
                "proactive_release_min_trs": 2000.0,
            },
            "safescale": {
                "cdec": 2.0,
                "HQ": 0.25,
                "ttft_p95_slo_ms": 1200.0,
                "tpot_p95_slo_ms": 100.0,
                "default_probe_window_ms": 60000.0,
                "min_probe_ms": 15000.0,
                "max_probe_ms": 300000.0,
                "cw2_fallback_ms": 300000.0,
            },
            "bootstrap": {
                "min_support": 300,
                "min_confidence": 0.8,
                "allow_aggressive_down": False,
            },
        }
        return _deep_merge(builtin_defaults, defaults)

    def _normalize_profile(
        self, model_name: str, merged: Dict[str, Any]
    ) -> ModelProfile:
        latency = (
            merged.get("latency_slo_ms")
            if isinstance(merged.get("latency_slo_ms"), dict)
            else {}
        )
        weights = (
            merged.get("weights") if isinstance(merged.get("weights"), dict) else {}
        )
        control = (
            merged.get("control") if isinstance(merged.get("control"), dict) else {}
        )
        safescale = (
            merged.get("safescale") if isinstance(merged.get("safescale"), dict) else {}
        )
        bootstrap = (
            merged.get("bootstrap") if isinstance(merged.get("bootstrap"), dict) else {}
        )

        normalized_latency = {
            "ttft_p95": as_float(latency.get("ttft_p95"), 1200.0),
            "tpot_p95": as_float(latency.get("tpot_p95"), 100.0),
            "e2e_p95": as_float(latency.get("e2e_p95"), 10000.0),
        }
        normalized_weights = {
            "w_p": as_float(weights.get("w_p"), 0.04),
            "w_d": as_float(weights.get("w_d"), 1.0),
        }
        normalized_control = {
            "qmin": as_float(control.get("qmin"), 1.0),
            "qsat": as_float(control.get("qsat"), 4.0),
            "epsat": as_float(control.get("epsat"), 0.05),
            "Hsat": float(as_int(control.get("Hsat"), 3)),
            "delta_crit": as_float(control.get("delta_crit"), 0.2),
            "delta_high": as_float(control.get("delta_high"), 0.25),
            "rho_rescue": as_float(control.get("rho_rescue"), 0.5),
            "rho_fairness": as_float(control.get("rho_fairness"), 0.25),
            "delta_release": as_float(control.get("delta_release"), 0.15),
            "delta_transfer_safe": as_float(control.get("delta_transfer_safe"), 0.35),
            "max_zero_load_releases_per_loop": as_float(
                control.get("max_zero_load_releases_per_loop"), 2.0
            ),
            "receiver_thrashing_trs": as_float(
                control.get("receiver_thrashing_trs"), 1000.0
            ),
            "receiver_thrashing_eff": as_float(
                control.get("receiver_thrashing_eff"), 200.0
            ),
            "receiver_congested_trs": as_float(
                control.get("receiver_congested_trs"), 1500.0
            ),
            "donor_waste_trs": as_float(control.get("donor_waste_trs"), 2500.0),
            "donor_waste_eff": as_float(control.get("donor_waste_eff"), 300.0),
            "donor_surplus_trs": as_float(control.get("donor_surplus_trs"), 2000.0),
            "fairness_rich_ratio": as_float(control.get("fairness_rich_ratio"), 1.1),
            "fairness_poor_ratio": as_float(control.get("fairness_poor_ratio"), 0.9),
            "shrink_safe_trs": as_float(control.get("shrink_safe_trs"), 1500.0),
            "shrink_relaxed_trs": as_float(control.get("shrink_relaxed_trs"), 1000.0),
            "shrink_relaxed_trigger_trs": as_float(
                control.get("shrink_relaxed_trigger_trs"), 1950.0
            ),
            "shrink_relaxed_max_pods": as_float(
                control.get("shrink_relaxed_max_pods"), 3.0
            ),
            "proactive_release_min_trs": as_float(
                control.get("proactive_release_min_trs"), 2000.0
            ),
        }
        normalized_safescale = {
            "cdec": as_float(safescale.get("cdec"), 2.0),
            "HQ": as_float(safescale.get("HQ"), 0.25),
            "ttft_p95_slo_ms": as_float(
                safescale.get("ttft_p95_slo_ms"), normalized_latency["ttft_p95"]
            ),
            "tpot_p95_slo_ms": as_float(
                safescale.get("tpot_p95_slo_ms"), normalized_latency["tpot_p95"]
            ),
            "default_probe_window_ms": as_float(
                safescale.get("default_probe_window_ms"), 60000.0
            ),
            "min_probe_ms": as_float(safescale.get("min_probe_ms"), 15000.0),
            "max_probe_ms": as_float(safescale.get("max_probe_ms"), 300000.0),
            "cw2_fallback_ms": as_float(
                safescale.get("cw2_fallback_ms"),
                as_float(safescale.get("max_probe_ms"), 300000.0),
            ),
        }
        normalized_bootstrap = {
            "min_support": as_int(bootstrap.get("min_support"), 300),
            "min_confidence": as_float(bootstrap.get("min_confidence"), 0.8),
            "allow_aggressive_down": as_bool(
                bootstrap.get("allow_aggressive_down"), False
            ),
        }

        return ModelProfile(
            model_name=model_name,
            slo_class=str(merged.get("slo_class", "interactive")),
            latency_slo_ms=normalized_latency,
            theta_reliability_target=as_float(
                merged.get("theta_reliability_target"), 0.99
            ),
            weights=normalized_weights,
            control=normalized_control,
            safescale=normalized_safescale,
            bootstrap=normalized_bootstrap,
        )


def build_runtime_config(
    model_name: str,
    profile_store: ModelProfileStore,
    state_store: Any,
) -> ModelRuntimeConfig:
    profile = profile_store.get_model_profile(model_name)
    calibration = (
        state_store.load_model_calibration(model_name) if state_store else None
    )

    theta_m: Optional[float] = None
    theta_confidence = 0.0
    theta_support: Optional[int] = None
    theta_coverage: Optional[Dict[str, Any]] = None
    calibration_run_id: Optional[str] = None
    calibration_generated_at: Optional[str] = None
    calibration_updated_at: Optional[str] = None
    calibration_profile_version: Optional[int] = None
    if isinstance(calibration, dict):
        if calibration.get("theta_m") is not None:
            theta_m = as_float(calibration.get("theta_m"), 0.0)
        theta_confidence = as_float(calibration.get("confidence"), 0.0)
        if calibration.get("support") is not None:
            theta_support = as_int(calibration.get("support"), 0)
        coverage = calibration.get("coverage")
        if isinstance(coverage, dict):
            theta_coverage = coverage
        if calibration.get("run_id") is not None:
            calibration_run_id = str(calibration.get("run_id"))
        if calibration.get("generated_at") is not None:
            calibration_generated_at = str(calibration.get("generated_at"))
        if calibration.get("updated_at") is not None:
            calibration_updated_at = str(calibration.get("updated_at"))
        if calibration.get("profile_version") is not None:
            calibration_profile_version = as_int(
                calibration.get("profile_version"), profile_store.version
            )

    return ModelRuntimeConfig(
        model_name=model_name,
        profile_version=profile_store.version,
        slo_class=profile.slo_class,
        latency_slo_ms=profile.latency_slo_ms,
        theta_reliability_target=profile.theta_reliability_target,
        theta_m=theta_m,
        theta_confidence=theta_confidence,
        theta_support=theta_support,
        theta_coverage=theta_coverage,
        calibration_run_id=calibration_run_id,
        calibration_generated_at=calibration_generated_at,
        calibration_updated_at=calibration_updated_at,
        calibration_profile_version=calibration_profile_version,
        weights=profile.weights,
        control=profile.control,
        safescale=profile.safescale,
        bootstrap=profile.bootstrap,
    )

评论