Extract category cast config policy

This commit is contained in:
2026-06-27 00:22:17 +02:00
parent f3f9929df5
commit 4c31553409
6 changed files with 159 additions and 80 deletions
+10 -78
View File
@@ -24,6 +24,7 @@ try:
template_list as _template_list,
)
from . import camera_config as camera_policy
from . import category_cast_config as category_cast_policy
from . import generate_prompt_batches as g
from . import location_config as location_policy
from . import pair_clothing
@@ -62,6 +63,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
template_list as _template_list,
)
import camera_config as camera_policy
import category_cast_config as category_cast_policy
import generate_prompt_batches as g
import location_config as location_policy
import pair_clothing
@@ -1030,26 +1032,8 @@ def seed_mode_choices() -> list[str]:
return seed_policy.seed_mode_choices()
CATEGORY_PRESETS = {
"auto_weighted": ("auto_weighted", RANDOM_SUBCATEGORY),
"auto_full": ("auto_full", RANDOM_SUBCATEGORY),
"women_casual": ("Casual clothes", RANDOM_SUBCATEGORY),
"men_casual": ("Men casual clothes", RANDOM_SUBCATEGORY),
"couple_casual": ("Couple casual clothes", RANDOM_SUBCATEGORY),
"provocative_erotic": ("Provocative erotic clothes", RANDOM_SUBCATEGORY),
"hardcore_pose": ("Hardcore sexual poses", RANDOM_SUBCATEGORY),
"custom_random": ("custom_random", RANDOM_SUBCATEGORY),
}
CAST_PRESETS = {
"solo_woman": (1, 0),
"solo_man": (0, 1),
"mixed_couple": (1, 1),
"two_women": (2, 0),
"two_men": (0, 2),
"threesome_2w1m": (2, 1),
"small_group_3w2m": (3, 2),
}
CATEGORY_PRESETS = category_cast_policy.CATEGORY_PRESETS
CAST_PRESETS = category_cast_policy.CAST_PRESETS
GENERATION_PROFILE_PRESETS = {
"balanced": {
@@ -1122,11 +1106,11 @@ GENERATION_PROFILE_PRESETS = {
def category_preset_choices() -> list[str]:
return list(CATEGORY_PRESETS)
return category_cast_policy.category_preset_choices()
def cast_preset_choices() -> list[str]:
return list(CAST_PRESETS) + ["custom_counts"]
return category_cast_policy.cast_preset_choices()
def generation_profile_choices() -> list[str]:
@@ -1134,71 +1118,19 @@ def generation_profile_choices() -> list[str]:
def build_category_config_json(preset: str = "auto_weighted", subcategory: str = RANDOM_SUBCATEGORY) -> str:
category, default_subcategory = CATEGORY_PRESETS.get(preset, CATEGORY_PRESETS["auto_weighted"])
chosen_subcategory = subcategory if subcategory and subcategory != RANDOM_SUBCATEGORY else default_subcategory
return json.dumps(
{
"preset": preset if preset in CATEGORY_PRESETS else "auto_weighted",
"category": category,
"subcategory": chosen_subcategory,
},
ensure_ascii=True,
sort_keys=True,
)
return category_cast_policy.build_category_config_json(preset=preset, subcategory=subcategory)
def _parse_category_config(category_config: str | dict[str, Any] | None) -> tuple[str, str]:
if not category_config:
return CATEGORY_PRESETS["auto_weighted"]
if isinstance(category_config, dict):
raw = category_config
else:
try:
raw = json.loads(str(category_config))
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid category_config JSON: {exc}") from exc
if not isinstance(raw, dict):
raise ValueError("category_config must be a JSON object")
preset = str(raw.get("preset") or "auto_weighted")
category, subcategory = CATEGORY_PRESETS.get(preset, CATEGORY_PRESETS["auto_weighted"])
category = str(raw.get("category") or category)
subcategory = str(raw.get("subcategory") or subcategory or RANDOM_SUBCATEGORY)
return category, subcategory
return category_cast_policy.parse_category_config(category_config)
def build_cast_config_json(cast_mode: str = "mixed_couple", women_count: int = 1, men_count: int = 1) -> str:
if cast_mode in CAST_PRESETS:
women_count, men_count = CAST_PRESETS[cast_mode]
else:
women_count = max(0, min(12, int(women_count)))
men_count = max(0, min(12, int(men_count)))
if women_count + men_count == 0:
women_count = 1
cast_mode = "custom_counts"
return json.dumps(
{
"cast_mode": cast_mode,
"women_count": int(women_count),
"men_count": int(men_count),
},
ensure_ascii=True,
sort_keys=True,
)
return category_cast_policy.build_cast_config_json(cast_mode=cast_mode, women_count=women_count, men_count=men_count)
def _parse_cast_config(cast_config: str | dict[str, Any] | None) -> dict[str, int | str]:
if not cast_config:
return {"cast_mode": "mixed_couple", "women_count": 1, "men_count": 1}
if isinstance(cast_config, dict):
raw = cast_config
else:
try:
raw = json.loads(str(cast_config))
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid cast_config JSON: {exc}") from exc
if not isinstance(raw, dict):
raise ValueError("cast_config must be a JSON object")
return json.loads(build_cast_config_json(str(raw.get("cast_mode") or "custom_counts"), raw.get("women_count", 1), raw.get("men_count", 1)))
return category_cast_policy.parse_cast_config(cast_config)
def build_generation_profile_json(