Harden seed control normalization
This commit is contained in:
+8
-16
@@ -8,12 +8,16 @@ try:
|
||||
from .seed_config import (
|
||||
build_seed_config_json,
|
||||
build_seed_lock_config_json,
|
||||
normalize_reroll_axis,
|
||||
seed_reroll_axis_choices,
|
||||
seed_mode_choices,
|
||||
)
|
||||
except ImportError: # Allows local smoke tests from the repository root.
|
||||
from seed_config import (
|
||||
build_seed_config_json,
|
||||
build_seed_lock_config_json,
|
||||
normalize_reroll_axis,
|
||||
seed_reroll_axis_choices,
|
||||
seed_mode_choices,
|
||||
)
|
||||
|
||||
@@ -180,20 +184,7 @@ class SxCPSeedLocker:
|
||||
"required": {
|
||||
"base_seed": ("INT", seed_spec),
|
||||
"reroll_axis": (
|
||||
[
|
||||
"none",
|
||||
"category",
|
||||
"subcategory",
|
||||
"content",
|
||||
"person",
|
||||
"scene",
|
||||
"pose",
|
||||
"role",
|
||||
"expression",
|
||||
"composition",
|
||||
"content_pose",
|
||||
"scene_pose",
|
||||
],
|
||||
seed_reroll_axis_choices(),
|
||||
{"default": "none"},
|
||||
),
|
||||
"reroll_seed": ("INT", reroll_seed_spec),
|
||||
@@ -206,8 +197,9 @@ class SxCPSeedLocker:
|
||||
CATEGORY = "prompt_builder"
|
||||
|
||||
def build(self, base_seed, reroll_axis, reroll_seed):
|
||||
config = build_seed_lock_config_json(base_seed=base_seed, reroll_axis=reroll_axis, reroll_seed=reroll_seed)
|
||||
summary = f"base {base_seed}; reroll {reroll_axis} with {'main seed' if int(reroll_seed) < 0 else reroll_seed}"
|
||||
normalized_axis = normalize_reroll_axis(reroll_axis)
|
||||
config = build_seed_lock_config_json(base_seed=base_seed, reroll_axis=normalized_axis, reroll_seed=reroll_seed)
|
||||
summary = f"base {base_seed}; reroll {normalized_axis} with {'main seed' if int(reroll_seed) < 0 else reroll_seed}"
|
||||
return config, summary
|
||||
|
||||
|
||||
|
||||
@@ -356,6 +356,18 @@ def seed_mode_choices() -> list[str]:
|
||||
return seed_policy.seed_mode_choices()
|
||||
|
||||
|
||||
def seed_reroll_axis_choices() -> list[str]:
|
||||
return seed_policy.seed_reroll_axis_choices()
|
||||
|
||||
|
||||
def normalize_seed_mode(value: Any) -> str:
|
||||
return seed_policy.normalize_seed_mode(value)
|
||||
|
||||
|
||||
def normalize_reroll_axis(value: Any) -> str:
|
||||
return seed_policy.normalize_reroll_axis(value)
|
||||
|
||||
|
||||
CATEGORY_PRESETS = category_cast_policy.CATEGORY_PRESETS
|
||||
CAST_PRESETS = category_cast_policy.CAST_PRESETS
|
||||
|
||||
|
||||
+48
-16
@@ -41,12 +41,58 @@ SEED_LOCK_AXES = (
|
||||
"composition",
|
||||
)
|
||||
SEED_MODE_CHOICES = ["auto", "follow_main", "fixed", "random"]
|
||||
SEED_REROLL_GROUPS = {
|
||||
"none": (),
|
||||
"category": ("category",),
|
||||
"subcategory": ("subcategory",),
|
||||
"content": ("content",),
|
||||
"person": ("person",),
|
||||
"scene": ("scene",),
|
||||
"pose": ("pose", "role"),
|
||||
"role": ("role",),
|
||||
"expression": ("expression",),
|
||||
"composition": ("composition",),
|
||||
"content_pose": ("content", "pose", "role"),
|
||||
"scene_pose": ("scene", "pose", "role"),
|
||||
}
|
||||
SEED_REROLL_AXIS_CHOICES = list(SEED_REROLL_GROUPS.keys())
|
||||
|
||||
|
||||
def _normal_key(value: Any) -> str:
|
||||
return str(value or "").strip().lower().replace("-", "_").replace(" ", "_")
|
||||
|
||||
|
||||
def seed_mode_choices() -> list[str]:
|
||||
return list(SEED_MODE_CHOICES)
|
||||
|
||||
|
||||
def normalize_seed_mode(value: Any) -> str:
|
||||
normalized = _normal_key(value)
|
||||
aliases = {
|
||||
"follow": "follow_main",
|
||||
"followmain": "follow_main",
|
||||
"follow_main_seed": "follow_main",
|
||||
"main": "follow_main",
|
||||
"main_seed": "follow_main",
|
||||
}
|
||||
normalized = aliases.get(normalized, normalized)
|
||||
return normalized if normalized in SEED_MODE_CHOICES else "auto"
|
||||
|
||||
|
||||
def seed_reroll_axis_choices() -> list[str]:
|
||||
return list(SEED_REROLL_AXIS_CHOICES)
|
||||
|
||||
|
||||
def normalize_reroll_axis(value: Any) -> str:
|
||||
normalized = _normal_key(value)
|
||||
aliases = {
|
||||
"contentpose": "content_pose",
|
||||
"scenepose": "scene_pose",
|
||||
}
|
||||
normalized = aliases.get(normalized, normalized)
|
||||
return normalized if normalized in SEED_REROLL_GROUPS else "none"
|
||||
|
||||
|
||||
def row_seed(seed: int, row_number: int, salt: int = 0) -> int:
|
||||
return int(seed) + int(row_number) * 1009 + salt * 9176
|
||||
|
||||
@@ -74,7 +120,7 @@ def build_seed_config_json(
|
||||
rng = random.SystemRandom()
|
||||
|
||||
def axis_seed(value: int, mode: str) -> int:
|
||||
mode = mode if mode in SEED_MODE_CHOICES else "auto"
|
||||
mode = normalize_seed_mode(mode)
|
||||
if mode == "auto":
|
||||
return int(value)
|
||||
if mode == "random":
|
||||
@@ -107,21 +153,7 @@ def build_seed_lock_config_json(
|
||||
) -> str:
|
||||
base_seed = int(base_seed)
|
||||
reroll_seed = int(reroll_seed)
|
||||
reroll_groups = {
|
||||
"none": (),
|
||||
"category": ("category",),
|
||||
"subcategory": ("subcategory",),
|
||||
"content": ("content",),
|
||||
"person": ("person",),
|
||||
"scene": ("scene",),
|
||||
"pose": ("pose", "role"),
|
||||
"role": ("role",),
|
||||
"expression": ("expression",),
|
||||
"composition": ("composition",),
|
||||
"content_pose": ("content", "pose", "role"),
|
||||
"scene_pose": ("scene", "pose", "role"),
|
||||
}
|
||||
reroll = set(reroll_groups.get(str(reroll_axis or "none"), ()))
|
||||
reroll = set(SEED_REROLL_GROUPS[normalize_reroll_axis(reroll_axis)])
|
||||
config: dict[str, int] = {}
|
||||
for axis in SEED_LOCK_AXES:
|
||||
config[f"{axis}_seed"] = reroll_seed if axis in reroll else base_seed
|
||||
|
||||
+19
-5
@@ -5738,14 +5738,21 @@ def smoke_node_utility_registration() -> None:
|
||||
"Seed Control summary lost random resolved seed value",
|
||||
)
|
||||
|
||||
seed, seed_config, summary = sxcp_nodes.NODE_CLASS_MAPPINGS["SxCPGlobalSeed"]().build(12345)
|
||||
parsed_seed = json.loads(seed_config)
|
||||
seed_locker = sxcp_nodes.NODE_CLASS_MAPPINGS["SxCPSeedLocker"]
|
||||
locker_inputs = seed_locker.INPUT_TYPES().get("required") or {}
|
||||
_expect(
|
||||
list(locker_inputs["reroll_axis"][0]) == seed_config.seed_reroll_axis_choices(),
|
||||
"Seed Locker reroll choices drifted from seed_config",
|
||||
)
|
||||
|
||||
seed, global_seed_config, summary = sxcp_nodes.NODE_CLASS_MAPPINGS["SxCPGlobalSeed"]().build(12345)
|
||||
parsed_seed = json.loads(global_seed_config)
|
||||
_expect(seed == 12345, "Global Seed did not return the clamped seed")
|
||||
_expect(parsed_seed, "Global Seed config should not be empty")
|
||||
_expect(all(int(value) == 12345 for value in parsed_seed.values()), "Global Seed config did not lock every axis")
|
||||
_expect("all axes locked" in summary, "Global Seed summary changed unexpectedly")
|
||||
|
||||
locker_config, locker_summary = sxcp_nodes.NODE_CLASS_MAPPINGS["SxCPSeedLocker"]().build(12345, "pose", 999)
|
||||
locker_config, locker_summary = seed_locker().build(12345, "pose", 999)
|
||||
parsed_locker = json.loads(locker_config)
|
||||
_expect(parsed_locker.get("pose_seed") == 999, "Seed Locker did not apply pose reroll seed")
|
||||
_expect("reroll pose" in locker_summary, "Seed Locker summary lost reroll axis")
|
||||
@@ -5852,6 +5859,13 @@ def smoke_server_route_payload_policy() -> None:
|
||||
def smoke_seed_config_policy() -> None:
|
||||
_expect(pb.SEED_AXIS_SALTS is seed_config.SEED_AXIS_SALTS, "prompt_builder seed salts should delegate to seed_config")
|
||||
_expect(pb.seed_mode_choices() == seed_config.seed_mode_choices(), "seed mode choices drifted from seed_config")
|
||||
_expect(
|
||||
pb.seed_reroll_axis_choices() == seed_config.seed_reroll_axis_choices(),
|
||||
"seed reroll axis choices drifted from seed_config",
|
||||
)
|
||||
_expect(pb.normalize_seed_mode("follow main") == "follow_main", "seed mode normalizer should accept spaced labels")
|
||||
_expect(pb.normalize_seed_mode("FOLLOW-MAIN") == "follow_main", "seed mode normalizer should accept hyphenated labels")
|
||||
_expect(pb.normalize_reroll_axis("content pose") == "content_pose", "reroll axis normalizer should accept spaced labels")
|
||||
|
||||
fixed_config = json.loads(
|
||||
pb.build_seed_config_json(
|
||||
@@ -5861,7 +5875,7 @@ def smoke_seed_config_policy() -> None:
|
||||
role_seed=789,
|
||||
category_seed_mode="fixed",
|
||||
content_seed_mode="fixed",
|
||||
pose_seed_mode="follow_main",
|
||||
pose_seed_mode="follow-main",
|
||||
role_seed_mode="auto",
|
||||
)
|
||||
)
|
||||
@@ -5875,7 +5889,7 @@ def smoke_seed_config_policy() -> None:
|
||||
_expect(pb._configured_axis_seed(parsed, "content") == 44, "content axis should honor item_seed alias")
|
||||
_expect(pb._configured_axis_seed(parsed, "role") == 55, "role axis should honor pose seed alias")
|
||||
|
||||
locked = json.loads(pb.build_seed_lock_config_json(base_seed=100, reroll_axis="content_pose", reroll_seed=999))
|
||||
locked = json.loads(pb.build_seed_lock_config_json(base_seed=100, reroll_axis="content pose", reroll_seed=999))
|
||||
_expect(locked["content_seed"] == 999, "content_pose reroll should alter content seed")
|
||||
_expect(locked["pose_seed"] == 999 and locked["role_seed"] == 999, "content_pose reroll should alter pose and role seeds")
|
||||
_expect(locked["scene_seed"] == 100, "content_pose reroll should leave scene locked")
|
||||
|
||||
Reference in New Issue
Block a user