Fix custom cast count selection

This commit is contained in:
2026-06-24 11:28:11 +02:00
parent af8fe355f7
commit 3c28de3712
2 changed files with 62 additions and 7 deletions
+55 -5
View File
@@ -1055,6 +1055,39 @@ def _find_category(categories: list[dict[str, Any]], name_or_slug: str) -> dict[
return None
def _base_cast_counts(women_count: int, men_count: int) -> tuple[int, int]:
women_count = max(0, int(women_count))
men_count = max(0, int(men_count))
if women_count + men_count == 0:
women_count = 1
return women_count, men_count
def _counts_for_exact_subcategory(
subcategory: dict[str, Any],
women_count: int,
men_count: int,
) -> tuple[int, int]:
women_count, men_count = _base_cast_counts(women_count, men_count)
min_women = _constraint_int(subcategory, "min_women")
if min_women is not None and women_count < min_women:
women_count = min_women
min_men = _constraint_int(subcategory, "min_men")
if min_men is not None and men_count < min_men:
men_count = min_men
min_people = _constraint_int(subcategory, "min_people")
if min_people is not None:
missing = min_people - (women_count + men_count)
if missing > 0:
if women_count > 0 or men_count == 0:
women_count += missing
else:
men_count += missing
return women_count, men_count
def _find_subcategory(
categories: list[dict[str, Any]],
category_choice: str,
@@ -1063,7 +1096,8 @@ def _find_subcategory(
subcategory_rng: random.Random,
women_count: int = 1,
men_count: int = 1,
) -> tuple[dict[str, Any], dict[str, Any]]:
) -> tuple[dict[str, Any], dict[str, Any], int, int]:
women_count, men_count = _base_cast_counts(women_count, men_count)
if subcategory_choice and subcategory_choice != RANDOM_SUBCATEGORY and " / " in subcategory_choice:
category_name, subcategory_name = subcategory_choice.split(" / ", 1)
category = _find_category(categories, category_name)
@@ -1072,12 +1106,17 @@ def _find_subcategory(
wanted = subcategory_name.strip().lower()
for subcategory in category["subcategories"]:
if subcategory["name"].lower() == wanted or subcategory["slug"].lower() == wanted:
if not _compatible_entry(subcategory, women_count, men_count):
adjusted_women_count, adjusted_men_count = _counts_for_exact_subcategory(
subcategory,
women_count,
men_count,
)
if not _compatible_entry(subcategory, adjusted_women_count, adjusted_men_count):
raise ValueError(
f"Subcategory '{subcategory['name']}' is not compatible with "
f"women_count={women_count}, men_count={men_count}"
)
return category, subcategory
return category, subcategory, adjusted_women_count, adjusted_men_count
raise ValueError(f"Unknown subcategory '{subcategory_name}' for category '{category_name}'")
if category_choice == "custom_random":
@@ -1090,7 +1129,7 @@ def _find_subcategory(
raise ValueError(f"Unknown custom category: {category_choice}")
subcategories = _compatible_entries(category["subcategories"], women_count, men_count)
subcategory = _weighted_choice(subcategory_rng, subcategories)
return category, subcategory
return category, subcategory, women_count, men_count
def _merged_field(category: dict[str, Any], subcategory: dict[str, Any], item: Any, key: str, default: Any = None) -> Any:
@@ -1680,7 +1719,9 @@ def _build_custom_row(
expression_rng = _axis_rng(seed_config, "expression", seed, row_number)
composition_rng = _axis_rng(seed_config, "composition", seed, row_number)
category, subcategory = _find_subcategory(
requested_women_count = women_count
requested_men_count = men_count
category, subcategory, women_count, men_count = _find_subcategory(
categories,
category_choice,
subcategory_choice,
@@ -1689,6 +1730,14 @@ def _build_custom_row(
women_count,
men_count,
)
count_adjustment = {}
if women_count != requested_women_count or men_count != requested_men_count:
count_adjustment = {
"requested_women_count": requested_women_count,
"requested_men_count": requested_men_count,
"effective_women_count": women_count,
"effective_men_count": men_count,
}
content_axis = "pose" if _is_pose_content_category(category, subcategory) else "content"
content_rng = _axis_rng(seed_config, content_axis, seed, row_number)
items = _list_from(subcategory.get("items", [subcategory["name"]]))
@@ -1809,6 +1858,7 @@ def _build_custom_row(
"women_count": context.get("women_count", ""),
"men_count": context.get("men_count", ""),
"person_count": context.get("person_count", ""),
"cast_count_adjustment": count_adjustment if subject_type == "configured_cast" else {},
"source": "json_category",
}
)