Add atlas refine cue seed workflow
This commit is contained in:
+99
-12
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
|
||||
try:
|
||||
from . import krea2_eval_log
|
||||
@@ -97,6 +98,57 @@ def _selected_variant_keys(variants):
|
||||
return [str(variant.get("key")) for variant in variants if variant.get("key")]
|
||||
|
||||
|
||||
def _int_seed(value, default=-1):
|
||||
try:
|
||||
seed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return seed if seed >= 0 else default
|
||||
|
||||
|
||||
def _seeded_prompt_variant_indices(variants, atlas_cue_seed=-1):
|
||||
seed = _int_seed(atlas_cue_seed)
|
||||
if seed < 0:
|
||||
return {}, seed
|
||||
indices = {}
|
||||
for variant in variants:
|
||||
key = str(variant.get("key") or "").strip()
|
||||
if not key:
|
||||
continue
|
||||
cue_sets = krea2_pose_variant_catalog.prompt_cue_sets(variant)
|
||||
if len(cue_sets) <= 1:
|
||||
continue
|
||||
rng = random.Random(f"sxcp_krea2_atlas_cue:{seed}:{key}")
|
||||
indices[key] = rng.randrange(len(cue_sets))
|
||||
return indices, seed
|
||||
|
||||
|
||||
def _normalized_prompt_variant_indices(value):
|
||||
if not isinstance(value, dict):
|
||||
return {}
|
||||
indices = {}
|
||||
for key, index in value.items():
|
||||
key_text = str(key or "").strip()
|
||||
if not key_text:
|
||||
continue
|
||||
try:
|
||||
indices[key_text] = int(index)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return indices
|
||||
|
||||
|
||||
def _summary_without_variant_metadata(summary):
|
||||
return "; ".join(
|
||||
part
|
||||
for part in (str(summary or "").split(";"))
|
||||
if part.strip()
|
||||
and not part.strip().startswith("variants=")
|
||||
and not part.strip().startswith("cue_seed=")
|
||||
and not part.strip().startswith("cue_indices=")
|
||||
).strip()
|
||||
|
||||
|
||||
def _merged_family_for_variant_filter(incoming_config, combine_mode, family):
|
||||
family = _variant_family(family)
|
||||
if combine_mode != "add":
|
||||
@@ -120,7 +172,7 @@ def _empty_or_incoming_config(incoming_config, combine_mode):
|
||||
return json.dumps(config, ensure_ascii=True, sort_keys=True)
|
||||
|
||||
|
||||
def _merge_variant_metadata(config_json, variants):
|
||||
def _merge_variant_metadata(config_json, variants, atlas_cue_seed=-1):
|
||||
config = json.loads(config_json)
|
||||
selected_keys = _selected_variant_keys(variants)
|
||||
existing_keys = config.get("krea2_variant_keys") or []
|
||||
@@ -133,10 +185,32 @@ def _merge_variant_metadata(config_json, variants):
|
||||
existing_statuses = config.get("krea2_variant_statuses") if isinstance(config.get("krea2_variant_statuses"), dict) else {}
|
||||
config["krea2_variant_statuses"] = {**existing_statuses, **selected_statuses}
|
||||
|
||||
base_summary = str(config.get("summary") or hardcore_position_summary(config))
|
||||
if variant_keys and "variants=" not in base_summary:
|
||||
base_summary = f"{base_summary}; variants={','.join(variant_keys)}"
|
||||
config["summary"] = base_summary
|
||||
existing_indices = _normalized_prompt_variant_indices(config.get("krea2_prompt_variant_indices"))
|
||||
seeded_indices, seed = _seeded_prompt_variant_indices(variants, atlas_cue_seed)
|
||||
prompt_variant_indices = {**existing_indices, **seeded_indices}
|
||||
if prompt_variant_indices:
|
||||
config["krea2_prompt_variant_indices"] = prompt_variant_indices
|
||||
if seeded_indices:
|
||||
config["krea2_prompt_variant_seed"] = seed
|
||||
config["krea2_prompt_variant_seed_axis"] = "atlas_cue_seed"
|
||||
|
||||
base_summary = _summary_without_variant_metadata(config.get("summary") or hardcore_position_summary(config))
|
||||
summary_parts = [base_summary] if base_summary else []
|
||||
if variant_keys:
|
||||
summary_parts.append("variants=" + ",".join(variant_keys))
|
||||
if seeded_indices:
|
||||
summary_parts.append(f"cue_seed={seed}")
|
||||
selected_indices = {
|
||||
key: prompt_variant_indices[key]
|
||||
for key in variant_keys
|
||||
if key in prompt_variant_indices
|
||||
}
|
||||
if selected_indices:
|
||||
summary_parts.append(
|
||||
"cue_indices="
|
||||
+ ",".join(f"{key}:{selected_indices[key]}" for key in variant_keys if key in selected_indices)
|
||||
)
|
||||
config["summary"] = "; ".join(part for part in summary_parts if part)
|
||||
return json.dumps(config, ensure_ascii=True, sort_keys=True)
|
||||
|
||||
|
||||
@@ -195,6 +269,7 @@ class SxCPKrea2PoseVariant:
|
||||
},
|
||||
"optional": {
|
||||
"hardcore_position_config": (SXCP_HARDCORE_POSITION_CONFIG,),
|
||||
"atlas_cue_seed": ("INT", {"default": -1, "min": -1, "max": 0xFFFFFFFF, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -210,7 +285,7 @@ class SxCPKrea2PoseVariant:
|
||||
FUNCTION = "build"
|
||||
CATEGORY = "prompt_builder"
|
||||
|
||||
def build(self, variant_key, combine_mode="replace", hardcore_position_config=""):
|
||||
def build(self, variant_key, combine_mode="replace", hardcore_position_config="", atlas_cue_seed=-1):
|
||||
variant = krea2_pose_variant_catalog.get_variant(variant_key)
|
||||
if not variant:
|
||||
empty = {
|
||||
@@ -228,12 +303,23 @@ class SxCPKrea2PoseVariant:
|
||||
family=family,
|
||||
selected_positions=positions,
|
||||
)
|
||||
config = _merge_variant_metadata(config, [variant], atlas_cue_seed=atlas_cue_seed)
|
||||
parsed_config = json.loads(config)
|
||||
prompt_cues = "; ".join(str(cue) for cue in variant.get("prompt_cues", []) if str(cue).strip())
|
||||
avoid_cues = "; ".join(str(cue) for cue in variant.get("avoid_cues", []) if str(cue).strip())
|
||||
summary = (
|
||||
f"variant={variant.get('key')}; status={variant.get('status')}; "
|
||||
f"family={family}; positions={','.join(positions) or 'none'}"
|
||||
)
|
||||
summary_parts = [
|
||||
f"variant={variant.get('key')}",
|
||||
f"status={variant.get('status')}",
|
||||
f"family={family}",
|
||||
f"positions={','.join(positions) or 'none'}",
|
||||
]
|
||||
if parsed_config.get("krea2_prompt_variant_seed") is not None:
|
||||
summary_parts.append(f"cue_seed={parsed_config.get('krea2_prompt_variant_seed')}")
|
||||
prompt_variant_indices = _normalized_prompt_variant_indices(parsed_config.get("krea2_prompt_variant_indices"))
|
||||
selected_index = prompt_variant_indices.get(str(variant.get("key") or ""))
|
||||
if selected_index is not None:
|
||||
summary_parts.append(f"cue_indices={variant.get('key')}:{selected_index}")
|
||||
summary = "; ".join(summary_parts)
|
||||
return (
|
||||
config,
|
||||
str(variant.get("key") or variant_key),
|
||||
@@ -252,6 +338,7 @@ class _SxCPKrea2POVVariantFilter:
|
||||
def INPUT_TYPES(cls):
|
||||
required = {
|
||||
"combine_mode": (["replace", "add"], {"default": "replace"}),
|
||||
"atlas_cue_seed": ("INT", {"default": -1, "min": -1, "max": 0xFFFFFFFF, "step": 1}),
|
||||
}
|
||||
for variant in _variants_for_action_family(cls.ACTION_FAMILY):
|
||||
required[_variant_input_key(variant.get("key"))] = ("BOOLEAN", {"default": False})
|
||||
@@ -274,7 +361,7 @@ class _SxCPKrea2POVVariantFilter:
|
||||
FUNCTION = "build"
|
||||
CATEGORY = "prompt_builder"
|
||||
|
||||
def build(self, combine_mode="replace", hardcore_position_config="", **kwargs):
|
||||
def build(self, combine_mode="replace", hardcore_position_config="", atlas_cue_seed=-1, **kwargs):
|
||||
variants = _selected_variant_rows(self.ACTION_FAMILY, kwargs)
|
||||
if not variants:
|
||||
config = _empty_or_incoming_config(hardcore_position_config or "", combine_mode)
|
||||
@@ -292,7 +379,7 @@ class _SxCPKrea2POVVariantFilter:
|
||||
family=family,
|
||||
selected_positions=positions,
|
||||
)
|
||||
config = _merge_variant_metadata(config, variants)
|
||||
config = _merge_variant_metadata(config, variants, atlas_cue_seed=atlas_cue_seed)
|
||||
parsed = json.loads(config)
|
||||
selected_keys = parsed.get("krea2_variant_keys") or []
|
||||
selected_positions = parsed.get("positions") or []
|
||||
|
||||
Reference in New Issue
Block a user