Wire seed config into built-in clothing routes
This commit is contained in:
@@ -204,6 +204,7 @@ def build_prompt_result(request: PromptBuildRequest, deps: PromptBuildDependenci
|
||||
minimal_ratio,
|
||||
pose_ratio,
|
||||
seed,
|
||||
seed_config=parsed_seed_config,
|
||||
)
|
||||
else:
|
||||
row = deps.build_custom_row(
|
||||
|
||||
@@ -3232,8 +3232,9 @@ def make_group_or_layout(index: int, batch: int, rng: random.Random, expr_deck:
|
||||
return row
|
||||
|
||||
|
||||
def build_rows(total: int, start_index: int, clothing: str = "full", ethnicity: str = "any", poses: str = "standard", backside_bias: float = 0.0, figure: str = "curvy", no_plus: bool = False, no_black: bool = False, minimal_clothing_ratio: float | None = None, standard_pose_ratio: float | None = None, seed: int = DEFAULT_RNG_SEED, expression_seed: int = EXPRESSION_SEED) -> list[dict]:
|
||||
def build_rows(total: int, start_index: int, clothing: str = "full", ethnicity: str = "any", poses: str = "standard", backside_bias: float = 0.0, figure: str = "curvy", no_plus: bool = False, no_black: bool = False, minimal_clothing_ratio: float | None = None, standard_pose_ratio: float | None = None, seed: int = DEFAULT_RNG_SEED, expression_seed: int = EXPRESSION_SEED, clothing_rng: random.Random | None = None) -> list[dict]:
|
||||
rng = random.Random(seed)
|
||||
wardrobe_rng = clothing_rng or rng
|
||||
expr_deck = ExpressionDeck(EXPRESSIONS, random.Random(expression_seed))
|
||||
rows: list[dict] = []
|
||||
batch_quotas = batch_category_quotas()
|
||||
@@ -3245,21 +3246,21 @@ def build_rows(total: int, start_index: int, clothing: str = "full", ethnicity:
|
||||
index = start_index
|
||||
for batch in range(1, batch_count + 1):
|
||||
batch_rows: list[dict] = []
|
||||
clothing_modes = batch_clothing_modes(rng, clothing, minimal_clothing_ratio)
|
||||
clothing_modes = batch_clothing_modes(wardrobe_rng, clothing, minimal_clothing_ratio)
|
||||
single_pose_modes = batch_single_pose_modes(rng, poses, standard_pose_ratio, single_subject_count)
|
||||
for category, count in batch_quotas:
|
||||
for _ in range(count):
|
||||
row_clothing = clothing_modes.pop()
|
||||
if category == "woman":
|
||||
row_pose = single_pose_modes.pop()
|
||||
row = make_single(index, batch, rng, "woman", expr_deck, row_clothing, ethnicity, row_pose, backside_bias, figure, no_plus, no_black)
|
||||
row = make_single(index, batch, rng, "woman", expr_deck, row_clothing, ethnicity, row_pose, backside_bias, figure, no_plus, no_black, clothing_rng=wardrobe_rng if clothing_rng else None)
|
||||
elif category == "man":
|
||||
row_pose = single_pose_modes.pop()
|
||||
row = make_single(index, batch, rng, "man", expr_deck, row_clothing, ethnicity, row_pose, backside_bias, figure, no_plus, no_black)
|
||||
row = make_single(index, batch, rng, "man", expr_deck, row_clothing, ethnicity, row_pose, backside_bias, figure, no_plus, no_black, clothing_rng=wardrobe_rng if clothing_rng else None)
|
||||
elif category == "couple":
|
||||
row = make_couple(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus)
|
||||
row = make_couple(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus, clothing_rng=wardrobe_rng if clothing_rng else None)
|
||||
else:
|
||||
row = make_group_or_layout(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus)
|
||||
row = make_group_or_layout(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus, clothing_rng=wardrobe_rng if clothing_rng else None)
|
||||
batch_rows.append(row)
|
||||
index += 1
|
||||
rng.shuffle(batch_rows)
|
||||
|
||||
@@ -1410,6 +1410,7 @@ def _build_auto_weighted_row(
|
||||
minimal_clothing_ratio: float | None,
|
||||
standard_pose_ratio: float | None,
|
||||
seed: int,
|
||||
seed_config: dict[str, int] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return row_generation_policy.build_auto_weighted_row(
|
||||
row_number,
|
||||
@@ -1424,6 +1425,7 @@ def _build_auto_weighted_row(
|
||||
minimal_clothing_ratio,
|
||||
standard_pose_ratio,
|
||||
seed,
|
||||
seed_config=seed_config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -78,8 +78,12 @@ def build_auto_weighted_row(
|
||||
minimal_clothing_ratio: float | None,
|
||||
standard_pose_ratio: float | None,
|
||||
seed: int,
|
||||
seed_config: dict[str, int] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
batch_number = max(1, ((row_number - 1) // g.BATCH_SIZE) + 1)
|
||||
clothing_rng = None
|
||||
if seed_config is not None:
|
||||
clothing_rng = seed_policy.axis_rng(seed_config, "clothing", seed, row_number)
|
||||
rows = g.build_rows(
|
||||
batch_number * g.BATCH_SIZE,
|
||||
start_index,
|
||||
@@ -94,6 +98,7 @@ def build_auto_weighted_row(
|
||||
standard_pose_ratio,
|
||||
seed,
|
||||
g.EXPRESSION_SEED + seed,
|
||||
clothing_rng=clothing_rng,
|
||||
)
|
||||
row = rows[row_number - 1]
|
||||
row["main_category"] = "auto_weighted"
|
||||
|
||||
+60
-4
@@ -479,6 +479,10 @@ def _prompt_row(
|
||||
subcategory: str,
|
||||
seed: int,
|
||||
seed_config: str | dict[str, Any] | None = None,
|
||||
clothing: str = "random",
|
||||
poses: str = "random",
|
||||
minimal_clothing_ratio: float = 0.5,
|
||||
standard_pose_ratio: float = 0.5,
|
||||
character_cast: str = "",
|
||||
women_count: int = 1,
|
||||
men_count: int = 1,
|
||||
@@ -494,15 +498,15 @@ def _prompt_row(
|
||||
row_number=1,
|
||||
start_index=1,
|
||||
seed=seed,
|
||||
clothing="random",
|
||||
clothing=clothing,
|
||||
ethnicity="any",
|
||||
poses="random",
|
||||
poses=poses,
|
||||
backside_bias=0.35,
|
||||
figure="random",
|
||||
no_plus_women=False,
|
||||
no_black=False,
|
||||
minimal_clothing_ratio=0.5,
|
||||
standard_pose_ratio=0.5,
|
||||
minimal_clothing_ratio=minimal_clothing_ratio,
|
||||
standard_pose_ratio=standard_pose_ratio,
|
||||
trigger=Trigger,
|
||||
prepend_trigger_to_prompt=True,
|
||||
extra_positive="",
|
||||
@@ -14394,6 +14398,58 @@ def smoke_seed_config_policy() -> None:
|
||||
|
||||
clothing_axis_seed = 42001
|
||||
|
||||
auto_weighted_seeded = _prompt_row(
|
||||
name="seed_config_policy_auto_weighted_seed_config",
|
||||
category="auto_weighted",
|
||||
subcategory="random",
|
||||
seed=clothing_axis_seed,
|
||||
seed_config=pb.build_seed_lock_config_json(base_seed=clothing_axis_seed),
|
||||
women_count=1,
|
||||
men_count=0,
|
||||
)
|
||||
_expect(auto_weighted_seeded.get("source") == "built_in_generator", "auto_weighted prompt with seed_config should build")
|
||||
|
||||
def direct_builtin_woman(seed_config_value: str | dict[str, Any], *, name: str) -> dict[str, Any]:
|
||||
return _prompt_row(
|
||||
name=name,
|
||||
category="woman",
|
||||
subcategory="random",
|
||||
seed=43001,
|
||||
seed_config=seed_config_value,
|
||||
clothing="full",
|
||||
poses="standard",
|
||||
minimal_clothing_ratio=-1,
|
||||
standard_pose_ratio=-1,
|
||||
women_count=1,
|
||||
men_count=0,
|
||||
)
|
||||
|
||||
direct_builtin_locked = direct_builtin_woman(
|
||||
pb.build_seed_lock_config_json(base_seed=43001),
|
||||
name="seed_config_policy_direct_builtin_locked",
|
||||
)
|
||||
direct_builtin_clothing_reroll = direct_builtin_woman(
|
||||
pb.build_seed_lock_config_json(base_seed=43001, reroll_axis="clothing", reroll_seed=43002),
|
||||
name="seed_config_policy_direct_builtin_clothing_reroll",
|
||||
)
|
||||
direct_builtin_stable_fields = ("primary_subject", "age_band", "body_type", "scene", "composition", "pose_mode")
|
||||
_expect(
|
||||
tuple(direct_builtin_locked.get(key) for key in direct_builtin_stable_fields)
|
||||
== tuple(direct_builtin_clothing_reroll.get(key) for key in direct_builtin_stable_fields),
|
||||
"Prompt-level direct built-in clothing reroll should keep non-clothing fields stable",
|
||||
)
|
||||
_expect(
|
||||
(
|
||||
direct_builtin_locked.get("prompt"),
|
||||
direct_builtin_locked.get("caption"),
|
||||
)
|
||||
!= (
|
||||
direct_builtin_clothing_reroll.get("prompt"),
|
||||
direct_builtin_clothing_reroll.get("caption"),
|
||||
),
|
||||
"Prompt-level direct built-in clothing reroll should change clothing text",
|
||||
)
|
||||
|
||||
def clothing_category_row(seed_config_value: str | dict[str, Any], *, name: str = "seed_config_policy_clothing_axis") -> dict[str, Any]:
|
||||
return pb.build_prompt(
|
||||
category="Casual clothes",
|
||||
|
||||
Reference in New Issue
Block a user