Extract row generation policy
This commit is contained in:
@@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from . import category_library as category_policy
|
||||
from . import generate_prompt_batches as g
|
||||
from . import row_item as row_item_policy
|
||||
from . import seed_config as seed_policy
|
||||
except ImportError: # Allows local smoke tests with top-level imports.
|
||||
import category_library as category_policy
|
||||
import generate_prompt_batches as g
|
||||
import row_item as row_item_policy
|
||||
import seed_config as seed_policy
|
||||
|
||||
|
||||
def ratio_or_none(value: float) -> float | None:
|
||||
try:
|
||||
ratio = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if ratio < 0:
|
||||
return None
|
||||
return max(0.0, min(1.0, ratio))
|
||||
|
||||
|
||||
def clamped_float(value: Any, default: float = 0.5, min_value: float = 0.0, max_value: float = 1.0) -> float:
|
||||
try:
|
||||
number = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return max(min_value, min(max_value, number))
|
||||
|
||||
|
||||
def pick_clothing_mode(rng: random.Random, clothing: str, minimal_ratio: float | None) -> str:
|
||||
if clothing == "random":
|
||||
return "minimal" if rng.random() < 0.5 else "full"
|
||||
if minimal_ratio is None:
|
||||
return clothing
|
||||
return "minimal" if rng.random() < minimal_ratio else "full"
|
||||
|
||||
|
||||
def pick_pose_mode(rng: random.Random, poses: str, standard_ratio: float | None) -> str:
|
||||
if poses == "random":
|
||||
return "standard" if rng.random() < 0.5 else "evocative"
|
||||
if standard_ratio is None:
|
||||
return poses
|
||||
return "standard" if rng.random() < standard_ratio else "evocative"
|
||||
|
||||
|
||||
def pick_figure_bias(rng: random.Random, figure: str) -> str:
|
||||
if figure in ("curvy", "balanced", "bombshell"):
|
||||
return figure
|
||||
return g.choose(rng, ["curvy", "balanced", "bombshell"])
|
||||
|
||||
|
||||
def pick_expression_intensity(rng: random.Random, expression_intensity: Any) -> tuple[float, str]:
|
||||
try:
|
||||
value = float(expression_intensity)
|
||||
except (TypeError, ValueError):
|
||||
return 0.5, "default"
|
||||
if value < 0:
|
||||
return round(rng.random(), 2), "random"
|
||||
return clamped_float(value, 0.5), "input"
|
||||
|
||||
|
||||
def build_auto_weighted_row(
|
||||
row_number: int,
|
||||
start_index: int,
|
||||
clothing: str,
|
||||
ethnicity: str,
|
||||
poses: str,
|
||||
backside_bias: float,
|
||||
figure: str,
|
||||
no_plus_women: bool,
|
||||
no_black: bool,
|
||||
minimal_clothing_ratio: float | None,
|
||||
standard_pose_ratio: float | None,
|
||||
seed: int,
|
||||
) -> dict[str, Any]:
|
||||
batch_number = max(1, ((row_number - 1) // g.BATCH_SIZE) + 1)
|
||||
rows = g.build_rows(
|
||||
batch_number * g.BATCH_SIZE,
|
||||
start_index,
|
||||
clothing,
|
||||
ethnicity,
|
||||
poses,
|
||||
backside_bias,
|
||||
figure,
|
||||
no_plus_women,
|
||||
no_black,
|
||||
minimal_clothing_ratio,
|
||||
standard_pose_ratio,
|
||||
seed,
|
||||
g.EXPRESSION_SEED + seed,
|
||||
)
|
||||
row = rows[row_number - 1]
|
||||
row["main_category"] = "auto_weighted"
|
||||
row["subcategory"] = row.get("primary_subject", "auto")
|
||||
row["source"] = "built_in_generator"
|
||||
return row
|
||||
|
||||
|
||||
def build_direct_builtin_row(
|
||||
category: str,
|
||||
row_number: int,
|
||||
start_index: int,
|
||||
clothing: str,
|
||||
ethnicity: str,
|
||||
poses: str,
|
||||
backside_bias: float,
|
||||
figure: str,
|
||||
no_plus_women: bool,
|
||||
no_black: bool,
|
||||
minimal_clothing_ratio: float | None,
|
||||
standard_pose_ratio: float | None,
|
||||
seed: int,
|
||||
) -> dict[str, Any]:
|
||||
rng = random.Random(seed_policy.row_seed(seed, row_number))
|
||||
expr_deck = g.ExpressionDeck(
|
||||
g.EXPRESSIONS,
|
||||
random.Random(seed_policy.row_seed(g.EXPRESSION_SEED + seed, row_number)),
|
||||
)
|
||||
batch = max(1, ((row_number - 1) // g.BATCH_SIZE) + 1)
|
||||
index = start_index + row_number - 1
|
||||
row_clothing = pick_clothing_mode(rng, clothing, minimal_clothing_ratio)
|
||||
row_poses = pick_pose_mode(rng, poses, standard_pose_ratio)
|
||||
|
||||
if category == "woman":
|
||||
row = g.make_single(
|
||||
index,
|
||||
batch,
|
||||
rng,
|
||||
"woman",
|
||||
expr_deck,
|
||||
row_clothing,
|
||||
ethnicity,
|
||||
row_poses,
|
||||
backside_bias,
|
||||
figure,
|
||||
no_plus_women,
|
||||
no_black,
|
||||
)
|
||||
elif category == "man":
|
||||
row = g.make_single(index, batch, rng, "man", expr_deck, row_clothing, ethnicity, row_poses, backside_bias, figure)
|
||||
elif category == "couple":
|
||||
row = g.make_couple(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus_women)
|
||||
elif category == "group_or_layout":
|
||||
row = g.make_group_or_layout(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus_women)
|
||||
else:
|
||||
raise ValueError(f"Unknown built-in category: {category}")
|
||||
|
||||
row["main_category"] = category
|
||||
row["subcategory"] = row.get("pose_mode", category)
|
||||
row["source"] = "built_in_generator"
|
||||
return row
|
||||
|
||||
|
||||
def auto_full_choice(seed_config: dict[str, int], seed: int, row_number: int) -> str:
|
||||
categories = category_policy.load_category_library()
|
||||
if not categories:
|
||||
return "auto_weighted"
|
||||
category_rng = seed_policy.axis_rng(seed_config, "category", seed, row_number)
|
||||
choices: list[dict[str, Any]] = [{"category": "auto_weighted", "weight": 1.0}]
|
||||
choices.extend(
|
||||
{
|
||||
"category": category["name"],
|
||||
"weight": category.get("weight", 1.0),
|
||||
}
|
||||
for category in categories
|
||||
)
|
||||
choice = row_item_policy.weighted_choice(category_rng, choices)
|
||||
return str(choice.get("category") or "auto_weighted")
|
||||
Reference in New Issue
Block a user