Wire seed config into built-in clothing routes
This commit is contained in:
@@ -204,6 +204,7 @@ def build_prompt_result(request: PromptBuildRequest, deps: PromptBuildDependenci
|
|||||||
minimal_ratio,
|
minimal_ratio,
|
||||||
pose_ratio,
|
pose_ratio,
|
||||||
seed,
|
seed,
|
||||||
|
seed_config=parsed_seed_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
row = deps.build_custom_row(
|
row = deps.build_custom_row(
|
||||||
|
|||||||
@@ -3232,8 +3232,9 @@ def make_group_or_layout(index: int, batch: int, rng: random.Random, expr_deck:
|
|||||||
return row
|
return row
|
||||||
|
|
||||||
|
|
||||||
def build_rows(total: int, start_index: int, clothing: str = "full", ethnicity: str = "any", poses: str = "standard", backside_bias: float = 0.0, figure: str = "curvy", no_plus: bool = False, no_black: bool = False, minimal_clothing_ratio: float | None = None, standard_pose_ratio: float | None = None, seed: int = DEFAULT_RNG_SEED, expression_seed: int = EXPRESSION_SEED) -> list[dict]:
|
def build_rows(total: int, start_index: int, clothing: str = "full", ethnicity: str = "any", poses: str = "standard", backside_bias: float = 0.0, figure: str = "curvy", no_plus: bool = False, no_black: bool = False, minimal_clothing_ratio: float | None = None, standard_pose_ratio: float | None = None, seed: int = DEFAULT_RNG_SEED, expression_seed: int = EXPRESSION_SEED, clothing_rng: random.Random | None = None) -> list[dict]:
|
||||||
rng = random.Random(seed)
|
rng = random.Random(seed)
|
||||||
|
wardrobe_rng = clothing_rng or rng
|
||||||
expr_deck = ExpressionDeck(EXPRESSIONS, random.Random(expression_seed))
|
expr_deck = ExpressionDeck(EXPRESSIONS, random.Random(expression_seed))
|
||||||
rows: list[dict] = []
|
rows: list[dict] = []
|
||||||
batch_quotas = batch_category_quotas()
|
batch_quotas = batch_category_quotas()
|
||||||
@@ -3245,21 +3246,21 @@ def build_rows(total: int, start_index: int, clothing: str = "full", ethnicity:
|
|||||||
index = start_index
|
index = start_index
|
||||||
for batch in range(1, batch_count + 1):
|
for batch in range(1, batch_count + 1):
|
||||||
batch_rows: list[dict] = []
|
batch_rows: list[dict] = []
|
||||||
clothing_modes = batch_clothing_modes(rng, clothing, minimal_clothing_ratio)
|
clothing_modes = batch_clothing_modes(wardrobe_rng, clothing, minimal_clothing_ratio)
|
||||||
single_pose_modes = batch_single_pose_modes(rng, poses, standard_pose_ratio, single_subject_count)
|
single_pose_modes = batch_single_pose_modes(rng, poses, standard_pose_ratio, single_subject_count)
|
||||||
for category, count in batch_quotas:
|
for category, count in batch_quotas:
|
||||||
for _ in range(count):
|
for _ in range(count):
|
||||||
row_clothing = clothing_modes.pop()
|
row_clothing = clothing_modes.pop()
|
||||||
if category == "woman":
|
if category == "woman":
|
||||||
row_pose = single_pose_modes.pop()
|
row_pose = single_pose_modes.pop()
|
||||||
row = make_single(index, batch, rng, "woman", expr_deck, row_clothing, ethnicity, row_pose, backside_bias, figure, no_plus, no_black)
|
row = make_single(index, batch, rng, "woman", expr_deck, row_clothing, ethnicity, row_pose, backside_bias, figure, no_plus, no_black, clothing_rng=wardrobe_rng if clothing_rng else None)
|
||||||
elif category == "man":
|
elif category == "man":
|
||||||
row_pose = single_pose_modes.pop()
|
row_pose = single_pose_modes.pop()
|
||||||
row = make_single(index, batch, rng, "man", expr_deck, row_clothing, ethnicity, row_pose, backside_bias, figure, no_plus, no_black)
|
row = make_single(index, batch, rng, "man", expr_deck, row_clothing, ethnicity, row_pose, backside_bias, figure, no_plus, no_black, clothing_rng=wardrobe_rng if clothing_rng else None)
|
||||||
elif category == "couple":
|
elif category == "couple":
|
||||||
row = make_couple(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus)
|
row = make_couple(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus, clothing_rng=wardrobe_rng if clothing_rng else None)
|
||||||
else:
|
else:
|
||||||
row = make_group_or_layout(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus)
|
row = make_group_or_layout(index, batch, rng, expr_deck, row_clothing, ethnicity, no_plus, clothing_rng=wardrobe_rng if clothing_rng else None)
|
||||||
batch_rows.append(row)
|
batch_rows.append(row)
|
||||||
index += 1
|
index += 1
|
||||||
rng.shuffle(batch_rows)
|
rng.shuffle(batch_rows)
|
||||||
|
|||||||
@@ -1410,6 +1410,7 @@ def _build_auto_weighted_row(
|
|||||||
minimal_clothing_ratio: float | None,
|
minimal_clothing_ratio: float | None,
|
||||||
standard_pose_ratio: float | None,
|
standard_pose_ratio: float | None,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
seed_config: dict[str, int] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return row_generation_policy.build_auto_weighted_row(
|
return row_generation_policy.build_auto_weighted_row(
|
||||||
row_number,
|
row_number,
|
||||||
@@ -1424,6 +1425,7 @@ def _build_auto_weighted_row(
|
|||||||
minimal_clothing_ratio,
|
minimal_clothing_ratio,
|
||||||
standard_pose_ratio,
|
standard_pose_ratio,
|
||||||
seed,
|
seed,
|
||||||
|
seed_config=seed_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -78,8 +78,12 @@ def build_auto_weighted_row(
|
|||||||
minimal_clothing_ratio: float | None,
|
minimal_clothing_ratio: float | None,
|
||||||
standard_pose_ratio: float | None,
|
standard_pose_ratio: float | None,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
seed_config: dict[str, int] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
batch_number = max(1, ((row_number - 1) // g.BATCH_SIZE) + 1)
|
batch_number = max(1, ((row_number - 1) // g.BATCH_SIZE) + 1)
|
||||||
|
clothing_rng = None
|
||||||
|
if seed_config is not None:
|
||||||
|
clothing_rng = seed_policy.axis_rng(seed_config, "clothing", seed, row_number)
|
||||||
rows = g.build_rows(
|
rows = g.build_rows(
|
||||||
batch_number * g.BATCH_SIZE,
|
batch_number * g.BATCH_SIZE,
|
||||||
start_index,
|
start_index,
|
||||||
@@ -94,6 +98,7 @@ def build_auto_weighted_row(
|
|||||||
standard_pose_ratio,
|
standard_pose_ratio,
|
||||||
seed,
|
seed,
|
||||||
g.EXPRESSION_SEED + seed,
|
g.EXPRESSION_SEED + seed,
|
||||||
|
clothing_rng=clothing_rng,
|
||||||
)
|
)
|
||||||
row = rows[row_number - 1]
|
row = rows[row_number - 1]
|
||||||
row["main_category"] = "auto_weighted"
|
row["main_category"] = "auto_weighted"
|
||||||
|
|||||||
+60
-4
@@ -479,6 +479,10 @@ def _prompt_row(
|
|||||||
subcategory: str,
|
subcategory: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
seed_config: str | dict[str, Any] | None = None,
|
seed_config: str | dict[str, Any] | None = None,
|
||||||
|
clothing: str = "random",
|
||||||
|
poses: str = "random",
|
||||||
|
minimal_clothing_ratio: float = 0.5,
|
||||||
|
standard_pose_ratio: float = 0.5,
|
||||||
character_cast: str = "",
|
character_cast: str = "",
|
||||||
women_count: int = 1,
|
women_count: int = 1,
|
||||||
men_count: int = 1,
|
men_count: int = 1,
|
||||||
@@ -494,15 +498,15 @@ def _prompt_row(
|
|||||||
row_number=1,
|
row_number=1,
|
||||||
start_index=1,
|
start_index=1,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
clothing="random",
|
clothing=clothing,
|
||||||
ethnicity="any",
|
ethnicity="any",
|
||||||
poses="random",
|
poses=poses,
|
||||||
backside_bias=0.35,
|
backside_bias=0.35,
|
||||||
figure="random",
|
figure="random",
|
||||||
no_plus_women=False,
|
no_plus_women=False,
|
||||||
no_black=False,
|
no_black=False,
|
||||||
minimal_clothing_ratio=0.5,
|
minimal_clothing_ratio=minimal_clothing_ratio,
|
||||||
standard_pose_ratio=0.5,
|
standard_pose_ratio=standard_pose_ratio,
|
||||||
trigger=Trigger,
|
trigger=Trigger,
|
||||||
prepend_trigger_to_prompt=True,
|
prepend_trigger_to_prompt=True,
|
||||||
extra_positive="",
|
extra_positive="",
|
||||||
@@ -14394,6 +14398,58 @@ def smoke_seed_config_policy() -> None:
|
|||||||
|
|
||||||
clothing_axis_seed = 42001
|
clothing_axis_seed = 42001
|
||||||
|
|
||||||
|
auto_weighted_seeded = _prompt_row(
|
||||||
|
name="seed_config_policy_auto_weighted_seed_config",
|
||||||
|
category="auto_weighted",
|
||||||
|
subcategory="random",
|
||||||
|
seed=clothing_axis_seed,
|
||||||
|
seed_config=pb.build_seed_lock_config_json(base_seed=clothing_axis_seed),
|
||||||
|
women_count=1,
|
||||||
|
men_count=0,
|
||||||
|
)
|
||||||
|
_expect(auto_weighted_seeded.get("source") == "built_in_generator", "auto_weighted prompt with seed_config should build")
|
||||||
|
|
||||||
|
def direct_builtin_woman(seed_config_value: str | dict[str, Any], *, name: str) -> dict[str, Any]:
|
||||||
|
return _prompt_row(
|
||||||
|
name=name,
|
||||||
|
category="woman",
|
||||||
|
subcategory="random",
|
||||||
|
seed=43001,
|
||||||
|
seed_config=seed_config_value,
|
||||||
|
clothing="full",
|
||||||
|
poses="standard",
|
||||||
|
minimal_clothing_ratio=-1,
|
||||||
|
standard_pose_ratio=-1,
|
||||||
|
women_count=1,
|
||||||
|
men_count=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_builtin_locked = direct_builtin_woman(
|
||||||
|
pb.build_seed_lock_config_json(base_seed=43001),
|
||||||
|
name="seed_config_policy_direct_builtin_locked",
|
||||||
|
)
|
||||||
|
direct_builtin_clothing_reroll = direct_builtin_woman(
|
||||||
|
pb.build_seed_lock_config_json(base_seed=43001, reroll_axis="clothing", reroll_seed=43002),
|
||||||
|
name="seed_config_policy_direct_builtin_clothing_reroll",
|
||||||
|
)
|
||||||
|
direct_builtin_stable_fields = ("primary_subject", "age_band", "body_type", "scene", "composition", "pose_mode")
|
||||||
|
_expect(
|
||||||
|
tuple(direct_builtin_locked.get(key) for key in direct_builtin_stable_fields)
|
||||||
|
== tuple(direct_builtin_clothing_reroll.get(key) for key in direct_builtin_stable_fields),
|
||||||
|
"Prompt-level direct built-in clothing reroll should keep non-clothing fields stable",
|
||||||
|
)
|
||||||
|
_expect(
|
||||||
|
(
|
||||||
|
direct_builtin_locked.get("prompt"),
|
||||||
|
direct_builtin_locked.get("caption"),
|
||||||
|
)
|
||||||
|
!= (
|
||||||
|
direct_builtin_clothing_reroll.get("prompt"),
|
||||||
|
direct_builtin_clothing_reroll.get("caption"),
|
||||||
|
),
|
||||||
|
"Prompt-level direct built-in clothing reroll should change clothing text",
|
||||||
|
)
|
||||||
|
|
||||||
def clothing_category_row(seed_config_value: str | dict[str, Any], *, name: str = "seed_config_policy_clothing_axis") -> dict[str, Any]:
|
def clothing_category_row(seed_config_value: str | dict[str, Any], *, name: str = "seed_config_policy_clothing_axis") -> dict[str, Any]:
|
||||||
return pb.build_prompt(
|
return pb.build_prompt(
|
||||||
category="Casual clothes",
|
category="Casual clothes",
|
||||||
|
|||||||
Reference in New Issue
Block a user