Add Krea2 POV routing and eval tooling
This commit is contained in:
@@ -8,9 +8,12 @@ try:
|
||||
from .hardcore_position_config import (
|
||||
build_hardcore_action_filter_json,
|
||||
build_hardcore_position_pool_json,
|
||||
empty_hardcore_position_config,
|
||||
hardcore_position_family_choices,
|
||||
hardcore_position_focus_choices,
|
||||
hardcore_position_key_choices,
|
||||
hardcore_position_summary,
|
||||
parse_hardcore_position_config,
|
||||
)
|
||||
except ImportError: # Allows local smoke tests from the repository root.
|
||||
import krea2_eval_log
|
||||
@@ -18,9 +21,12 @@ except ImportError: # Allows local smoke tests from the repository root.
|
||||
from hardcore_position_config import (
|
||||
build_hardcore_action_filter_json,
|
||||
build_hardcore_position_pool_json,
|
||||
empty_hardcore_position_config,
|
||||
hardcore_position_family_choices,
|
||||
hardcore_position_focus_choices,
|
||||
hardcore_position_key_choices,
|
||||
hardcore_position_summary,
|
||||
parse_hardcore_position_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,6 +40,19 @@ def _choice_input_key(prefix, choice):
|
||||
return f"{prefix}_{key}"
|
||||
|
||||
|
||||
def _variant_input_key(variant_key):
|
||||
return _choice_input_key("include", str(variant_key or "").removeprefix("pov_"))
|
||||
|
||||
|
||||
def _unique_extend(values):
|
||||
selected = []
|
||||
for value in values:
|
||||
text = str(value or "").strip()
|
||||
if text and text not in selected:
|
||||
selected.append(text)
|
||||
return selected
|
||||
|
||||
|
||||
def _variant_family(value):
|
||||
family = str(value or "any")
|
||||
if family == "penetration":
|
||||
@@ -46,6 +65,90 @@ def _variant_positions(variant):
|
||||
return [str(key) for key in variant.get("position_keys", []) if str(key) in valid]
|
||||
|
||||
|
||||
def _variants_for_action_family(action_family):
|
||||
return krea2_pose_variant_catalog.variants(action_family=action_family)
|
||||
|
||||
|
||||
def _selected_variant_rows(action_family, kwargs):
|
||||
return [
|
||||
variant
|
||||
for variant in _variants_for_action_family(action_family)
|
||||
if bool(kwargs.get(_variant_input_key(variant.get("key")), False))
|
||||
]
|
||||
|
||||
|
||||
def _join_variant_cues(variants, cue_key):
|
||||
cues = []
|
||||
for variant in variants:
|
||||
cues.extend(str(cue) for cue in variant.get(cue_key, []) if str(cue).strip())
|
||||
return "; ".join(_unique_extend(cues))
|
||||
|
||||
|
||||
def _selected_variant_positions(variants):
|
||||
positions = []
|
||||
for variant in variants:
|
||||
positions.extend(_variant_positions(variant))
|
||||
return _unique_extend(positions)
|
||||
|
||||
|
||||
def _selected_variant_keys(variants):
|
||||
return [str(variant.get("key")) for variant in variants if variant.get("key")]
|
||||
|
||||
|
||||
def _merged_family_for_variant_filter(incoming_config, combine_mode, family):
|
||||
family = _variant_family(family)
|
||||
if combine_mode != "add":
|
||||
return family
|
||||
incoming = parse_hardcore_position_config(incoming_config)
|
||||
incoming_family = _variant_family(incoming.get("family"))
|
||||
incoming_positions = incoming.get("positions") or []
|
||||
if not incoming.get("enabled") or (not incoming_positions and incoming_family == "any"):
|
||||
return family
|
||||
if incoming_family == family:
|
||||
return family
|
||||
return "any"
|
||||
|
||||
|
||||
def _empty_or_incoming_config(incoming_config, combine_mode):
|
||||
if combine_mode == "add" and incoming_config:
|
||||
config = parse_hardcore_position_config(incoming_config)
|
||||
else:
|
||||
config = empty_hardcore_position_config()
|
||||
config["summary"] = hardcore_position_summary(config)
|
||||
return json.dumps(config, ensure_ascii=True, sort_keys=True)
|
||||
|
||||
|
||||
def _merge_variant_metadata(config_json, variants):
|
||||
config = json.loads(config_json)
|
||||
selected_keys = _selected_variant_keys(variants)
|
||||
existing_keys = config.get("krea2_variant_keys") or []
|
||||
if not isinstance(existing_keys, list):
|
||||
existing_keys = [existing_keys]
|
||||
variant_keys = _unique_extend([*existing_keys, *selected_keys])
|
||||
config["krea2_variant_keys"] = variant_keys
|
||||
|
||||
selected_statuses = {str(variant.get("key")): str(variant.get("status") or "") for variant in variants if variant.get("key")}
|
||||
existing_statuses = config.get("krea2_variant_statuses") if isinstance(config.get("krea2_variant_statuses"), dict) else {}
|
||||
config["krea2_variant_statuses"] = {**existing_statuses, **selected_statuses}
|
||||
|
||||
prompt_cues = _unique_extend(
|
||||
[*(config.get("krea2_prompt_cues") or []), *(_join_variant_cues(variants, "prompt_cues").split("; ") if variants else [])]
|
||||
)
|
||||
avoid_cues = _unique_extend(
|
||||
[*(config.get("krea2_avoid_cues") or []), *(_join_variant_cues(variants, "avoid_cues").split("; ") if variants else [])]
|
||||
)
|
||||
if prompt_cues:
|
||||
config["krea2_prompt_cues"] = prompt_cues
|
||||
if avoid_cues:
|
||||
config["krea2_avoid_cues"] = avoid_cues
|
||||
|
||||
base_summary = str(config.get("summary") or hardcore_position_summary(config))
|
||||
if variant_keys and "variants=" not in base_summary:
|
||||
base_summary = f"{base_summary}; variants={','.join(variant_keys)}"
|
||||
config["summary"] = base_summary
|
||||
return json.dumps(config, ensure_ascii=True, sort_keys=True)
|
||||
|
||||
|
||||
class SxCPHardcorePositionPool:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
@@ -142,6 +245,104 @@ class SxCPKrea2PoseVariant:
|
||||
)
|
||||
|
||||
|
||||
class _SxCPKrea2POVVariantFilter:
|
||||
ACTION_FAMILY = ""
|
||||
POSITION_FAMILY = "any"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
required = {
|
||||
"combine_mode": (["replace", "add"], {"default": "replace"}),
|
||||
}
|
||||
for variant in _variants_for_action_family(cls.ACTION_FAMILY):
|
||||
required[_variant_input_key(variant.get("key"))] = ("BOOLEAN", {"default": False})
|
||||
return {
|
||||
"required": required,
|
||||
"optional": {
|
||||
"hardcore_position_config": (SXCP_HARDCORE_POSITION_CONFIG,),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (SXCP_HARDCORE_POSITION_CONFIG, "STRING", "STRING", "STRING", "STRING", "STRING")
|
||||
RETURN_NAMES = (
|
||||
"hardcore_position_config",
|
||||
"selected_variant_keys",
|
||||
"selected_positions",
|
||||
"prompt_cues",
|
||||
"summary",
|
||||
"variants_json",
|
||||
)
|
||||
FUNCTION = "build"
|
||||
CATEGORY = "prompt_builder"
|
||||
|
||||
def build(self, combine_mode="replace", hardcore_position_config="", **kwargs):
|
||||
variants = _selected_variant_rows(self.ACTION_FAMILY, kwargs)
|
||||
if not variants:
|
||||
config = _empty_or_incoming_config(hardcore_position_config or "", combine_mode)
|
||||
return config, "", "", "", json.loads(config).get("summary", ""), "[]"
|
||||
|
||||
positions = _selected_variant_positions(variants)
|
||||
family = _merged_family_for_variant_filter(
|
||||
hardcore_position_config or "",
|
||||
combine_mode,
|
||||
self.POSITION_FAMILY or self.ACTION_FAMILY,
|
||||
)
|
||||
config = build_hardcore_position_pool_json(
|
||||
hardcore_position_config=hardcore_position_config or "",
|
||||
combine_mode=combine_mode,
|
||||
family=family,
|
||||
selected_positions=positions,
|
||||
)
|
||||
config = _merge_variant_metadata(config, variants)
|
||||
parsed = json.loads(config)
|
||||
selected_keys = parsed.get("krea2_variant_keys") or []
|
||||
selected_positions = parsed.get("positions") or []
|
||||
prompt_cues = _join_variant_cues(variants, "prompt_cues")
|
||||
return (
|
||||
config,
|
||||
",".join(str(key) for key in selected_keys),
|
||||
",".join(str(position) for position in selected_positions),
|
||||
prompt_cues,
|
||||
str(parsed.get("summary") or ""),
|
||||
json.dumps(variants, ensure_ascii=True, sort_keys=True),
|
||||
)
|
||||
|
||||
|
||||
class SxCPKrea2POVPenetrationFilter(_SxCPKrea2POVVariantFilter):
|
||||
ACTION_FAMILY = "penetration"
|
||||
POSITION_FAMILY = "penetration"
|
||||
|
||||
|
||||
class SxCPKrea2POVOralFilter(_SxCPKrea2POVVariantFilter):
|
||||
ACTION_FAMILY = "oral"
|
||||
POSITION_FAMILY = "oral"
|
||||
|
||||
|
||||
class SxCPKrea2POVOutercourseFilter(_SxCPKrea2POVVariantFilter):
|
||||
ACTION_FAMILY = "outercourse"
|
||||
POSITION_FAMILY = "outercourse"
|
||||
|
||||
|
||||
class SxCPKrea2POVManualFilter(_SxCPKrea2POVVariantFilter):
|
||||
ACTION_FAMILY = "manual"
|
||||
POSITION_FAMILY = "manual"
|
||||
|
||||
|
||||
class SxCPKrea2POVToyFilter(_SxCPKrea2POVVariantFilter):
|
||||
ACTION_FAMILY = "toy"
|
||||
POSITION_FAMILY = "any"
|
||||
|
||||
|
||||
class SxCPKrea2POVClimaxFilter(_SxCPKrea2POVVariantFilter):
|
||||
ACTION_FAMILY = "climax"
|
||||
POSITION_FAMILY = "climax"
|
||||
|
||||
|
||||
class SxCPKrea2POVInteractionFilter(_SxCPKrea2POVVariantFilter):
|
||||
ACTION_FAMILY = "interaction"
|
||||
POSITION_FAMILY = "interaction"
|
||||
|
||||
|
||||
class SxCPKrea2VariantEvidence:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
@@ -258,6 +459,13 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SxCPHardcorePositionPool": SxCPHardcorePositionPool,
|
||||
"SxCPHardcoreActionFilter": SxCPHardcoreActionFilter,
|
||||
"SxCPKrea2PoseVariant": SxCPKrea2PoseVariant,
|
||||
"SxCPKrea2POVPenetrationFilter": SxCPKrea2POVPenetrationFilter,
|
||||
"SxCPKrea2POVOralFilter": SxCPKrea2POVOralFilter,
|
||||
"SxCPKrea2POVOutercourseFilter": SxCPKrea2POVOutercourseFilter,
|
||||
"SxCPKrea2POVManualFilter": SxCPKrea2POVManualFilter,
|
||||
"SxCPKrea2POVToyFilter": SxCPKrea2POVToyFilter,
|
||||
"SxCPKrea2POVClimaxFilter": SxCPKrea2POVClimaxFilter,
|
||||
"SxCPKrea2POVInteractionFilter": SxCPKrea2POVInteractionFilter,
|
||||
"SxCPKrea2VariantEvidence": SxCPKrea2VariantEvidence,
|
||||
}
|
||||
|
||||
@@ -265,5 +473,12 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SxCPHardcorePositionPool": "SxCP Hardcore Position Pool",
|
||||
"SxCPHardcoreActionFilter": "SxCP Hardcore Action Filter",
|
||||
"SxCPKrea2PoseVariant": "SxCP Krea2 Pose Variant",
|
||||
"SxCPKrea2POVPenetrationFilter": "SxCP Krea2 POV Penetration Filter",
|
||||
"SxCPKrea2POVOralFilter": "SxCP Krea2 POV Oral Filter",
|
||||
"SxCPKrea2POVOutercourseFilter": "SxCP Krea2 POV Outercourse Filter",
|
||||
"SxCPKrea2POVManualFilter": "SxCP Krea2 POV Manual Filter",
|
||||
"SxCPKrea2POVToyFilter": "SxCP Krea2 POV Toy Filter",
|
||||
"SxCPKrea2POVClimaxFilter": "SxCP Krea2 POV Climax Filter",
|
||||
"SxCPKrea2POVInteractionFilter": "SxCP Krea2 POV Interaction Filter",
|
||||
"SxCPKrea2VariantEvidence": "SxCP Krea2 Variant Evidence",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user