Extract row generation policy

This commit is contained in:
2026-06-27 09:10:51 +02:00
parent 58ddda82d7
commit 23bcb1b526
5 changed files with 278 additions and 91 deletions
+26 -90
View File
@@ -40,6 +40,7 @@ try:
from . import row_normalization as row_policy
from . import row_camera as row_camera_policy
from . import row_expression as row_expression_policy
from . import row_generation as row_generation_policy
from . import row_item as row_item_policy
from . import row_location as row_location_policy
from . import row_pools as row_pool_policy
@@ -84,6 +85,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
import row_normalization as row_policy
import row_camera as row_camera_policy
import row_expression as row_expression_policy
import row_generation as row_generation_policy
import row_item as row_item_policy
import row_location as row_location_policy
import row_pools as row_pool_policy
@@ -771,21 +773,11 @@ def _apply_hardcore_position_config_to_subcategory(
def _ratio_or_none(value: float) -> float | None:
try:
ratio = float(value)
except (TypeError, ValueError):
return None
if ratio < 0:
return None
return max(0.0, min(1.0, ratio))
return row_generation_policy.ratio_or_none(value)
def _clamped_float(value: Any, default: float = 0.5, min_value: float = 0.0, max_value: float = 1.0) -> float:
try:
number = float(value)
except (TypeError, ValueError):
return default
return max(min_value, min(max_value, number))
return row_generation_policy.clamped_float(value, default, min_value, max_value)
def build_seed_config_json(
@@ -1251,35 +1243,19 @@ def _row_seed(seed: int, row_number: int, salt: int = 0) -> int:
def _pick_clothing_mode(rng: random.Random, clothing: str, minimal_ratio: float | None) -> str:
if clothing == "random":
return "minimal" if rng.random() < 0.5 else "full"
if minimal_ratio is None:
return clothing
return "minimal" if rng.random() < minimal_ratio else "full"
return row_generation_policy.pick_clothing_mode(rng, clothing, minimal_ratio)
def _pick_pose_mode(rng: random.Random, poses: str, standard_ratio: float | None) -> str:
if poses == "random":
return "standard" if rng.random() < 0.5 else "evocative"
if standard_ratio is None:
return poses
return "standard" if rng.random() < standard_ratio else "evocative"
return row_generation_policy.pick_pose_mode(rng, poses, standard_ratio)
def _pick_figure_bias(rng: random.Random, figure: str) -> str:
if figure in ("curvy", "balanced", "bombshell"):
return figure
return g.choose(rng, ["curvy", "balanced", "bombshell"])
return row_generation_policy.pick_figure_bias(rng, figure)
def _pick_expression_intensity(rng: random.Random, expression_intensity: Any) -> tuple[float, str]:
try:
value = float(expression_intensity)
except (TypeError, ValueError):
return 0.5, "default"
if value < 0:
return round(rng.random(), 2), "random"
return _clamped_float(value, 0.5), "input"
return row_generation_policy.pick_expression_intensity(rng, expression_intensity)
def _build_auto_weighted_row(
@@ -1296,9 +1272,8 @@ def _build_auto_weighted_row(
standard_pose_ratio: float | None,
seed: int,
) -> dict[str, Any]:
batch_number = max(1, ((row_number - 1) // g.BATCH_SIZE) + 1)
rows = g.build_rows(
batch_number * g.BATCH_SIZE,
return row_generation_policy.build_auto_weighted_row(
row_number,
start_index,
clothing,
ethnicity,
@@ -1310,13 +1285,7 @@ def _build_auto_weighted_row(
minimal_clothing_ratio,
standard_pose_ratio,
seed,
g.EXPRESSION_SEED + seed,
)
row = rows[row_number - 1]
row["main_category"] = "auto_weighted"
row["subcategory"] = row.get("primary_subject", "auto")
row["source"] = "built_in_generator"
return row
def _build_direct_builtin_row(
@@ -1334,58 +1303,25 @@ def _build_direct_builtin_row(
standard_pose_ratio: float | None,
seed: int,
) -> dict[str, Any]:
rng = random.Random(_row_seed(seed, row_number))
expr_deck = g.ExpressionDeck(g.EXPRESSIONS, random.Random(_row_seed(g.EXPRESSION_SEED + seed, row_number)))
batch = max(1, ((row_number - 1) // g.BATCH_SIZE) + 1)
index = start_index + row_number - 1
row_clothing = _pick_clothing_mode(rng, clothing, minimal_clothing_ratio)
row_poses = _pick_pose_mode(rng, poses, standard_pose_ratio)
if category == "woman":
row = g.make_single(
index,
batch,
rng,
"woman",
expr_deck,
row_clothing,
ethnicity,
row_poses,
backside_bias,
figure,
no_plus_women,
no_black,
)
elif category == "man":
row = g.make_single(index, batch, rng, "man", expr_deck, row_clothing, ethnicity, row_poses, backside_bias, figure)
elif category == "couple":
row = g.make_couple(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus_women)
elif category == "group_or_layout":
row = g.make_group_or_layout(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus_women)
else:
raise ValueError(f"Unknown built-in category: {category}")
row["main_category"] = category
row["subcategory"] = row.get("pose_mode", category)
row["source"] = "built_in_generator"
return row
return row_generation_policy.build_direct_builtin_row(
category,
row_number,
start_index,
clothing,
ethnicity,
poses,
backside_bias,
figure,
no_plus_women,
no_black,
minimal_clothing_ratio,
standard_pose_ratio,
seed,
)
def _auto_full_choice(seed_config: dict[str, int], seed: int, row_number: int) -> str:
categories = load_category_library()
if not categories:
return "auto_weighted"
category_rng = _axis_rng(seed_config, "category", seed, row_number)
choices: list[dict[str, Any]] = [{"category": "auto_weighted", "weight": 1.0}]
choices.extend(
{
"category": category["name"],
"weight": category.get("weight", 1.0),
}
for category in categories
)
choice = _weighted_choice(category_rng, choices)
return str(choice.get("category") or "auto_weighted")
return row_generation_policy.auto_full_choice(seed_config, seed, row_number)
def _body_phrase(body: Any, figure_note: Any = "") -> str: