Add Krea2 fixed-seed eval log
This commit is contained in:
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parent
|
||||
DEFAULT_EVAL_LOG_PATH = ROOT / "docs" / "krea2-eval-log.json"
|
||||
|
||||
|
||||
def _path_key(path: str | Path | None = None) -> str:
|
||||
return str(Path(path or DEFAULT_EVAL_LOG_PATH).resolve())
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _load_raw_eval_log(path_key: str) -> dict[str, Any]:
|
||||
with Path(path_key).open("r", encoding="utf-8") as handle:
|
||||
data = json.load(handle)
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
_load_raw_eval_log.cache_clear()
|
||||
|
||||
|
||||
def load_eval_log(path: str | Path | None = None) -> dict[str, Any]:
|
||||
return copy.deepcopy(_load_raw_eval_log(_path_key(path)))
|
||||
|
||||
|
||||
def entries(
|
||||
*,
|
||||
variant_key: str | None = None,
|
||||
result: str | None = None,
|
||||
decision: str | None = None,
|
||||
path: str | Path | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
log = load_eval_log(path)
|
||||
rows = log.get("entries") or []
|
||||
if not isinstance(rows, list):
|
||||
return []
|
||||
filtered: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
if variant_key is not None and row.get("variant_key") != variant_key:
|
||||
continue
|
||||
if result is not None and row.get("result") != result:
|
||||
continue
|
||||
if decision is not None and row.get("decision") != decision:
|
||||
continue
|
||||
filtered.append(row)
|
||||
return filtered
|
||||
|
||||
|
||||
def entries_for_variant(
|
||||
variant_key: str,
|
||||
*,
|
||||
result: str | None = None,
|
||||
decision: str | None = None,
|
||||
path: str | Path | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
return entries(variant_key=variant_key, result=result, decision=decision, path=path)
|
||||
|
||||
|
||||
def variant_keys(
|
||||
*,
|
||||
result: str | None = None,
|
||||
decision: str | None = None,
|
||||
path: str | Path | None = None,
|
||||
) -> list[str]:
|
||||
keys: list[str] = []
|
||||
for row in entries(result=result, decision=decision, path=path):
|
||||
key = row.get("variant_key")
|
||||
if key and key not in keys:
|
||||
keys.append(str(key))
|
||||
return keys
|
||||
|
||||
Reference in New Issue
Block a user