From 3c28de37126e089c230b5674bb5a4789bfa7c4fc Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Wed, 24 Jun 2026 11:28:11 +0200 Subject: [PATCH] Fix custom cast count selection --- README.md | 9 +++++-- prompt_builder.py | 60 +++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index bfd37d8..fd28edb 100644 --- a/README.md +++ b/README.md @@ -375,8 +375,13 @@ Supported constraints: - `cast` or `requires`: `women_only`, `men_only`, `mixed`, `has_women`, `has_men`, `solo`, `couple`, `threesome`, `group` -If an exact subcategory is not compatible with `women_count` and `men_count`, -the node raises a clear error instead of generating an impossible prompt. +If an exact subcategory has a larger minimum cast size than the current +`women_count` and `men_count`, the node raises the effective cast count to that +minimum instead of failing. The original and effective counts are recorded in +`metadata_json.cast_count_adjustment`. Other impossible cast constraints still +raise a clear error instead of generating an impossible prompt. +When both cast counts are `0`, custom category selection treats the effective +configured cast as one adult woman so random filtering still has a valid cast. Use the `subcategory` dropdown to select either `random` or an exact `Main category / Subcategory` path. Exact paths override the `category` dropdown, diff --git a/prompt_builder.py b/prompt_builder.py index fa1e037..75895d6 100644 --- a/prompt_builder.py +++ b/prompt_builder.py @@ -1055,6 +1055,39 @@ def _find_category(categories: list[dict[str, Any]], name_or_slug: str) -> dict[ return None +def _base_cast_counts(women_count: int, men_count: int) -> tuple[int, int]: + women_count = max(0, int(women_count)) + men_count = max(0, int(men_count)) + if women_count + men_count == 0: + women_count = 1 + return women_count, men_count + + +def _counts_for_exact_subcategory( + subcategory: dict[str, Any], + women_count: int, + men_count: int, +) -> tuple[int, int]: + women_count, men_count = _base_cast_counts(women_count, men_count) + + min_women = _constraint_int(subcategory, "min_women") + if min_women is not None and women_count < min_women: + women_count = min_women + min_men = _constraint_int(subcategory, "min_men") + if min_men is not None and men_count < min_men: + men_count = min_men + + min_people = _constraint_int(subcategory, "min_people") + if min_people is not None: + missing = min_people - (women_count + men_count) + if missing > 0: + if women_count > 0 or men_count == 0: + women_count += missing + else: + men_count += missing + return women_count, men_count + + def _find_subcategory( categories: list[dict[str, Any]], category_choice: str, @@ -1063,7 +1096,8 @@ def _find_subcategory( subcategory_rng: random.Random, women_count: int = 1, men_count: int = 1, -) -> tuple[dict[str, Any], dict[str, Any]]: +) -> tuple[dict[str, Any], dict[str, Any], int, int]: + women_count, men_count = _base_cast_counts(women_count, men_count) if subcategory_choice and subcategory_choice != RANDOM_SUBCATEGORY and " / " in subcategory_choice: category_name, subcategory_name = subcategory_choice.split(" / ", 1) category = _find_category(categories, category_name) @@ -1072,12 +1106,17 @@ def _find_subcategory( wanted = subcategory_name.strip().lower() for subcategory in category["subcategories"]: if subcategory["name"].lower() == wanted or subcategory["slug"].lower() == wanted: - if not _compatible_entry(subcategory, women_count, men_count): + adjusted_women_count, adjusted_men_count = _counts_for_exact_subcategory( + subcategory, + women_count, + men_count, + ) + if not _compatible_entry(subcategory, adjusted_women_count, adjusted_men_count): raise ValueError( f"Subcategory '{subcategory['name']}' is not compatible with " f"women_count={women_count}, men_count={men_count}" ) - return category, subcategory + return category, subcategory, adjusted_women_count, adjusted_men_count raise ValueError(f"Unknown subcategory '{subcategory_name}' for category '{category_name}'") if category_choice == "custom_random": @@ -1090,7 +1129,7 @@ def _find_subcategory( raise ValueError(f"Unknown custom category: {category_choice}") subcategories = _compatible_entries(category["subcategories"], women_count, men_count) subcategory = _weighted_choice(subcategory_rng, subcategories) - return category, subcategory + return category, subcategory, women_count, men_count def _merged_field(category: dict[str, Any], subcategory: dict[str, Any], item: Any, key: str, default: Any = None) -> Any: @@ -1680,7 +1719,9 @@ def _build_custom_row( expression_rng = _axis_rng(seed_config, "expression", seed, row_number) composition_rng = _axis_rng(seed_config, "composition", seed, row_number) - category, subcategory = _find_subcategory( + requested_women_count = women_count + requested_men_count = men_count + category, subcategory, women_count, men_count = _find_subcategory( categories, category_choice, subcategory_choice, @@ -1689,6 +1730,14 @@ def _build_custom_row( women_count, men_count, ) + count_adjustment = {} + if women_count != requested_women_count or men_count != requested_men_count: + count_adjustment = { + "requested_women_count": requested_women_count, + "requested_men_count": requested_men_count, + "effective_women_count": women_count, + "effective_men_count": men_count, + } content_axis = "pose" if _is_pose_content_category(category, subcategory) else "content" content_rng = _axis_rng(seed_config, content_axis, seed, row_number) items = _list_from(subcategory.get("items", [subcategory["name"]])) @@ -1809,6 +1858,7 @@ def _build_custom_row( "women_count": context.get("women_count", ""), "men_count": context.get("men_count", ""), "person_count": context.get("person_count", ""), + "cast_count_adjustment": count_adjustment if subject_type == "configured_cast" else {}, "source": "json_category", } )