from __future__ import annotations import copy import json from functools import lru_cache from pathlib import Path from typing import Any ROOT = Path(__file__).resolve().parent DEFAULT_CATALOG_PATH = ROOT / "categories" / "krea2_pov_pose_variants.json" def _path_key(path: str | Path | None = None) -> str: return str(Path(path or DEFAULT_CATALOG_PATH).resolve()) @lru_cache(maxsize=8) def _load_raw_catalog(path_key: str) -> dict[str, Any]: with Path(path_key).open("r", encoding="utf-8") as handle: data = json.load(handle) return data if isinstance(data, dict) else {} def clear_cache() -> None: _load_raw_catalog.cache_clear() def load_catalog(path: str | Path | None = None) -> dict[str, Any]: return copy.deepcopy(_load_raw_catalog(_path_key(path))) def variants( *, status: str | None = None, family: str | None = None, action_family: str | None = None, path: str | Path | None = None, ) -> list[dict[str, Any]]: catalog = load_catalog(path) rows = catalog.get("variants") or [] if not isinstance(rows, list): return [] filtered: list[dict[str, Any]] = [] for row in rows: if not isinstance(row, dict): continue if status is not None and row.get("status") != status: continue if family is not None and row.get("family") != family: continue if action_family is not None and row.get("action_family") != action_family: continue filtered.append(row) return filtered def variant_keys( *, status: str | None = None, family: str | None = None, action_family: str | None = None, path: str | Path | None = None, ) -> list[str]: return [ str(row.get("key")) for row in variants(status=status, family=family, action_family=action_family, path=path) if row.get("key") ] def get_variant(key: str, *, path: str | Path | None = None) -> dict[str, Any]: for row in variants(path=path): if row.get("key") == key: return row return {} def _cue_list(value: Any) -> list[str]: if isinstance(value, dict): value = value.get("prompt_cues") or value.get("cues") if not isinstance(value, list): return [] return [str(cue) for cue in value if str(cue).strip()] def prompt_cue_sets(variant_or_key: dict[str, Any] | str) -> list[list[str]]: variant = get_variant(variant_or_key) if isinstance(variant_or_key, str) else dict(variant_or_key or {}) if not variant: return [] cue_sets: list[list[str]] = [] baseline = _cue_list(variant.get("prompt_cues")) if baseline: cue_sets.append(baseline) for cue_set in variant.get("prompt_variant_cues") or []: cues = _cue_list(cue_set) if cues: cue_sets.append(cues) if not cue_sets: fallback = str(variant.get("canonical_geometry") or "").strip() if fallback: cue_sets.append([fallback]) return cue_sets def reference_paths(key: str, *, path: str | Path | None = None) -> list[Path]: catalog = load_catalog(path) atlas_root = Path(str(catalog.get("atlas_root") or "")) variant = get_variant(key, path=path) refs = variant.get("reference_images") or [] if not isinstance(refs, list): return [] paths: list[Path] = [] for ref in refs: ref_path = Path(str(ref)) if ".." in ref_path.parts: continue paths.append(atlas_root / ref_path) return paths