From 4c315534097baedd84695254882dd5d56d5a61a6 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 27 Jun 2026 00:22:17 +0200 Subject: [PATCH] Extract category cast config policy --- category_cast_config.py | 114 +++++++++++++++++++ docs/prompt-architecture-improvement-plan.md | 3 + docs/prompt-pool-routing-map.md | 1 + node_route_config.py | 8 +- prompt_builder.py | 88 ++------------ tools/prompt_smoke.py | 25 ++++ 6 files changed, 159 insertions(+), 80 deletions(-) create mode 100644 category_cast_config.py diff --git a/category_cast_config.py b/category_cast_config.py new file mode 100644 index 0000000..e358b88 --- /dev/null +++ b/category_cast_config.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import json +from typing import Any + + +RANDOM_SUBCATEGORY = "random" + +CATEGORY_PRESETS = { + "auto_weighted": ("auto_weighted", RANDOM_SUBCATEGORY), + "auto_full": ("auto_full", RANDOM_SUBCATEGORY), + "women_casual": ("Casual clothes", RANDOM_SUBCATEGORY), + "men_casual": ("Men casual clothes", RANDOM_SUBCATEGORY), + "couple_casual": ("Couple casual clothes", RANDOM_SUBCATEGORY), + "provocative_erotic": ("Provocative erotic clothes", RANDOM_SUBCATEGORY), + "hardcore_pose": ("Hardcore sexual poses", RANDOM_SUBCATEGORY), + "custom_random": ("custom_random", RANDOM_SUBCATEGORY), +} + +CAST_PRESETS = { + "solo_woman": (1, 0), + "solo_man": (0, 1), + "mixed_couple": (1, 1), + "two_women": (2, 0), + "two_men": (0, 2), + "threesome_2w1m": (2, 1), + "small_group_3w2m": (3, 2), +} + + +def category_preset_choices() -> list[str]: + return list(CATEGORY_PRESETS) + + +def cast_preset_choices() -> list[str]: + return list(CAST_PRESETS) + ["custom_counts"] + + +def build_category_config_json(preset: str = "auto_weighted", subcategory: str = RANDOM_SUBCATEGORY) -> str: + category, default_subcategory = CATEGORY_PRESETS.get(preset, CATEGORY_PRESETS["auto_weighted"]) + chosen_subcategory = subcategory if subcategory and subcategory != RANDOM_SUBCATEGORY else default_subcategory + return json.dumps( + { + "preset": preset if preset in CATEGORY_PRESETS else "auto_weighted", + "category": category, + "subcategory": chosen_subcategory, + }, + ensure_ascii=True, + sort_keys=True, + ) + + +def parse_category_config(category_config: str | dict[str, Any] | None) -> tuple[str, str]: + if not category_config: + return CATEGORY_PRESETS["auto_weighted"] + if isinstance(category_config, dict): + raw = category_config + else: + try: + raw = json.loads(str(category_config)) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid category_config JSON: {exc}") from exc + if not isinstance(raw, dict): + raise ValueError("category_config must be a JSON object") + preset = str(raw.get("preset") or "auto_weighted") + category, subcategory = CATEGORY_PRESETS.get(preset, CATEGORY_PRESETS["auto_weighted"]) + category = str(raw.get("category") or category) + subcategory = str(raw.get("subcategory") or subcategory or RANDOM_SUBCATEGORY) + return category, subcategory + + +def build_cast_config_json(cast_mode: str = "mixed_couple", women_count: int = 1, men_count: int = 1) -> str: + if cast_mode in CAST_PRESETS: + women_count, men_count = CAST_PRESETS[cast_mode] + else: + women_count = max(0, min(12, int(women_count))) + men_count = max(0, min(12, int(men_count))) + if women_count + men_count == 0: + women_count = 1 + cast_mode = "custom_counts" + return json.dumps( + { + "cast_mode": cast_mode, + "women_count": int(women_count), + "men_count": int(men_count), + }, + ensure_ascii=True, + sort_keys=True, + ) + + +def parse_cast_config(cast_config: str | dict[str, Any] | None) -> dict[str, int | str]: + if not cast_config: + return {"cast_mode": "mixed_couple", "women_count": 1, "men_count": 1} + if isinstance(cast_config, dict): + raw = cast_config + else: + try: + raw = json.loads(str(cast_config)) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid cast_config JSON: {exc}") from exc + if not isinstance(raw, dict): + raise ValueError("cast_config must be a JSON object") + return json.loads( + build_cast_config_json( + str(raw.get("cast_mode") or "custom_counts"), + raw.get("women_count", 1), + raw.get("men_count", 1), + ) + ) + + +_parse_category_config = parse_category_config +_parse_cast_config = parse_cast_config diff --git a/docs/prompt-architecture-improvement-plan.md b/docs/prompt-architecture-improvement-plan.md index 1f4287b..49bd6c0 100644 --- a/docs/prompt-architecture-improvement-plan.md +++ b/docs/prompt-architecture-improvement-plan.md @@ -104,6 +104,9 @@ Already isolated: - JSON category loading, subcategory normalization, named scene/expression/ composition pool loading, cast compatibility filtering, exact subcategory lookup, and inheritance-based pool merging live in `category_library.py`. +- category/cast route preset schemas, config JSON builders, choice lists, and + parsers live in `category_cast_config.py`; `prompt_builder.py` keeps public + delegate wrappers for existing nodes and tests. - location/composition config presets, themed location packs, custom location/composition entry parsing, merge behavior, and config parsing live in `location_config.py`; `prompt_builder.py` still applies selected configs diff --git a/docs/prompt-pool-routing-map.md b/docs/prompt-pool-routing-map.md index 1e6af1d..af57165 100644 --- a/docs/prompt-pool-routing-map.md +++ b/docs/prompt-pool-routing-map.md @@ -68,6 +68,7 @@ Core helper ownership: | Python module | What it owns | | --- | --- | | `category_library.py` | JSON category loading, subcategory normalization, named scene/expression/composition pool loading, cast compatibility filtering, exact subcategory lookup, and inheritance-based pool merging. | +| `category_cast_config.py` | Category preset and cast preset schemas, category/cast config JSON builders, choice lists, and config parsers used by route nodes. | | `camera_config.py` | Camera option schema, direct/orbit/Qwen camera JSON builders, camera config parsing, plain camera directive text, and camera caption labels. | | `seed_config.py` | Seed axis salts/aliases, seed mode choices, global/axis lock JSON builders, seed config parsing, row seed math, and deterministic axis RNG construction. | | `location_config.py` | Location/composition preset schemas, themed location packs, custom location/composition parsing, pool merge behavior, and location/composition config parsing. | diff --git a/node_route_config.py b/node_route_config.py index 104ce7d..72876de 100644 --- a/node_route_config.py +++ b/node_route_config.py @@ -4,11 +4,13 @@ import json import random try: - from .prompt_builder import ( + from .category_cast_config import ( build_cast_config_json, build_category_config_json, cast_preset_choices, category_preset_choices, + ) + from .prompt_builder import ( subcategory_choices, ) from .location_config import ( @@ -20,11 +22,13 @@ try: location_theme_choices, ) except ImportError: # Allows local smoke tests from the repository root. - from prompt_builder import ( + from category_cast_config import ( build_cast_config_json, build_category_config_json, cast_preset_choices, category_preset_choices, + ) + from prompt_builder import ( subcategory_choices, ) from location_config import ( diff --git a/prompt_builder.py b/prompt_builder.py index f582e55..c6d2824 100644 --- a/prompt_builder.py +++ b/prompt_builder.py @@ -24,6 +24,7 @@ try: template_list as _template_list, ) from . import camera_config as camera_policy + from . import category_cast_config as category_cast_policy from . import generate_prompt_batches as g from . import location_config as location_policy from . import pair_clothing @@ -62,6 +63,7 @@ except ImportError: # Allows local smoke tests with `python -c`. template_list as _template_list, ) import camera_config as camera_policy + import category_cast_config as category_cast_policy import generate_prompt_batches as g import location_config as location_policy import pair_clothing @@ -1030,26 +1032,8 @@ def seed_mode_choices() -> list[str]: return seed_policy.seed_mode_choices() -CATEGORY_PRESETS = { - "auto_weighted": ("auto_weighted", RANDOM_SUBCATEGORY), - "auto_full": ("auto_full", RANDOM_SUBCATEGORY), - "women_casual": ("Casual clothes", RANDOM_SUBCATEGORY), - "men_casual": ("Men casual clothes", RANDOM_SUBCATEGORY), - "couple_casual": ("Couple casual clothes", RANDOM_SUBCATEGORY), - "provocative_erotic": ("Provocative erotic clothes", RANDOM_SUBCATEGORY), - "hardcore_pose": ("Hardcore sexual poses", RANDOM_SUBCATEGORY), - "custom_random": ("custom_random", RANDOM_SUBCATEGORY), -} - -CAST_PRESETS = { - "solo_woman": (1, 0), - "solo_man": (0, 1), - "mixed_couple": (1, 1), - "two_women": (2, 0), - "two_men": (0, 2), - "threesome_2w1m": (2, 1), - "small_group_3w2m": (3, 2), -} +CATEGORY_PRESETS = category_cast_policy.CATEGORY_PRESETS +CAST_PRESETS = category_cast_policy.CAST_PRESETS GENERATION_PROFILE_PRESETS = { "balanced": { @@ -1122,11 +1106,11 @@ GENERATION_PROFILE_PRESETS = { def category_preset_choices() -> list[str]: - return list(CATEGORY_PRESETS) + return category_cast_policy.category_preset_choices() def cast_preset_choices() -> list[str]: - return list(CAST_PRESETS) + ["custom_counts"] + return category_cast_policy.cast_preset_choices() def generation_profile_choices() -> list[str]: @@ -1134,71 +1118,19 @@ def generation_profile_choices() -> list[str]: def build_category_config_json(preset: str = "auto_weighted", subcategory: str = RANDOM_SUBCATEGORY) -> str: - category, default_subcategory = CATEGORY_PRESETS.get(preset, CATEGORY_PRESETS["auto_weighted"]) - chosen_subcategory = subcategory if subcategory and subcategory != RANDOM_SUBCATEGORY else default_subcategory - return json.dumps( - { - "preset": preset if preset in CATEGORY_PRESETS else "auto_weighted", - "category": category, - "subcategory": chosen_subcategory, - }, - ensure_ascii=True, - sort_keys=True, - ) + return category_cast_policy.build_category_config_json(preset=preset, subcategory=subcategory) def _parse_category_config(category_config: str | dict[str, Any] | None) -> tuple[str, str]: - if not category_config: - return CATEGORY_PRESETS["auto_weighted"] - if isinstance(category_config, dict): - raw = category_config - else: - try: - raw = json.loads(str(category_config)) - except json.JSONDecodeError as exc: - raise ValueError(f"Invalid category_config JSON: {exc}") from exc - if not isinstance(raw, dict): - raise ValueError("category_config must be a JSON object") - preset = str(raw.get("preset") or "auto_weighted") - category, subcategory = CATEGORY_PRESETS.get(preset, CATEGORY_PRESETS["auto_weighted"]) - category = str(raw.get("category") or category) - subcategory = str(raw.get("subcategory") or subcategory or RANDOM_SUBCATEGORY) - return category, subcategory + return category_cast_policy.parse_category_config(category_config) def build_cast_config_json(cast_mode: str = "mixed_couple", women_count: int = 1, men_count: int = 1) -> str: - if cast_mode in CAST_PRESETS: - women_count, men_count = CAST_PRESETS[cast_mode] - else: - women_count = max(0, min(12, int(women_count))) - men_count = max(0, min(12, int(men_count))) - if women_count + men_count == 0: - women_count = 1 - cast_mode = "custom_counts" - return json.dumps( - { - "cast_mode": cast_mode, - "women_count": int(women_count), - "men_count": int(men_count), - }, - ensure_ascii=True, - sort_keys=True, - ) + return category_cast_policy.build_cast_config_json(cast_mode=cast_mode, women_count=women_count, men_count=men_count) def _parse_cast_config(cast_config: str | dict[str, Any] | None) -> dict[str, int | str]: - if not cast_config: - return {"cast_mode": "mixed_couple", "women_count": 1, "men_count": 1} - if isinstance(cast_config, dict): - raw = cast_config - else: - try: - raw = json.loads(str(cast_config)) - except json.JSONDecodeError as exc: - raise ValueError(f"Invalid cast_config JSON: {exc}") from exc - if not isinstance(raw, dict): - raise ValueError("cast_config must be a JSON object") - return json.loads(build_cast_config_json(str(raw.get("cast_mode") or "custom_counts"), raw.get("women_count", 1), raw.get("men_count", 1))) + return category_cast_policy.parse_cast_config(cast_config) def build_generation_profile_json( diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 88bb8e3..5151e38 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -24,6 +24,7 @@ if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import caption_naturalizer # noqa: E402 +import category_cast_config # noqa: E402 import category_library # noqa: E402 import __init__ as sxcp_nodes # noqa: E402 import krea_formatter # noqa: E402 @@ -554,6 +555,29 @@ def smoke_location_config_policy() -> None: _expect(json.loads(themed_composition).get("composition_entries"), "Themed location did not output compositions") +def smoke_category_cast_config_policy() -> None: + _expect(pb.CATEGORY_PRESETS is category_cast_config.CATEGORY_PRESETS, "Prompt builder category presets are not delegated") + _expect(pb.CAST_PRESETS is category_cast_config.CAST_PRESETS, "Prompt builder cast presets are not delegated") + _expect("hardcore_pose" in category_cast_config.category_preset_choices(), "Category preset choices lost hardcore_pose") + _expect("custom_counts" in category_cast_config.cast_preset_choices(), "Cast preset choices lost custom_counts") + + category_config = json.loads(pb.build_category_config_json("hardcore_pose", "Foreplay and teasing")) + _expect(category_config.get("category") == "Hardcore sexual poses", "Category config lost hardcore category mapping") + _expect(category_config.get("subcategory") == "Foreplay and teasing", "Category config lost explicit subcategory") + _expect(pb._parse_category_config(category_config) == ("Hardcore sexual poses", "Foreplay and teasing"), "Category parser wrapper drifted") + + fallback_config = json.loads(category_cast_config.build_category_config_json("unknown", "random")) + _expect(fallback_config.get("preset") == "auto_weighted", "Unknown category preset did not fall back") + _expect(pb._parse_category_config({"preset": "unknown"}) == ("auto_weighted", "random"), "Unknown category parser fallback changed") + + cast_config = json.loads(pb.build_cast_config_json("mixed_couple", 9, 9)) + _expect((cast_config.get("women_count"), cast_config.get("men_count")) == (1, 1), "Cast preset did not override manual counts") + custom_cast = json.loads(category_cast_config.build_cast_config_json("custom_counts", -5, 99)) + _expect((custom_cast.get("women_count"), custom_cast.get("men_count")) == (0, 12), "Custom cast counts were not clamped") + empty_cast = pb._parse_cast_config({"cast_mode": "custom_counts", "women_count": 0, "men_count": 0}) + _expect((empty_cast.get("women_count"), empty_cast.get("men_count")) == (1, 0), "Empty custom cast was not corrected") + + def smoke_category_library_route() -> None: categories = category_library.load_category_library() _expect(len(categories) >= 3, "category library should load JSON categories") @@ -2461,6 +2485,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [ ("camera_scene_single", smoke_camera_scene_single), ("config_route_location_theme", smoke_config_route_location_theme), ("location_config_policy", smoke_location_config_policy), + ("category_cast_config_policy", smoke_category_cast_config_policy), ("category_library_route", smoke_category_library_route), ("hardcore_category_routes", smoke_hardcore_category_routes), ("krea_close_foreplay_route", smoke_krea_close_foreplay_route),