Files
ComfyUI-Ethanfel-Prompt-Bui…/krea2_pose_variant_catalog.py
T

120 lines
3.5 KiB
Python

from __future__ import annotations
import copy
import json
from functools import lru_cache
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parent
DEFAULT_CATALOG_PATH = ROOT / "categories" / "krea2_pov_pose_variants.json"
def _path_key(path: str | Path | None = None) -> str:
return str(Path(path or DEFAULT_CATALOG_PATH).resolve())
@lru_cache(maxsize=8)
def _load_raw_catalog(path_key: str) -> dict[str, Any]:
with Path(path_key).open("r", encoding="utf-8") as handle:
data = json.load(handle)
return data if isinstance(data, dict) else {}
def clear_cache() -> None:
_load_raw_catalog.cache_clear()
def load_catalog(path: str | Path | None = None) -> dict[str, Any]:
return copy.deepcopy(_load_raw_catalog(_path_key(path)))
def variants(
*,
status: str | None = None,
family: str | None = None,
action_family: str | None = None,
path: str | Path | None = None,
) -> list[dict[str, Any]]:
catalog = load_catalog(path)
rows = catalog.get("variants") or []
if not isinstance(rows, list):
return []
filtered: list[dict[str, Any]] = []
for row in rows:
if not isinstance(row, dict):
continue
if status is not None and row.get("status") != status:
continue
if family is not None and row.get("family") != family:
continue
if action_family is not None and row.get("action_family") != action_family:
continue
filtered.append(row)
return filtered
def variant_keys(
*,
status: str | None = None,
family: str | None = None,
action_family: str | None = None,
path: str | Path | None = None,
) -> list[str]:
return [
str(row.get("key"))
for row in variants(status=status, family=family, action_family=action_family, path=path)
if row.get("key")
]
def get_variant(key: str, *, path: str | Path | None = None) -> dict[str, Any]:
for row in variants(path=path):
if row.get("key") == key:
return row
return {}
def _cue_list(value: Any) -> list[str]:
if isinstance(value, dict):
value = value.get("prompt_cues") or value.get("cues")
if not isinstance(value, list):
return []
return [str(cue) for cue in value if str(cue).strip()]
def prompt_cue_sets(variant_or_key: dict[str, Any] | str) -> list[list[str]]:
variant = get_variant(variant_or_key) if isinstance(variant_or_key, str) else dict(variant_or_key or {})
if not variant:
return []
cue_sets: list[list[str]] = []
baseline = _cue_list(variant.get("prompt_cues"))
if baseline:
cue_sets.append(baseline)
for cue_set in variant.get("prompt_variant_cues") or []:
cues = _cue_list(cue_set)
if cues:
cue_sets.append(cues)
if not cue_sets:
fallback = str(variant.get("canonical_geometry") or "").strip()
if fallback:
cue_sets.append([fallback])
return cue_sets
def reference_paths(key: str, *, path: str | Path | None = None) -> list[Path]:
catalog = load_catalog(path)
atlas_root = Path(str(catalog.get("atlas_root") or ""))
variant = get_variant(key, path=path)
refs = variant.get("reference_images") or []
if not isinstance(refs, list):
return []
paths: list[Path] = []
for ref in refs:
ref_path = Path(str(ref))
if ".." in ref_path.parts:
continue
paths.append(atlas_root / ref_path)
return paths