From 40ee843bafff170269bba1fb6b8f3e53aac81cf9 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Mon, 29 Jun 2026 02:31:03 +0200 Subject: [PATCH] Add Krea2 pose variant catalog loader --- docs/krea2-pov-pose-atlas.md | 3 +- krea2_pose_variant_catalog.py | 93 +++++++++++++++++++++++++++++++++++ tools/prompt_smoke.py | 38 ++++++++++++++ 3 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 krea2_pose_variant_catalog.py diff --git a/docs/krea2-pov-pose-atlas.md b/docs/krea2-pov-pose-atlas.md index f6d2c21..ebcf2cf 100644 --- a/docs/krea2-pov-pose-atlas.md +++ b/docs/krea2-pov-pose-atlas.md @@ -14,7 +14,8 @@ Machine-readable pose variants live in than the full atlas: it only contains variants that are proven or useful candidates for fixed-seed Krea2 tuning. Add a variant there when it has a compact geometry summary, cue phrases, avoid phrases, references, and a known generator -hook. +hook. Code should read it through `krea2_pose_variant_catalog.py` instead of +parsing the JSON directly. ## Inventory diff --git a/krea2_pose_variant_catalog.py b/krea2_pose_variant_catalog.py new file mode 100644 index 0000000..403cc9a --- /dev/null +++ b/krea2_pose_variant_catalog.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import copy +import json +from functools import lru_cache +from pathlib import Path +from typing import Any + + +ROOT = Path(__file__).resolve().parent +DEFAULT_CATALOG_PATH = ROOT / "categories" / "krea2_pov_pose_variants.json" + + +def _path_key(path: str | Path | None = None) -> str: + return str(Path(path or DEFAULT_CATALOG_PATH).resolve()) + + +@lru_cache(maxsize=8) +def _load_raw_catalog(path_key: str) -> dict[str, Any]: + with Path(path_key).open("r", encoding="utf-8") as handle: + data = json.load(handle) + return data if isinstance(data, dict) else {} + + +def clear_cache() -> None: + _load_raw_catalog.cache_clear() + + +def load_catalog(path: str | Path | None = None) -> dict[str, Any]: + return copy.deepcopy(_load_raw_catalog(_path_key(path))) + + +def variants( + *, + status: str | None = None, + family: str | None = None, + action_family: str | None = None, + path: str | Path | None = None, +) -> list[dict[str, Any]]: + catalog = load_catalog(path) + rows = catalog.get("variants") or [] + if not isinstance(rows, list): + return [] + filtered: list[dict[str, Any]] = [] + for row in rows: + if not isinstance(row, dict): + continue + if status is not None and row.get("status") != status: + continue + if family is not None and row.get("family") != family: + continue + if action_family is not None and row.get("action_family") != action_family: + continue + filtered.append(row) + return filtered + + +def variant_keys( + *, + status: str | None = None, + family: str | None = None, + action_family: str | None = None, + path: str | Path | None = None, +) -> list[str]: + return [ + str(row.get("key")) + for row in variants(status=status, family=family, action_family=action_family, path=path) + if row.get("key") + ] + + +def get_variant(key: str, *, path: str | Path | None = None) -> dict[str, Any]: + for row in variants(path=path): + if row.get("key") == key: + return row + return {} + + +def reference_paths(key: str, *, path: str | Path | None = None) -> list[Path]: + catalog = load_catalog(path) + atlas_root = Path(str(catalog.get("atlas_root") or "")) + variant = get_variant(key, path=path) + refs = variant.get("reference_images") or [] + if not isinstance(refs, list): + return [] + paths: list[Path] = [] + for ref in refs: + ref_path = Path(str(ref)) + if ".." in ref_path.parts: + continue + paths.append(atlas_root / ref_path) + return paths + diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index e104783..1a85b97 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -61,6 +61,7 @@ import krea_format_route # noqa: E402 import krea_formatter # noqa: E402 import krea_normal_formatter # noqa: E402 import krea_pair_formatter # noqa: E402 +import krea2_pose_variant_catalog # noqa: E402 import krea_row_fields # noqa: E402 import location_config # noqa: E402 import loop_nodes # noqa: E402 @@ -6766,6 +6767,42 @@ def smoke_krea2_pov_pose_variant_catalog() -> None: _expect(required_keys.issubset(seen_keys), "Krea2 POV pose-variant catalog lost a proven starter variant") +def smoke_krea2_pose_variant_catalog_policy() -> None: + catalog = krea2_pose_variant_catalog.load_catalog() + _expect(catalog.get("version") == 1, "Krea2 pose-variant loader returned wrong catalog") + proven = krea2_pose_variant_catalog.variant_keys(status="proven") + _expect( + proven == [ + "pov_doggy_top_down_rear_entry", + "pov_boobjob_upright_cleavage", + "pov_handjob_upright_centered", + ], + f"Krea2 pose-variant proven keys changed unexpectedly: {proven}", + ) + outercourse = krea2_pose_variant_catalog.variant_keys(action_family="outercourse") + _expect( + outercourse == [ + "pov_boobjob_upright_cleavage", + "pov_handjob_upright_centered", + "pov_ballsucking_low_head", + ], + f"Krea2 pose-variant outercourse filtering changed unexpectedly: {outercourse}", + ) + handjob = krea2_pose_variant_catalog.get_variant("pov_handjob_upright_centered") + _expect( + any("woman's right hand wraps" in str(cue) for cue in handjob.get("prompt_cues", [])), + "Handjob variant lost hand ownership cue", + ) + handjob["prompt_cues"].append("mutation should not leak") + clean_handjob = krea2_pose_variant_catalog.get_variant("pov_handjob_upright_centered") + _expect("mutation should not leak" not in clean_handjob.get("prompt_cues", []), "Catalog loader leaked caller mutation") + refs = krea2_pose_variant_catalog.reference_paths("pov_boobjob_upright_cleavage") + _expect(refs and all(path.name.endswith(".png") for path in refs), "Boobjob reference paths are not image paths") + _expect(all("bg" not in str(path).lower() for path in refs), "Reference paths should not include background-only atlas images") + missing = krea2_pose_variant_catalog.get_variant("missing_pose_variant") + _expect(missing == {}, "Missing pose variant should return an empty mapping") + + def smoke_krea_pov_penetration_route() -> None: pair = pb.build_insta_of_pair( row_number=1, @@ -9726,6 +9763,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [ ("insta_pair_camera_split", smoke_insta_pair_camera_split), ("pov_camera_scene", smoke_pov_camera_scene), ("krea2_pov_pose_variant_catalog", smoke_krea2_pov_pose_variant_catalog), + ("krea2_pose_variant_catalog_policy", smoke_krea2_pose_variant_catalog_policy), ("krea_pov_penetration_route", smoke_krea_pov_penetration_route), ("pov_outercourse_position_routes", smoke_pov_outercourse_position_routes), ("pov_oral_position_routes", smoke_pov_oral_position_routes),