from __future__ import annotations import copy import json from functools import lru_cache from pathlib import Path from typing import Any ROOT = Path(__file__).resolve().parent DEFAULT_EVAL_LOG_PATH = ROOT / "docs" / "krea2-eval-log.json" VALID_RESULTS = {"accepted", "rejected", "inconclusive"} VALID_DECISIONS = {"generator_patch", "prompt_guide_rule", "prompt_only_retry", "needs_more_tests"} def _path_key(path: str | Path | None = None) -> str: return str(Path(path or DEFAULT_EVAL_LOG_PATH).resolve()) @lru_cache(maxsize=8) def _load_raw_eval_log(path_key: str) -> dict[str, Any]: with Path(path_key).open("r", encoding="utf-8") as handle: data = json.load(handle) return data if isinstance(data, dict) else {} def clear_cache() -> None: _load_raw_eval_log.cache_clear() def load_eval_log(path: str | Path | None = None) -> dict[str, Any]: return copy.deepcopy(_load_raw_eval_log(_path_key(path))) def _text(value: Any) -> str: return value if isinstance(value, str) else "" def _require_text(errors: list[str], entry: dict[str, Any], key: str, min_len: int) -> None: value = _text(entry.get(key)).strip() if len(value) < min_len: errors.append(f"{key} must be at least {min_len} characters") def validate_entry( entry: dict[str, Any], *, existing_entries: list[dict[str, Any]] | None = None, catalog_keys: set[str] | None = None, ) -> list[str]: errors: list[str] = [] if not isinstance(entry, dict): return ["entry must be an object"] _require_text(errors, entry, "id", 6) entry_id = _text(entry.get("id")).strip() if entry_id and existing_entries: existing_ids = {_text(row.get("id")).strip() for row in existing_entries if isinstance(row, dict)} if entry_id in existing_ids: errors.append(f"duplicate id {entry_id!r}") _require_text(errors, entry, "variant_key", 8) variant_key = _text(entry.get("variant_key")).strip() if variant_key and catalog_keys is not None and variant_key not in catalog_keys: errors.append(f"unknown variant {variant_key!r}") seed = entry.get("seed") if not isinstance(seed, int) or isinstance(seed, bool): errors.append("seed must be an integer") result = entry.get("result") if result not in VALID_RESULTS: errors.append(f"result must be one of {sorted(VALID_RESULTS)}") decision = entry.get("decision") if decision not in VALID_DECISIONS: errors.append(f"decision must be one of {sorted(VALID_DECISIONS)}") _require_text(errors, entry, "baseline_prompt_summary", 20) _require_text(errors, entry, "candidate_prompt_summary", 20) _require_text(errors, entry, "observation", 30) for image_key in ("baseline_image", "candidate_image"): image_path = _text(entry.get(image_key)).strip() if not image_path: continue path = Path(image_path) if not path.is_absolute(): errors.append(f"{image_key} must be absolute when present") if path.suffix.lower() != ".png": errors.append(f"{image_key} must reference a PNG artifact") return errors def save_eval_log(log: dict[str, Any], *, path: str | Path | None = None) -> None: target = Path(path or DEFAULT_EVAL_LOG_PATH) target.write_text(json.dumps(log, ensure_ascii=True, indent=2) + "\n", encoding="utf-8") clear_cache() def append_entry( entry: dict[str, Any], *, path: str | Path | None = None, catalog_path: str | Path | None = None, dry_run: bool = False, ) -> dict[str, Any]: try: from . import krea2_pose_variant_catalog except ImportError: # Allows local smoke tests from the repository root. import krea2_pose_variant_catalog log = load_eval_log(path) rows = log.get("entries") if not isinstance(rows, list): rows = [] log["entries"] = rows new_entry = copy.deepcopy(entry) errors = validate_entry( new_entry, existing_entries=[row for row in rows if isinstance(row, dict)], catalog_keys=set(krea2_pose_variant_catalog.variant_keys(path=catalog_path)), ) if errors: raise ValueError("; ".join(errors)) rows.append(new_entry) if not dry_run: save_eval_log(log, path=path) return copy.deepcopy(log) def entries( *, variant_key: str | None = None, result: str | None = None, decision: str | None = None, path: str | Path | None = None, ) -> list[dict[str, Any]]: log = load_eval_log(path) rows = log.get("entries") or [] if not isinstance(rows, list): return [] filtered: list[dict[str, Any]] = [] for row in rows: if not isinstance(row, dict): continue if variant_key is not None and row.get("variant_key") != variant_key: continue if result is not None and row.get("result") != result: continue if decision is not None and row.get("decision") != decision: continue filtered.append(row) return filtered def entries_for_variant( variant_key: str, *, result: str | None = None, decision: str | None = None, path: str | Path | None = None, ) -> list[dict[str, Any]]: return entries(variant_key=variant_key, result=result, decision=decision, path=path) def variant_keys( *, result: str | None = None, decision: str | None = None, path: str | Path | None = None, ) -> list[str]: keys: list[str] = [] for row in entries(result=result, decision=decision, path=path): key = row.get("variant_key") if key and key not in keys: keys.append(str(key)) return keys