Add Krea2 pose variant catalog loader
This commit is contained in:
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parent
|
||||
DEFAULT_CATALOG_PATH = ROOT / "categories" / "krea2_pov_pose_variants.json"
|
||||
|
||||
|
||||
def _path_key(path: str | Path | None = None) -> str:
|
||||
return str(Path(path or DEFAULT_CATALOG_PATH).resolve())
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _load_raw_catalog(path_key: str) -> dict[str, Any]:
|
||||
with Path(path_key).open("r", encoding="utf-8") as handle:
|
||||
data = json.load(handle)
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
_load_raw_catalog.cache_clear()
|
||||
|
||||
|
||||
def load_catalog(path: str | Path | None = None) -> dict[str, Any]:
|
||||
return copy.deepcopy(_load_raw_catalog(_path_key(path)))
|
||||
|
||||
|
||||
def variants(
|
||||
*,
|
||||
status: str | None = None,
|
||||
family: str | None = None,
|
||||
action_family: str | None = None,
|
||||
path: str | Path | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
catalog = load_catalog(path)
|
||||
rows = catalog.get("variants") or []
|
||||
if not isinstance(rows, list):
|
||||
return []
|
||||
filtered: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
if status is not None and row.get("status") != status:
|
||||
continue
|
||||
if family is not None and row.get("family") != family:
|
||||
continue
|
||||
if action_family is not None and row.get("action_family") != action_family:
|
||||
continue
|
||||
filtered.append(row)
|
||||
return filtered
|
||||
|
||||
|
||||
def variant_keys(
|
||||
*,
|
||||
status: str | None = None,
|
||||
family: str | None = None,
|
||||
action_family: str | None = None,
|
||||
path: str | Path | None = None,
|
||||
) -> list[str]:
|
||||
return [
|
||||
str(row.get("key"))
|
||||
for row in variants(status=status, family=family, action_family=action_family, path=path)
|
||||
if row.get("key")
|
||||
]
|
||||
|
||||
|
||||
def get_variant(key: str, *, path: str | Path | None = None) -> dict[str, Any]:
|
||||
for row in variants(path=path):
|
||||
if row.get("key") == key:
|
||||
return row
|
||||
return {}
|
||||
|
||||
|
||||
def reference_paths(key: str, *, path: str | Path | None = None) -> list[Path]:
|
||||
catalog = load_catalog(path)
|
||||
atlas_root = Path(str(catalog.get("atlas_root") or ""))
|
||||
variant = get_variant(key, path=path)
|
||||
refs = variant.get("reference_images") or []
|
||||
if not isinstance(refs, list):
|
||||
return []
|
||||
paths: list[Path] = []
|
||||
for ref in refs:
|
||||
ref_path = Path(str(ref))
|
||||
if ".." in ref_path.parts:
|
||||
continue
|
||||
paths.append(atlas_root / ref_path)
|
||||
return paths
|
||||
|
||||
Reference in New Issue
Block a user