Harden seed control normalization
This commit is contained in:
+48
-16
@@ -41,12 +41,58 @@ SEED_LOCK_AXES = (
|
||||
"composition",
|
||||
)
|
||||
SEED_MODE_CHOICES = ["auto", "follow_main", "fixed", "random"]
|
||||
SEED_REROLL_GROUPS = {
|
||||
"none": (),
|
||||
"category": ("category",),
|
||||
"subcategory": ("subcategory",),
|
||||
"content": ("content",),
|
||||
"person": ("person",),
|
||||
"scene": ("scene",),
|
||||
"pose": ("pose", "role"),
|
||||
"role": ("role",),
|
||||
"expression": ("expression",),
|
||||
"composition": ("composition",),
|
||||
"content_pose": ("content", "pose", "role"),
|
||||
"scene_pose": ("scene", "pose", "role"),
|
||||
}
|
||||
SEED_REROLL_AXIS_CHOICES = list(SEED_REROLL_GROUPS.keys())
|
||||
|
||||
|
||||
def _normal_key(value: Any) -> str:
|
||||
return str(value or "").strip().lower().replace("-", "_").replace(" ", "_")
|
||||
|
||||
|
||||
def seed_mode_choices() -> list[str]:
|
||||
return list(SEED_MODE_CHOICES)
|
||||
|
||||
|
||||
def normalize_seed_mode(value: Any) -> str:
|
||||
normalized = _normal_key(value)
|
||||
aliases = {
|
||||
"follow": "follow_main",
|
||||
"followmain": "follow_main",
|
||||
"follow_main_seed": "follow_main",
|
||||
"main": "follow_main",
|
||||
"main_seed": "follow_main",
|
||||
}
|
||||
normalized = aliases.get(normalized, normalized)
|
||||
return normalized if normalized in SEED_MODE_CHOICES else "auto"
|
||||
|
||||
|
||||
def seed_reroll_axis_choices() -> list[str]:
|
||||
return list(SEED_REROLL_AXIS_CHOICES)
|
||||
|
||||
|
||||
def normalize_reroll_axis(value: Any) -> str:
|
||||
normalized = _normal_key(value)
|
||||
aliases = {
|
||||
"contentpose": "content_pose",
|
||||
"scenepose": "scene_pose",
|
||||
}
|
||||
normalized = aliases.get(normalized, normalized)
|
||||
return normalized if normalized in SEED_REROLL_GROUPS else "none"
|
||||
|
||||
|
||||
def row_seed(seed: int, row_number: int, salt: int = 0) -> int:
|
||||
return int(seed) + int(row_number) * 1009 + salt * 9176
|
||||
|
||||
@@ -74,7 +120,7 @@ def build_seed_config_json(
|
||||
rng = random.SystemRandom()
|
||||
|
||||
def axis_seed(value: int, mode: str) -> int:
|
||||
mode = mode if mode in SEED_MODE_CHOICES else "auto"
|
||||
mode = normalize_seed_mode(mode)
|
||||
if mode == "auto":
|
||||
return int(value)
|
||||
if mode == "random":
|
||||
@@ -107,21 +153,7 @@ def build_seed_lock_config_json(
|
||||
) -> str:
|
||||
base_seed = int(base_seed)
|
||||
reroll_seed = int(reroll_seed)
|
||||
reroll_groups = {
|
||||
"none": (),
|
||||
"category": ("category",),
|
||||
"subcategory": ("subcategory",),
|
||||
"content": ("content",),
|
||||
"person": ("person",),
|
||||
"scene": ("scene",),
|
||||
"pose": ("pose", "role"),
|
||||
"role": ("role",),
|
||||
"expression": ("expression",),
|
||||
"composition": ("composition",),
|
||||
"content_pose": ("content", "pose", "role"),
|
||||
"scene_pose": ("scene", "pose", "role"),
|
||||
}
|
||||
reroll = set(reroll_groups.get(str(reroll_axis or "none"), ()))
|
||||
reroll = set(SEED_REROLL_GROUPS[normalize_reroll_axis(reroll_axis)])
|
||||
config: dict[str, int] = {}
|
||||
for axis in SEED_LOCK_AXES:
|
||||
config[f"{axis}_seed"] = reroll_seed if axis in reroll else base_seed
|
||||
|
||||
Reference in New Issue
Block a user