175 lines
5.4 KiB
Python
175 lines
5.4 KiB
Python
from __future__ import annotations
|
|
|
|
import random
|
|
from typing import Any
|
|
|
|
try:
|
|
from . import category_library as category_policy
|
|
from . import generate_prompt_batches as g
|
|
from . import row_item as row_item_policy
|
|
from . import seed_config as seed_policy
|
|
except ImportError: # Allows local smoke tests with top-level imports.
|
|
import category_library as category_policy
|
|
import generate_prompt_batches as g
|
|
import row_item as row_item_policy
|
|
import seed_config as seed_policy
|
|
|
|
|
|
def ratio_or_none(value: float) -> float | None:
|
|
try:
|
|
ratio = float(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
if ratio < 0:
|
|
return None
|
|
return max(0.0, min(1.0, ratio))
|
|
|
|
|
|
def clamped_float(value: Any, default: float = 0.5, min_value: float = 0.0, max_value: float = 1.0) -> float:
|
|
try:
|
|
number = float(value)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
return max(min_value, min(max_value, number))
|
|
|
|
|
|
def pick_clothing_mode(rng: random.Random, clothing: str, minimal_ratio: float | None) -> str:
|
|
if clothing == "random":
|
|
return "minimal" if rng.random() < 0.5 else "full"
|
|
if minimal_ratio is None:
|
|
return clothing
|
|
return "minimal" if rng.random() < minimal_ratio else "full"
|
|
|
|
|
|
def pick_pose_mode(rng: random.Random, poses: str, standard_ratio: float | None) -> str:
|
|
if poses == "random":
|
|
return "standard" if rng.random() < 0.5 else "evocative"
|
|
if standard_ratio is None:
|
|
return poses
|
|
return "standard" if rng.random() < standard_ratio else "evocative"
|
|
|
|
|
|
def pick_figure_bias(rng: random.Random, figure: str) -> str:
|
|
if figure in ("curvy", "balanced", "bombshell"):
|
|
return figure
|
|
return g.choose(rng, ["curvy", "balanced", "bombshell"])
|
|
|
|
|
|
def pick_expression_intensity(rng: random.Random, expression_intensity: Any) -> tuple[float, str]:
|
|
try:
|
|
value = float(expression_intensity)
|
|
except (TypeError, ValueError):
|
|
return 0.5, "default"
|
|
if value < 0:
|
|
return round(rng.random(), 2), "random"
|
|
return clamped_float(value, 0.5), "input"
|
|
|
|
|
|
def build_auto_weighted_row(
|
|
row_number: int,
|
|
start_index: int,
|
|
clothing: str,
|
|
ethnicity: str,
|
|
poses: str,
|
|
backside_bias: float,
|
|
figure: str,
|
|
no_plus_women: bool,
|
|
no_black: bool,
|
|
minimal_clothing_ratio: float | None,
|
|
standard_pose_ratio: float | None,
|
|
seed: int,
|
|
) -> dict[str, Any]:
|
|
batch_number = max(1, ((row_number - 1) // g.BATCH_SIZE) + 1)
|
|
rows = g.build_rows(
|
|
batch_number * g.BATCH_SIZE,
|
|
start_index,
|
|
clothing,
|
|
ethnicity,
|
|
poses,
|
|
backside_bias,
|
|
figure,
|
|
no_plus_women,
|
|
no_black,
|
|
minimal_clothing_ratio,
|
|
standard_pose_ratio,
|
|
seed,
|
|
g.EXPRESSION_SEED + seed,
|
|
)
|
|
row = rows[row_number - 1]
|
|
row["main_category"] = "auto_weighted"
|
|
row["subcategory"] = row.get("primary_subject", "auto")
|
|
row["source"] = "built_in_generator"
|
|
return row
|
|
|
|
|
|
def build_direct_builtin_row(
|
|
category: str,
|
|
row_number: int,
|
|
start_index: int,
|
|
clothing: str,
|
|
ethnicity: str,
|
|
poses: str,
|
|
backside_bias: float,
|
|
figure: str,
|
|
no_plus_women: bool,
|
|
no_black: bool,
|
|
minimal_clothing_ratio: float | None,
|
|
standard_pose_ratio: float | None,
|
|
seed: int,
|
|
) -> dict[str, Any]:
|
|
rng = random.Random(seed_policy.row_seed(seed, row_number))
|
|
expr_deck = g.ExpressionDeck(
|
|
g.EXPRESSIONS,
|
|
random.Random(seed_policy.row_seed(g.EXPRESSION_SEED + seed, row_number)),
|
|
)
|
|
batch = max(1, ((row_number - 1) // g.BATCH_SIZE) + 1)
|
|
index = start_index + row_number - 1
|
|
row_clothing = pick_clothing_mode(rng, clothing, minimal_clothing_ratio)
|
|
row_poses = pick_pose_mode(rng, poses, standard_pose_ratio)
|
|
|
|
if category == "woman":
|
|
row = g.make_single(
|
|
index,
|
|
batch,
|
|
rng,
|
|
"woman",
|
|
expr_deck,
|
|
row_clothing,
|
|
ethnicity,
|
|
row_poses,
|
|
backside_bias,
|
|
figure,
|
|
no_plus_women,
|
|
no_black,
|
|
)
|
|
elif category == "man":
|
|
row = g.make_single(index, batch, rng, "man", expr_deck, row_clothing, ethnicity, row_poses, backside_bias, figure)
|
|
elif category == "couple":
|
|
row = g.make_couple(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus_women)
|
|
elif category == "group_or_layout":
|
|
row = g.make_group_or_layout(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus_women)
|
|
else:
|
|
raise ValueError(f"Unknown built-in category: {category}")
|
|
|
|
row["main_category"] = category
|
|
row["subcategory"] = row.get("pose_mode", category)
|
|
row["source"] = "built_in_generator"
|
|
return row
|
|
|
|
|
|
def auto_full_choice(seed_config: dict[str, int], seed: int, row_number: int) -> str:
|
|
categories = category_policy.load_category_library()
|
|
if not categories:
|
|
return "auto_weighted"
|
|
category_rng = seed_policy.axis_rng(seed_config, "category", seed, row_number)
|
|
choices: list[dict[str, Any]] = [{"category": "auto_weighted", "weight": 1.0}]
|
|
choices.extend(
|
|
{
|
|
"category": category["name"],
|
|
"weight": category.get("weight", 1.0),
|
|
}
|
|
for category in categories
|
|
)
|
|
choice = row_item_policy.weighted_choice(category_rng, choices)
|
|
return str(choice.get("category") or "auto_weighted")
|