Extract category cast config policy
This commit is contained in:
@@ -0,0 +1,114 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
RANDOM_SUBCATEGORY = "random"
|
||||||
|
|
||||||
|
CATEGORY_PRESETS = {
|
||||||
|
"auto_weighted": ("auto_weighted", RANDOM_SUBCATEGORY),
|
||||||
|
"auto_full": ("auto_full", RANDOM_SUBCATEGORY),
|
||||||
|
"women_casual": ("Casual clothes", RANDOM_SUBCATEGORY),
|
||||||
|
"men_casual": ("Men casual clothes", RANDOM_SUBCATEGORY),
|
||||||
|
"couple_casual": ("Couple casual clothes", RANDOM_SUBCATEGORY),
|
||||||
|
"provocative_erotic": ("Provocative erotic clothes", RANDOM_SUBCATEGORY),
|
||||||
|
"hardcore_pose": ("Hardcore sexual poses", RANDOM_SUBCATEGORY),
|
||||||
|
"custom_random": ("custom_random", RANDOM_SUBCATEGORY),
|
||||||
|
}
|
||||||
|
|
||||||
|
CAST_PRESETS = {
|
||||||
|
"solo_woman": (1, 0),
|
||||||
|
"solo_man": (0, 1),
|
||||||
|
"mixed_couple": (1, 1),
|
||||||
|
"two_women": (2, 0),
|
||||||
|
"two_men": (0, 2),
|
||||||
|
"threesome_2w1m": (2, 1),
|
||||||
|
"small_group_3w2m": (3, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def category_preset_choices() -> list[str]:
|
||||||
|
return list(CATEGORY_PRESETS)
|
||||||
|
|
||||||
|
|
||||||
|
def cast_preset_choices() -> list[str]:
|
||||||
|
return list(CAST_PRESETS) + ["custom_counts"]
|
||||||
|
|
||||||
|
|
||||||
|
def build_category_config_json(preset: str = "auto_weighted", subcategory: str = RANDOM_SUBCATEGORY) -> str:
|
||||||
|
category, default_subcategory = CATEGORY_PRESETS.get(preset, CATEGORY_PRESETS["auto_weighted"])
|
||||||
|
chosen_subcategory = subcategory if subcategory and subcategory != RANDOM_SUBCATEGORY else default_subcategory
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"preset": preset if preset in CATEGORY_PRESETS else "auto_weighted",
|
||||||
|
"category": category,
|
||||||
|
"subcategory": chosen_subcategory,
|
||||||
|
},
|
||||||
|
ensure_ascii=True,
|
||||||
|
sort_keys=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_category_config(category_config: str | dict[str, Any] | None) -> tuple[str, str]:
|
||||||
|
if not category_config:
|
||||||
|
return CATEGORY_PRESETS["auto_weighted"]
|
||||||
|
if isinstance(category_config, dict):
|
||||||
|
raw = category_config
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
raw = json.loads(str(category_config))
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValueError(f"Invalid category_config JSON: {exc}") from exc
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
raise ValueError("category_config must be a JSON object")
|
||||||
|
preset = str(raw.get("preset") or "auto_weighted")
|
||||||
|
category, subcategory = CATEGORY_PRESETS.get(preset, CATEGORY_PRESETS["auto_weighted"])
|
||||||
|
category = str(raw.get("category") or category)
|
||||||
|
subcategory = str(raw.get("subcategory") or subcategory or RANDOM_SUBCATEGORY)
|
||||||
|
return category, subcategory
|
||||||
|
|
||||||
|
|
||||||
|
def build_cast_config_json(cast_mode: str = "mixed_couple", women_count: int = 1, men_count: int = 1) -> str:
|
||||||
|
if cast_mode in CAST_PRESETS:
|
||||||
|
women_count, men_count = CAST_PRESETS[cast_mode]
|
||||||
|
else:
|
||||||
|
women_count = max(0, min(12, int(women_count)))
|
||||||
|
men_count = max(0, min(12, int(men_count)))
|
||||||
|
if women_count + men_count == 0:
|
||||||
|
women_count = 1
|
||||||
|
cast_mode = "custom_counts"
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"cast_mode": cast_mode,
|
||||||
|
"women_count": int(women_count),
|
||||||
|
"men_count": int(men_count),
|
||||||
|
},
|
||||||
|
ensure_ascii=True,
|
||||||
|
sort_keys=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cast_config(cast_config: str | dict[str, Any] | None) -> dict[str, int | str]:
|
||||||
|
if not cast_config:
|
||||||
|
return {"cast_mode": "mixed_couple", "women_count": 1, "men_count": 1}
|
||||||
|
if isinstance(cast_config, dict):
|
||||||
|
raw = cast_config
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
raw = json.loads(str(cast_config))
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValueError(f"Invalid cast_config JSON: {exc}") from exc
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
raise ValueError("cast_config must be a JSON object")
|
||||||
|
return json.loads(
|
||||||
|
build_cast_config_json(
|
||||||
|
str(raw.get("cast_mode") or "custom_counts"),
|
||||||
|
raw.get("women_count", 1),
|
||||||
|
raw.get("men_count", 1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_parse_category_config = parse_category_config
|
||||||
|
_parse_cast_config = parse_cast_config
|
||||||
@@ -104,6 +104,9 @@ Already isolated:
|
|||||||
- JSON category loading, subcategory normalization, named scene/expression/
|
- JSON category loading, subcategory normalization, named scene/expression/
|
||||||
composition pool loading, cast compatibility filtering, exact subcategory
|
composition pool loading, cast compatibility filtering, exact subcategory
|
||||||
lookup, and inheritance-based pool merging live in `category_library.py`.
|
lookup, and inheritance-based pool merging live in `category_library.py`.
|
||||||
|
- category/cast route preset schemas, config JSON builders, choice lists, and
|
||||||
|
parsers live in `category_cast_config.py`; `prompt_builder.py` keeps public
|
||||||
|
delegate wrappers for existing nodes and tests.
|
||||||
- location/composition config presets, themed location packs, custom
|
- location/composition config presets, themed location packs, custom
|
||||||
location/composition entry parsing, merge behavior, and config parsing live
|
location/composition entry parsing, merge behavior, and config parsing live
|
||||||
in `location_config.py`; `prompt_builder.py` still applies selected configs
|
in `location_config.py`; `prompt_builder.py` still applies selected configs
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ Core helper ownership:
|
|||||||
| Python module | What it owns |
|
| Python module | What it owns |
|
||||||
| --- | --- |
|
| --- | --- |
|
||||||
| `category_library.py` | JSON category loading, subcategory normalization, named scene/expression/composition pool loading, cast compatibility filtering, exact subcategory lookup, and inheritance-based pool merging. |
|
| `category_library.py` | JSON category loading, subcategory normalization, named scene/expression/composition pool loading, cast compatibility filtering, exact subcategory lookup, and inheritance-based pool merging. |
|
||||||
|
| `category_cast_config.py` | Category preset and cast preset schemas, category/cast config JSON builders, choice lists, and config parsers used by route nodes. |
|
||||||
| `camera_config.py` | Camera option schema, direct/orbit/Qwen camera JSON builders, camera config parsing, plain camera directive text, and camera caption labels. |
|
| `camera_config.py` | Camera option schema, direct/orbit/Qwen camera JSON builders, camera config parsing, plain camera directive text, and camera caption labels. |
|
||||||
| `seed_config.py` | Seed axis salts/aliases, seed mode choices, global/axis lock JSON builders, seed config parsing, row seed math, and deterministic axis RNG construction. |
|
| `seed_config.py` | Seed axis salts/aliases, seed mode choices, global/axis lock JSON builders, seed config parsing, row seed math, and deterministic axis RNG construction. |
|
||||||
| `location_config.py` | Location/composition preset schemas, themed location packs, custom location/composition parsing, pool merge behavior, and location/composition config parsing. |
|
| `location_config.py` | Location/composition preset schemas, themed location packs, custom location/composition parsing, pool merge behavior, and location/composition config parsing. |
|
||||||
|
|||||||
@@ -4,11 +4,13 @@ import json
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .prompt_builder import (
|
from .category_cast_config import (
|
||||||
build_cast_config_json,
|
build_cast_config_json,
|
||||||
build_category_config_json,
|
build_category_config_json,
|
||||||
cast_preset_choices,
|
cast_preset_choices,
|
||||||
category_preset_choices,
|
category_preset_choices,
|
||||||
|
)
|
||||||
|
from .prompt_builder import (
|
||||||
subcategory_choices,
|
subcategory_choices,
|
||||||
)
|
)
|
||||||
from .location_config import (
|
from .location_config import (
|
||||||
@@ -20,11 +22,13 @@ try:
|
|||||||
location_theme_choices,
|
location_theme_choices,
|
||||||
)
|
)
|
||||||
except ImportError: # Allows local smoke tests from the repository root.
|
except ImportError: # Allows local smoke tests from the repository root.
|
||||||
from prompt_builder import (
|
from category_cast_config import (
|
||||||
build_cast_config_json,
|
build_cast_config_json,
|
||||||
build_category_config_json,
|
build_category_config_json,
|
||||||
cast_preset_choices,
|
cast_preset_choices,
|
||||||
category_preset_choices,
|
category_preset_choices,
|
||||||
|
)
|
||||||
|
from prompt_builder import (
|
||||||
subcategory_choices,
|
subcategory_choices,
|
||||||
)
|
)
|
||||||
from location_config import (
|
from location_config import (
|
||||||
|
|||||||
+10
-78
@@ -24,6 +24,7 @@ try:
|
|||||||
template_list as _template_list,
|
template_list as _template_list,
|
||||||
)
|
)
|
||||||
from . import camera_config as camera_policy
|
from . import camera_config as camera_policy
|
||||||
|
from . import category_cast_config as category_cast_policy
|
||||||
from . import generate_prompt_batches as g
|
from . import generate_prompt_batches as g
|
||||||
from . import location_config as location_policy
|
from . import location_config as location_policy
|
||||||
from . import pair_clothing
|
from . import pair_clothing
|
||||||
@@ -62,6 +63,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
|
|||||||
template_list as _template_list,
|
template_list as _template_list,
|
||||||
)
|
)
|
||||||
import camera_config as camera_policy
|
import camera_config as camera_policy
|
||||||
|
import category_cast_config as category_cast_policy
|
||||||
import generate_prompt_batches as g
|
import generate_prompt_batches as g
|
||||||
import location_config as location_policy
|
import location_config as location_policy
|
||||||
import pair_clothing
|
import pair_clothing
|
||||||
@@ -1030,26 +1032,8 @@ def seed_mode_choices() -> list[str]:
|
|||||||
return seed_policy.seed_mode_choices()
|
return seed_policy.seed_mode_choices()
|
||||||
|
|
||||||
|
|
||||||
CATEGORY_PRESETS = {
|
CATEGORY_PRESETS = category_cast_policy.CATEGORY_PRESETS
|
||||||
"auto_weighted": ("auto_weighted", RANDOM_SUBCATEGORY),
|
CAST_PRESETS = category_cast_policy.CAST_PRESETS
|
||||||
"auto_full": ("auto_full", RANDOM_SUBCATEGORY),
|
|
||||||
"women_casual": ("Casual clothes", RANDOM_SUBCATEGORY),
|
|
||||||
"men_casual": ("Men casual clothes", RANDOM_SUBCATEGORY),
|
|
||||||
"couple_casual": ("Couple casual clothes", RANDOM_SUBCATEGORY),
|
|
||||||
"provocative_erotic": ("Provocative erotic clothes", RANDOM_SUBCATEGORY),
|
|
||||||
"hardcore_pose": ("Hardcore sexual poses", RANDOM_SUBCATEGORY),
|
|
||||||
"custom_random": ("custom_random", RANDOM_SUBCATEGORY),
|
|
||||||
}
|
|
||||||
|
|
||||||
CAST_PRESETS = {
|
|
||||||
"solo_woman": (1, 0),
|
|
||||||
"solo_man": (0, 1),
|
|
||||||
"mixed_couple": (1, 1),
|
|
||||||
"two_women": (2, 0),
|
|
||||||
"two_men": (0, 2),
|
|
||||||
"threesome_2w1m": (2, 1),
|
|
||||||
"small_group_3w2m": (3, 2),
|
|
||||||
}
|
|
||||||
|
|
||||||
GENERATION_PROFILE_PRESETS = {
|
GENERATION_PROFILE_PRESETS = {
|
||||||
"balanced": {
|
"balanced": {
|
||||||
@@ -1122,11 +1106,11 @@ GENERATION_PROFILE_PRESETS = {
|
|||||||
|
|
||||||
|
|
||||||
def category_preset_choices() -> list[str]:
|
def category_preset_choices() -> list[str]:
|
||||||
return list(CATEGORY_PRESETS)
|
return category_cast_policy.category_preset_choices()
|
||||||
|
|
||||||
|
|
||||||
def cast_preset_choices() -> list[str]:
|
def cast_preset_choices() -> list[str]:
|
||||||
return list(CAST_PRESETS) + ["custom_counts"]
|
return category_cast_policy.cast_preset_choices()
|
||||||
|
|
||||||
|
|
||||||
def generation_profile_choices() -> list[str]:
|
def generation_profile_choices() -> list[str]:
|
||||||
@@ -1134,71 +1118,19 @@ def generation_profile_choices() -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
def build_category_config_json(preset: str = "auto_weighted", subcategory: str = RANDOM_SUBCATEGORY) -> str:
|
def build_category_config_json(preset: str = "auto_weighted", subcategory: str = RANDOM_SUBCATEGORY) -> str:
|
||||||
category, default_subcategory = CATEGORY_PRESETS.get(preset, CATEGORY_PRESETS["auto_weighted"])
|
return category_cast_policy.build_category_config_json(preset=preset, subcategory=subcategory)
|
||||||
chosen_subcategory = subcategory if subcategory and subcategory != RANDOM_SUBCATEGORY else default_subcategory
|
|
||||||
return json.dumps(
|
|
||||||
{
|
|
||||||
"preset": preset if preset in CATEGORY_PRESETS else "auto_weighted",
|
|
||||||
"category": category,
|
|
||||||
"subcategory": chosen_subcategory,
|
|
||||||
},
|
|
||||||
ensure_ascii=True,
|
|
||||||
sort_keys=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_category_config(category_config: str | dict[str, Any] | None) -> tuple[str, str]:
|
def _parse_category_config(category_config: str | dict[str, Any] | None) -> tuple[str, str]:
|
||||||
if not category_config:
|
return category_cast_policy.parse_category_config(category_config)
|
||||||
return CATEGORY_PRESETS["auto_weighted"]
|
|
||||||
if isinstance(category_config, dict):
|
|
||||||
raw = category_config
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
raw = json.loads(str(category_config))
|
|
||||||
except json.JSONDecodeError as exc:
|
|
||||||
raise ValueError(f"Invalid category_config JSON: {exc}") from exc
|
|
||||||
if not isinstance(raw, dict):
|
|
||||||
raise ValueError("category_config must be a JSON object")
|
|
||||||
preset = str(raw.get("preset") or "auto_weighted")
|
|
||||||
category, subcategory = CATEGORY_PRESETS.get(preset, CATEGORY_PRESETS["auto_weighted"])
|
|
||||||
category = str(raw.get("category") or category)
|
|
||||||
subcategory = str(raw.get("subcategory") or subcategory or RANDOM_SUBCATEGORY)
|
|
||||||
return category, subcategory
|
|
||||||
|
|
||||||
|
|
||||||
def build_cast_config_json(cast_mode: str = "mixed_couple", women_count: int = 1, men_count: int = 1) -> str:
|
def build_cast_config_json(cast_mode: str = "mixed_couple", women_count: int = 1, men_count: int = 1) -> str:
|
||||||
if cast_mode in CAST_PRESETS:
|
return category_cast_policy.build_cast_config_json(cast_mode=cast_mode, women_count=women_count, men_count=men_count)
|
||||||
women_count, men_count = CAST_PRESETS[cast_mode]
|
|
||||||
else:
|
|
||||||
women_count = max(0, min(12, int(women_count)))
|
|
||||||
men_count = max(0, min(12, int(men_count)))
|
|
||||||
if women_count + men_count == 0:
|
|
||||||
women_count = 1
|
|
||||||
cast_mode = "custom_counts"
|
|
||||||
return json.dumps(
|
|
||||||
{
|
|
||||||
"cast_mode": cast_mode,
|
|
||||||
"women_count": int(women_count),
|
|
||||||
"men_count": int(men_count),
|
|
||||||
},
|
|
||||||
ensure_ascii=True,
|
|
||||||
sort_keys=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_cast_config(cast_config: str | dict[str, Any] | None) -> dict[str, int | str]:
|
def _parse_cast_config(cast_config: str | dict[str, Any] | None) -> dict[str, int | str]:
|
||||||
if not cast_config:
|
return category_cast_policy.parse_cast_config(cast_config)
|
||||||
return {"cast_mode": "mixed_couple", "women_count": 1, "men_count": 1}
|
|
||||||
if isinstance(cast_config, dict):
|
|
||||||
raw = cast_config
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
raw = json.loads(str(cast_config))
|
|
||||||
except json.JSONDecodeError as exc:
|
|
||||||
raise ValueError(f"Invalid cast_config JSON: {exc}") from exc
|
|
||||||
if not isinstance(raw, dict):
|
|
||||||
raise ValueError("cast_config must be a JSON object")
|
|
||||||
return json.loads(build_cast_config_json(str(raw.get("cast_mode") or "custom_counts"), raw.get("women_count", 1), raw.get("men_count", 1)))
|
|
||||||
|
|
||||||
|
|
||||||
def build_generation_profile_json(
|
def build_generation_profile_json(
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ if str(ROOT) not in sys.path:
|
|||||||
sys.path.insert(0, str(ROOT))
|
sys.path.insert(0, str(ROOT))
|
||||||
|
|
||||||
import caption_naturalizer # noqa: E402
|
import caption_naturalizer # noqa: E402
|
||||||
|
import category_cast_config # noqa: E402
|
||||||
import category_library # noqa: E402
|
import category_library # noqa: E402
|
||||||
import __init__ as sxcp_nodes # noqa: E402
|
import __init__ as sxcp_nodes # noqa: E402
|
||||||
import krea_formatter # noqa: E402
|
import krea_formatter # noqa: E402
|
||||||
@@ -554,6 +555,29 @@ def smoke_location_config_policy() -> None:
|
|||||||
_expect(json.loads(themed_composition).get("composition_entries"), "Themed location did not output compositions")
|
_expect(json.loads(themed_composition).get("composition_entries"), "Themed location did not output compositions")
|
||||||
|
|
||||||
|
|
||||||
|
def smoke_category_cast_config_policy() -> None:
|
||||||
|
_expect(pb.CATEGORY_PRESETS is category_cast_config.CATEGORY_PRESETS, "Prompt builder category presets are not delegated")
|
||||||
|
_expect(pb.CAST_PRESETS is category_cast_config.CAST_PRESETS, "Prompt builder cast presets are not delegated")
|
||||||
|
_expect("hardcore_pose" in category_cast_config.category_preset_choices(), "Category preset choices lost hardcore_pose")
|
||||||
|
_expect("custom_counts" in category_cast_config.cast_preset_choices(), "Cast preset choices lost custom_counts")
|
||||||
|
|
||||||
|
category_config = json.loads(pb.build_category_config_json("hardcore_pose", "Foreplay and teasing"))
|
||||||
|
_expect(category_config.get("category") == "Hardcore sexual poses", "Category config lost hardcore category mapping")
|
||||||
|
_expect(category_config.get("subcategory") == "Foreplay and teasing", "Category config lost explicit subcategory")
|
||||||
|
_expect(pb._parse_category_config(category_config) == ("Hardcore sexual poses", "Foreplay and teasing"), "Category parser wrapper drifted")
|
||||||
|
|
||||||
|
fallback_config = json.loads(category_cast_config.build_category_config_json("unknown", "random"))
|
||||||
|
_expect(fallback_config.get("preset") == "auto_weighted", "Unknown category preset did not fall back")
|
||||||
|
_expect(pb._parse_category_config({"preset": "unknown"}) == ("auto_weighted", "random"), "Unknown category parser fallback changed")
|
||||||
|
|
||||||
|
cast_config = json.loads(pb.build_cast_config_json("mixed_couple", 9, 9))
|
||||||
|
_expect((cast_config.get("women_count"), cast_config.get("men_count")) == (1, 1), "Cast preset did not override manual counts")
|
||||||
|
custom_cast = json.loads(category_cast_config.build_cast_config_json("custom_counts", -5, 99))
|
||||||
|
_expect((custom_cast.get("women_count"), custom_cast.get("men_count")) == (0, 12), "Custom cast counts were not clamped")
|
||||||
|
empty_cast = pb._parse_cast_config({"cast_mode": "custom_counts", "women_count": 0, "men_count": 0})
|
||||||
|
_expect((empty_cast.get("women_count"), empty_cast.get("men_count")) == (1, 0), "Empty custom cast was not corrected")
|
||||||
|
|
||||||
|
|
||||||
def smoke_category_library_route() -> None:
|
def smoke_category_library_route() -> None:
|
||||||
categories = category_library.load_category_library()
|
categories = category_library.load_category_library()
|
||||||
_expect(len(categories) >= 3, "category library should load JSON categories")
|
_expect(len(categories) >= 3, "category library should load JSON categories")
|
||||||
@@ -2461,6 +2485,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
|
|||||||
("camera_scene_single", smoke_camera_scene_single),
|
("camera_scene_single", smoke_camera_scene_single),
|
||||||
("config_route_location_theme", smoke_config_route_location_theme),
|
("config_route_location_theme", smoke_config_route_location_theme),
|
||||||
("location_config_policy", smoke_location_config_policy),
|
("location_config_policy", smoke_location_config_policy),
|
||||||
|
("category_cast_config_policy", smoke_category_cast_config_policy),
|
||||||
("category_library_route", smoke_category_library_route),
|
("category_library_route", smoke_category_library_route),
|
||||||
("hardcore_category_routes", smoke_hardcore_category_routes),
|
("hardcore_category_routes", smoke_hardcore_category_routes),
|
||||||
("krea_close_foreplay_route", smoke_krea_close_foreplay_route),
|
("krea_close_foreplay_route", smoke_krea_close_foreplay_route),
|
||||||
|
|||||||
Reference in New Issue
Block a user