Wire seed config into built-in clothing routes

This commit is contained in:
2026-07-01 16:50:46 +02:00
parent 8c3f61ea6d
commit 12c5f73104
5 changed files with 75 additions and 10 deletions
+1
View File
@@ -204,6 +204,7 @@ def build_prompt_result(request: PromptBuildRequest, deps: PromptBuildDependenci
minimal_ratio, minimal_ratio,
pose_ratio, pose_ratio,
seed, seed,
seed_config=parsed_seed_config,
) )
else: else:
row = deps.build_custom_row( row = deps.build_custom_row(
+7 -6
View File
@@ -3232,8 +3232,9 @@ def make_group_or_layout(index: int, batch: int, rng: random.Random, expr_deck:
return row return row
def build_rows(total: int, start_index: int, clothing: str = "full", ethnicity: str = "any", poses: str = "standard", backside_bias: float = 0.0, figure: str = "curvy", no_plus: bool = False, no_black: bool = False, minimal_clothing_ratio: float | None = None, standard_pose_ratio: float | None = None, seed: int = DEFAULT_RNG_SEED, expression_seed: int = EXPRESSION_SEED) -> list[dict]: def build_rows(total: int, start_index: int, clothing: str = "full", ethnicity: str = "any", poses: str = "standard", backside_bias: float = 0.0, figure: str = "curvy", no_plus: bool = False, no_black: bool = False, minimal_clothing_ratio: float | None = None, standard_pose_ratio: float | None = None, seed: int = DEFAULT_RNG_SEED, expression_seed: int = EXPRESSION_SEED, clothing_rng: random.Random | None = None) -> list[dict]:
rng = random.Random(seed) rng = random.Random(seed)
wardrobe_rng = clothing_rng or rng
expr_deck = ExpressionDeck(EXPRESSIONS, random.Random(expression_seed)) expr_deck = ExpressionDeck(EXPRESSIONS, random.Random(expression_seed))
rows: list[dict] = [] rows: list[dict] = []
batch_quotas = batch_category_quotas() batch_quotas = batch_category_quotas()
@@ -3245,21 +3246,21 @@ def build_rows(total: int, start_index: int, clothing: str = "full", ethnicity:
index = start_index index = start_index
for batch in range(1, batch_count + 1): for batch in range(1, batch_count + 1):
batch_rows: list[dict] = [] batch_rows: list[dict] = []
clothing_modes = batch_clothing_modes(rng, clothing, minimal_clothing_ratio) clothing_modes = batch_clothing_modes(wardrobe_rng, clothing, minimal_clothing_ratio)
single_pose_modes = batch_single_pose_modes(rng, poses, standard_pose_ratio, single_subject_count) single_pose_modes = batch_single_pose_modes(rng, poses, standard_pose_ratio, single_subject_count)
for category, count in batch_quotas: for category, count in batch_quotas:
for _ in range(count): for _ in range(count):
row_clothing = clothing_modes.pop() row_clothing = clothing_modes.pop()
if category == "woman": if category == "woman":
row_pose = single_pose_modes.pop() row_pose = single_pose_modes.pop()
row = make_single(index, batch, rng, "woman", expr_deck, row_clothing, ethnicity, row_pose, backside_bias, figure, no_plus, no_black) row = make_single(index, batch, rng, "woman", expr_deck, row_clothing, ethnicity, row_pose, backside_bias, figure, no_plus, no_black, clothing_rng=wardrobe_rng if clothing_rng else None)
elif category == "man": elif category == "man":
row_pose = single_pose_modes.pop() row_pose = single_pose_modes.pop()
row = make_single(index, batch, rng, "man", expr_deck, row_clothing, ethnicity, row_pose, backside_bias, figure, no_plus, no_black) row = make_single(index, batch, rng, "man", expr_deck, row_clothing, ethnicity, row_pose, backside_bias, figure, no_plus, no_black, clothing_rng=wardrobe_rng if clothing_rng else None)
elif category == "couple": elif category == "couple":
row = make_couple(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus) row = make_couple(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus, clothing_rng=wardrobe_rng if clothing_rng else None)
else: else:
row = make_group_or_layout(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus) row = make_group_or_layout(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus, clothing_rng=wardrobe_rng if clothing_rng else None)
batch_rows.append(row) batch_rows.append(row)
index += 1 index += 1
rng.shuffle(batch_rows) rng.shuffle(batch_rows)
+2
View File
@@ -1410,6 +1410,7 @@ def _build_auto_weighted_row(
minimal_clothing_ratio: float | None, minimal_clothing_ratio: float | None,
standard_pose_ratio: float | None, standard_pose_ratio: float | None,
seed: int, seed: int,
seed_config: dict[str, int] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
return row_generation_policy.build_auto_weighted_row( return row_generation_policy.build_auto_weighted_row(
row_number, row_number,
@@ -1424,6 +1425,7 @@ def _build_auto_weighted_row(
minimal_clothing_ratio, minimal_clothing_ratio,
standard_pose_ratio, standard_pose_ratio,
seed, seed,
seed_config=seed_config,
) )
+5
View File
@@ -78,8 +78,12 @@ def build_auto_weighted_row(
minimal_clothing_ratio: float | None, minimal_clothing_ratio: float | None,
standard_pose_ratio: float | None, standard_pose_ratio: float | None,
seed: int, seed: int,
seed_config: dict[str, int] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
batch_number = max(1, ((row_number - 1) // g.BATCH_SIZE) + 1) batch_number = max(1, ((row_number - 1) // g.BATCH_SIZE) + 1)
clothing_rng = None
if seed_config is not None:
clothing_rng = seed_policy.axis_rng(seed_config, "clothing", seed, row_number)
rows = g.build_rows( rows = g.build_rows(
batch_number * g.BATCH_SIZE, batch_number * g.BATCH_SIZE,
start_index, start_index,
@@ -94,6 +98,7 @@ def build_auto_weighted_row(
standard_pose_ratio, standard_pose_ratio,
seed, seed,
g.EXPRESSION_SEED + seed, g.EXPRESSION_SEED + seed,
clothing_rng=clothing_rng,
) )
row = rows[row_number - 1] row = rows[row_number - 1]
row["main_category"] = "auto_weighted" row["main_category"] = "auto_weighted"
+60 -4
View File
@@ -479,6 +479,10 @@ def _prompt_row(
subcategory: str, subcategory: str,
seed: int, seed: int,
seed_config: str | dict[str, Any] | None = None, seed_config: str | dict[str, Any] | None = None,
clothing: str = "random",
poses: str = "random",
minimal_clothing_ratio: float = 0.5,
standard_pose_ratio: float = 0.5,
character_cast: str = "", character_cast: str = "",
women_count: int = 1, women_count: int = 1,
men_count: int = 1, men_count: int = 1,
@@ -494,15 +498,15 @@ def _prompt_row(
row_number=1, row_number=1,
start_index=1, start_index=1,
seed=seed, seed=seed,
clothing="random", clothing=clothing,
ethnicity="any", ethnicity="any",
poses="random", poses=poses,
backside_bias=0.35, backside_bias=0.35,
figure="random", figure="random",
no_plus_women=False, no_plus_women=False,
no_black=False, no_black=False,
minimal_clothing_ratio=0.5, minimal_clothing_ratio=minimal_clothing_ratio,
standard_pose_ratio=0.5, standard_pose_ratio=standard_pose_ratio,
trigger=Trigger, trigger=Trigger,
prepend_trigger_to_prompt=True, prepend_trigger_to_prompt=True,
extra_positive="", extra_positive="",
@@ -14394,6 +14398,58 @@ def smoke_seed_config_policy() -> None:
clothing_axis_seed = 42001 clothing_axis_seed = 42001
auto_weighted_seeded = _prompt_row(
name="seed_config_policy_auto_weighted_seed_config",
category="auto_weighted",
subcategory="random",
seed=clothing_axis_seed,
seed_config=pb.build_seed_lock_config_json(base_seed=clothing_axis_seed),
women_count=1,
men_count=0,
)
_expect(auto_weighted_seeded.get("source") == "built_in_generator", "auto_weighted prompt with seed_config should build")
def direct_builtin_woman(seed_config_value: str | dict[str, Any], *, name: str) -> dict[str, Any]:
return _prompt_row(
name=name,
category="woman",
subcategory="random",
seed=43001,
seed_config=seed_config_value,
clothing="full",
poses="standard",
minimal_clothing_ratio=-1,
standard_pose_ratio=-1,
women_count=1,
men_count=0,
)
direct_builtin_locked = direct_builtin_woman(
pb.build_seed_lock_config_json(base_seed=43001),
name="seed_config_policy_direct_builtin_locked",
)
direct_builtin_clothing_reroll = direct_builtin_woman(
pb.build_seed_lock_config_json(base_seed=43001, reroll_axis="clothing", reroll_seed=43002),
name="seed_config_policy_direct_builtin_clothing_reroll",
)
direct_builtin_stable_fields = ("primary_subject", "age_band", "body_type", "scene", "composition", "pose_mode")
_expect(
tuple(direct_builtin_locked.get(key) for key in direct_builtin_stable_fields)
== tuple(direct_builtin_clothing_reroll.get(key) for key in direct_builtin_stable_fields),
"Prompt-level direct built-in clothing reroll should keep non-clothing fields stable",
)
_expect(
(
direct_builtin_locked.get("prompt"),
direct_builtin_locked.get("caption"),
)
!= (
direct_builtin_clothing_reroll.get("prompt"),
direct_builtin_clothing_reroll.get("caption"),
),
"Prompt-level direct built-in clothing reroll should change clothing text",
)
def clothing_category_row(seed_config_value: str | dict[str, Any], *, name: str = "seed_config_policy_clothing_axis") -> dict[str, Any]: def clothing_category_row(seed_config_value: str | dict[str, Any], *, name: str = "seed_config_policy_clothing_axis") -> dict[str, Any]:
return pb.build_prompt( return pb.build_prompt(
category="Casual clothes", category="Casual clothes",