Add atlas refine cue seed workflow

This commit is contained in:
2026-07-01 14:10:23 +02:00
parent 83dfecc55b
commit 5f602db06b
34 changed files with 12162 additions and 18 deletions
+99 -12
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import json
import random
try:
from . import krea2_eval_log
@@ -97,6 +98,57 @@ def _selected_variant_keys(variants):
return [str(variant.get("key")) for variant in variants if variant.get("key")]
def _int_seed(value, default=-1):
try:
seed = int(value)
except (TypeError, ValueError):
return default
return seed if seed >= 0 else default
def _seeded_prompt_variant_indices(variants, atlas_cue_seed=-1):
seed = _int_seed(atlas_cue_seed)
if seed < 0:
return {}, seed
indices = {}
for variant in variants:
key = str(variant.get("key") or "").strip()
if not key:
continue
cue_sets = krea2_pose_variant_catalog.prompt_cue_sets(variant)
if len(cue_sets) <= 1:
continue
rng = random.Random(f"sxcp_krea2_atlas_cue:{seed}:{key}")
indices[key] = rng.randrange(len(cue_sets))
return indices, seed
def _normalized_prompt_variant_indices(value):
if not isinstance(value, dict):
return {}
indices = {}
for key, index in value.items():
key_text = str(key or "").strip()
if not key_text:
continue
try:
indices[key_text] = int(index)
except (TypeError, ValueError):
continue
return indices
def _summary_without_variant_metadata(summary):
return "; ".join(
part
for part in (str(summary or "").split(";"))
if part.strip()
and not part.strip().startswith("variants=")
and not part.strip().startswith("cue_seed=")
and not part.strip().startswith("cue_indices=")
).strip()
def _merged_family_for_variant_filter(incoming_config, combine_mode, family):
family = _variant_family(family)
if combine_mode != "add":
@@ -120,7 +172,7 @@ def _empty_or_incoming_config(incoming_config, combine_mode):
return json.dumps(config, ensure_ascii=True, sort_keys=True)
def _merge_variant_metadata(config_json, variants):
def _merge_variant_metadata(config_json, variants, atlas_cue_seed=-1):
config = json.loads(config_json)
selected_keys = _selected_variant_keys(variants)
existing_keys = config.get("krea2_variant_keys") or []
@@ -133,10 +185,32 @@ def _merge_variant_metadata(config_json, variants):
existing_statuses = config.get("krea2_variant_statuses") if isinstance(config.get("krea2_variant_statuses"), dict) else {}
config["krea2_variant_statuses"] = {**existing_statuses, **selected_statuses}
base_summary = str(config.get("summary") or hardcore_position_summary(config))
if variant_keys and "variants=" not in base_summary:
base_summary = f"{base_summary}; variants={','.join(variant_keys)}"
config["summary"] = base_summary
existing_indices = _normalized_prompt_variant_indices(config.get("krea2_prompt_variant_indices"))
seeded_indices, seed = _seeded_prompt_variant_indices(variants, atlas_cue_seed)
prompt_variant_indices = {**existing_indices, **seeded_indices}
if prompt_variant_indices:
config["krea2_prompt_variant_indices"] = prompt_variant_indices
if seeded_indices:
config["krea2_prompt_variant_seed"] = seed
config["krea2_prompt_variant_seed_axis"] = "atlas_cue_seed"
base_summary = _summary_without_variant_metadata(config.get("summary") or hardcore_position_summary(config))
summary_parts = [base_summary] if base_summary else []
if variant_keys:
summary_parts.append("variants=" + ",".join(variant_keys))
if seeded_indices:
summary_parts.append(f"cue_seed={seed}")
selected_indices = {
key: prompt_variant_indices[key]
for key in variant_keys
if key in prompt_variant_indices
}
if selected_indices:
summary_parts.append(
"cue_indices="
+ ",".join(f"{key}:{selected_indices[key]}" for key in variant_keys if key in selected_indices)
)
config["summary"] = "; ".join(part for part in summary_parts if part)
return json.dumps(config, ensure_ascii=True, sort_keys=True)
@@ -195,6 +269,7 @@ class SxCPKrea2PoseVariant:
},
"optional": {
"hardcore_position_config": (SXCP_HARDCORE_POSITION_CONFIG,),
"atlas_cue_seed": ("INT", {"default": -1, "min": -1, "max": 0xFFFFFFFF, "step": 1}),
},
}
@@ -210,7 +285,7 @@ class SxCPKrea2PoseVariant:
FUNCTION = "build"
CATEGORY = "prompt_builder"
def build(self, variant_key, combine_mode="replace", hardcore_position_config=""):
def build(self, variant_key, combine_mode="replace", hardcore_position_config="", atlas_cue_seed=-1):
variant = krea2_pose_variant_catalog.get_variant(variant_key)
if not variant:
empty = {
@@ -228,12 +303,23 @@ class SxCPKrea2PoseVariant:
family=family,
selected_positions=positions,
)
config = _merge_variant_metadata(config, [variant], atlas_cue_seed=atlas_cue_seed)
parsed_config = json.loads(config)
prompt_cues = "; ".join(str(cue) for cue in variant.get("prompt_cues", []) if str(cue).strip())
avoid_cues = "; ".join(str(cue) for cue in variant.get("avoid_cues", []) if str(cue).strip())
summary = (
f"variant={variant.get('key')}; status={variant.get('status')}; "
f"family={family}; positions={','.join(positions) or 'none'}"
)
summary_parts = [
f"variant={variant.get('key')}",
f"status={variant.get('status')}",
f"family={family}",
f"positions={','.join(positions) or 'none'}",
]
if parsed_config.get("krea2_prompt_variant_seed") is not None:
summary_parts.append(f"cue_seed={parsed_config.get('krea2_prompt_variant_seed')}")
prompt_variant_indices = _normalized_prompt_variant_indices(parsed_config.get("krea2_prompt_variant_indices"))
selected_index = prompt_variant_indices.get(str(variant.get("key") or ""))
if selected_index is not None:
summary_parts.append(f"cue_indices={variant.get('key')}:{selected_index}")
summary = "; ".join(summary_parts)
return (
config,
str(variant.get("key") or variant_key),
@@ -252,6 +338,7 @@ class _SxCPKrea2POVVariantFilter:
def INPUT_TYPES(cls):
required = {
"combine_mode": (["replace", "add"], {"default": "replace"}),
"atlas_cue_seed": ("INT", {"default": -1, "min": -1, "max": 0xFFFFFFFF, "step": 1}),
}
for variant in _variants_for_action_family(cls.ACTION_FAMILY):
required[_variant_input_key(variant.get("key"))] = ("BOOLEAN", {"default": False})
@@ -274,7 +361,7 @@ class _SxCPKrea2POVVariantFilter:
FUNCTION = "build"
CATEGORY = "prompt_builder"
def build(self, combine_mode="replace", hardcore_position_config="", **kwargs):
def build(self, combine_mode="replace", hardcore_position_config="", atlas_cue_seed=-1, **kwargs):
variants = _selected_variant_rows(self.ACTION_FAMILY, kwargs)
if not variants:
config = _empty_or_incoming_config(hardcore_position_config or "", combine_mode)
@@ -292,7 +379,7 @@ class _SxCPKrea2POVVariantFilter:
family=family,
selected_positions=positions,
)
config = _merge_variant_metadata(config, variants)
config = _merge_variant_metadata(config, variants, atlas_cue_seed=atlas_cue_seed)
parsed = json.loads(config)
selected_keys = parsed.get("krea2_variant_keys") or []
selected_positions = parsed.get("positions") or []