Add Krea2 pose variant catalog loader

This commit is contained in:
2026-06-29 02:31:03 +02:00
parent 484fb40638
commit 40ee843baf
3 changed files with 133 additions and 1 deletions
+2 -1
View File
@@ -14,7 +14,8 @@ Machine-readable pose variants live in
than the full atlas: it only contains variants that are proven or useful than the full atlas: it only contains variants that are proven or useful
candidates for fixed-seed Krea2 tuning. Add a variant there when it has a compact candidates for fixed-seed Krea2 tuning. Add a variant there when it has a compact
geometry summary, cue phrases, avoid phrases, references, and a known generator geometry summary, cue phrases, avoid phrases, references, and a known generator
hook. hook. Code should read it through `krea2_pose_variant_catalog.py` instead of
parsing the JSON directly.
## Inventory ## Inventory
+93
View File
@@ -0,0 +1,93 @@
from __future__ import annotations
import copy
import json
from functools import lru_cache
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parent
DEFAULT_CATALOG_PATH = ROOT / "categories" / "krea2_pov_pose_variants.json"
def _path_key(path: str | Path | None = None) -> str:
return str(Path(path or DEFAULT_CATALOG_PATH).resolve())
@lru_cache(maxsize=8)
def _load_raw_catalog(path_key: str) -> dict[str, Any]:
with Path(path_key).open("r", encoding="utf-8") as handle:
data = json.load(handle)
return data if isinstance(data, dict) else {}
def clear_cache() -> None:
_load_raw_catalog.cache_clear()
def load_catalog(path: str | Path | None = None) -> dict[str, Any]:
return copy.deepcopy(_load_raw_catalog(_path_key(path)))
def variants(
*,
status: str | None = None,
family: str | None = None,
action_family: str | None = None,
path: str | Path | None = None,
) -> list[dict[str, Any]]:
catalog = load_catalog(path)
rows = catalog.get("variants") or []
if not isinstance(rows, list):
return []
filtered: list[dict[str, Any]] = []
for row in rows:
if not isinstance(row, dict):
continue
if status is not None and row.get("status") != status:
continue
if family is not None and row.get("family") != family:
continue
if action_family is not None and row.get("action_family") != action_family:
continue
filtered.append(row)
return filtered
def variant_keys(
*,
status: str | None = None,
family: str | None = None,
action_family: str | None = None,
path: str | Path | None = None,
) -> list[str]:
return [
str(row.get("key"))
for row in variants(status=status, family=family, action_family=action_family, path=path)
if row.get("key")
]
def get_variant(key: str, *, path: str | Path | None = None) -> dict[str, Any]:
for row in variants(path=path):
if row.get("key") == key:
return row
return {}
def reference_paths(key: str, *, path: str | Path | None = None) -> list[Path]:
catalog = load_catalog(path)
atlas_root = Path(str(catalog.get("atlas_root") or ""))
variant = get_variant(key, path=path)
refs = variant.get("reference_images") or []
if not isinstance(refs, list):
return []
paths: list[Path] = []
for ref in refs:
ref_path = Path(str(ref))
if ".." in ref_path.parts:
continue
paths.append(atlas_root / ref_path)
return paths
+38
View File
@@ -61,6 +61,7 @@ import krea_format_route # noqa: E402
import krea_formatter # noqa: E402 import krea_formatter # noqa: E402
import krea_normal_formatter # noqa: E402 import krea_normal_formatter # noqa: E402
import krea_pair_formatter # noqa: E402 import krea_pair_formatter # noqa: E402
import krea2_pose_variant_catalog # noqa: E402
import krea_row_fields # noqa: E402 import krea_row_fields # noqa: E402
import location_config # noqa: E402 import location_config # noqa: E402
import loop_nodes # noqa: E402 import loop_nodes # noqa: E402
@@ -6766,6 +6767,42 @@ def smoke_krea2_pov_pose_variant_catalog() -> None:
_expect(required_keys.issubset(seen_keys), "Krea2 POV pose-variant catalog lost a proven starter variant") _expect(required_keys.issubset(seen_keys), "Krea2 POV pose-variant catalog lost a proven starter variant")
def smoke_krea2_pose_variant_catalog_policy() -> None:
catalog = krea2_pose_variant_catalog.load_catalog()
_expect(catalog.get("version") == 1, "Krea2 pose-variant loader returned wrong catalog")
proven = krea2_pose_variant_catalog.variant_keys(status="proven")
_expect(
proven == [
"pov_doggy_top_down_rear_entry",
"pov_boobjob_upright_cleavage",
"pov_handjob_upright_centered",
],
f"Krea2 pose-variant proven keys changed unexpectedly: {proven}",
)
outercourse = krea2_pose_variant_catalog.variant_keys(action_family="outercourse")
_expect(
outercourse == [
"pov_boobjob_upright_cleavage",
"pov_handjob_upright_centered",
"pov_ballsucking_low_head",
],
f"Krea2 pose-variant outercourse filtering changed unexpectedly: {outercourse}",
)
handjob = krea2_pose_variant_catalog.get_variant("pov_handjob_upright_centered")
_expect(
any("woman's right hand wraps" in str(cue) for cue in handjob.get("prompt_cues", [])),
"Handjob variant lost hand ownership cue",
)
handjob["prompt_cues"].append("mutation should not leak")
clean_handjob = krea2_pose_variant_catalog.get_variant("pov_handjob_upright_centered")
_expect("mutation should not leak" not in clean_handjob.get("prompt_cues", []), "Catalog loader leaked caller mutation")
refs = krea2_pose_variant_catalog.reference_paths("pov_boobjob_upright_cleavage")
_expect(refs and all(path.name.endswith(".png") for path in refs), "Boobjob reference paths are not image paths")
_expect(all("bg" not in str(path).lower() for path in refs), "Reference paths should not include background-only atlas images")
missing = krea2_pose_variant_catalog.get_variant("missing_pose_variant")
_expect(missing == {}, "Missing pose variant should return an empty mapping")
def smoke_krea_pov_penetration_route() -> None: def smoke_krea_pov_penetration_route() -> None:
pair = pb.build_insta_of_pair( pair = pb.build_insta_of_pair(
row_number=1, row_number=1,
@@ -9726,6 +9763,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
("insta_pair_camera_split", smoke_insta_pair_camera_split), ("insta_pair_camera_split", smoke_insta_pair_camera_split),
("pov_camera_scene", smoke_pov_camera_scene), ("pov_camera_scene", smoke_pov_camera_scene),
("krea2_pov_pose_variant_catalog", smoke_krea2_pov_pose_variant_catalog), ("krea2_pov_pose_variant_catalog", smoke_krea2_pov_pose_variant_catalog),
("krea2_pose_variant_catalog_policy", smoke_krea2_pose_variant_catalog_policy),
("krea_pov_penetration_route", smoke_krea_pov_penetration_route), ("krea_pov_penetration_route", smoke_krea_pov_penetration_route),
("pov_outercourse_position_routes", smoke_pov_outercourse_position_routes), ("pov_outercourse_position_routes", smoke_pov_outercourse_position_routes),
("pov_oral_position_routes", smoke_pov_oral_position_routes), ("pov_oral_position_routes", smoke_pov_oral_position_routes),