Files
ComfyUI-Ethanfel-Prompt-Bui…/node_hardcore_position.py
T

632 lines
23 KiB
Python

from __future__ import annotations
import json
import random
try:
from . import krea2_eval_log
from . import krea2_pose_variant_catalog
from .hardcore_position_config import (
build_hardcore_action_filter_json,
build_hardcore_position_pool_json,
empty_hardcore_position_config,
hardcore_position_family_choices,
hardcore_position_focus_choices,
hardcore_position_key_choices,
hardcore_position_summary,
normalize_restore_prompt_axes,
parse_hardcore_position_config,
)
except ImportError: # Allows local smoke tests from the repository root.
import krea2_eval_log
import krea2_pose_variant_catalog
from hardcore_position_config import (
build_hardcore_action_filter_json,
build_hardcore_position_pool_json,
empty_hardcore_position_config,
hardcore_position_family_choices,
hardcore_position_focus_choices,
hardcore_position_key_choices,
hardcore_position_summary,
normalize_restore_prompt_axes,
parse_hardcore_position_config,
)
SXCP_HARDCORE_POSITION_CONFIG = "SXCP_HARDCORE_POSITION_CONFIG"
def _choice_input_key(prefix, choice):
key = "".join(char if char.isalnum() else "_" for char in str(choice).lower()).strip("_")
while "__" in key:
key = key.replace("__", "_")
return f"{prefix}_{key}"
def _variant_input_key(variant_key):
return _choice_input_key("include", str(variant_key or "").removeprefix("pov_"))
def _unique_extend(values):
selected = []
for value in values:
text = str(value or "").strip()
if text and text not in selected:
selected.append(text)
return selected
def _variant_family(value):
family = str(value or "any")
if family == "penetration":
family = "penetrative"
return family if family in hardcore_position_family_choices() else "any"
def _variant_positions(variant):
valid = set(hardcore_position_key_choices())
return [str(key) for key in variant.get("position_keys", []) if str(key) in valid]
def _variants_for_action_family(action_family):
return krea2_pose_variant_catalog.variants(action_family=action_family)
def _selected_variant_rows(action_family, kwargs):
return [
variant
for variant in _variants_for_action_family(action_family)
if bool(kwargs.get(_variant_input_key(variant.get("key")), False))
]
def _join_variant_cues(variants, cue_key):
cues = []
for variant in variants:
cues.extend(str(cue) for cue in variant.get(cue_key, []) if str(cue).strip())
return "; ".join(_unique_extend(cues))
def _selected_variant_positions(variants):
positions = []
for variant in variants:
positions.extend(_variant_positions(variant))
return _unique_extend(positions)
def _selected_variant_keys(variants):
return [str(variant.get("key")) for variant in variants if variant.get("key")]
def _int_seed(value, default=-1):
try:
seed = int(value)
except (TypeError, ValueError):
return default
return seed if seed >= 0 else default
def _seeded_prompt_variant_indices(variants, atlas_cue_seed=-1):
seed = _int_seed(atlas_cue_seed)
if seed < 0:
return {}, seed
indices = {}
for variant in variants:
key = str(variant.get("key") or "").strip()
if not key:
continue
cue_sets = krea2_pose_variant_catalog.prompt_cue_sets(variant)
if len(cue_sets) <= 1:
continue
rng = random.Random(f"sxcp_krea2_atlas_cue:{seed}:{key}")
indices[key] = rng.randrange(len(cue_sets))
return indices, seed
def _normalized_prompt_variant_indices(value):
if not isinstance(value, dict):
return {}
indices = {}
for key, index in value.items():
key_text = str(key or "").strip()
if not key_text:
continue
try:
indices[key_text] = int(index)
except (TypeError, ValueError):
continue
return indices
def _summary_without_variant_metadata(summary):
return "; ".join(
part
for part in (str(summary or "").split(";"))
if part.strip()
and not part.strip().startswith("variants=")
and not part.strip().startswith("cue_seed=")
and not part.strip().startswith("cue_indices=")
).strip()
def _merged_family_for_variant_filter(incoming_config, combine_mode, family):
family = _variant_family(family)
if combine_mode != "add":
return family
incoming = parse_hardcore_position_config(incoming_config)
incoming_family = _variant_family(incoming.get("family"))
incoming_positions = incoming.get("positions") or []
if not incoming.get("enabled") or (not incoming_positions and incoming_family == "any"):
return family
if incoming_family == family:
return family
return "any"
def _empty_or_incoming_config(incoming_config, combine_mode):
if combine_mode == "add" and incoming_config:
config = parse_hardcore_position_config(incoming_config)
else:
config = empty_hardcore_position_config()
config["summary"] = hardcore_position_summary(config)
return json.dumps(config, ensure_ascii=True, sort_keys=True)
def _merge_variant_metadata(config_json, variants, atlas_cue_seed=-1):
config = json.loads(config_json)
selected_keys = _selected_variant_keys(variants)
existing_keys = config.get("krea2_variant_keys") or []
if not isinstance(existing_keys, list):
existing_keys = [existing_keys]
variant_keys = _unique_extend([*existing_keys, *selected_keys])
config["krea2_variant_keys"] = variant_keys
selected_statuses = {str(variant.get("key")): str(variant.get("status") or "") for variant in variants if variant.get("key")}
existing_statuses = config.get("krea2_variant_statuses") if isinstance(config.get("krea2_variant_statuses"), dict) else {}
config["krea2_variant_statuses"] = {**existing_statuses, **selected_statuses}
existing_indices = _normalized_prompt_variant_indices(config.get("krea2_prompt_variant_indices"))
seeded_indices, seed = _seeded_prompt_variant_indices(variants, atlas_cue_seed)
prompt_variant_indices = {**existing_indices, **seeded_indices}
if prompt_variant_indices:
config["krea2_prompt_variant_indices"] = prompt_variant_indices
if seeded_indices:
config["krea2_prompt_variant_seed"] = seed
config["krea2_prompt_variant_seed_axis"] = "atlas_cue_seed"
base_summary = _summary_without_variant_metadata(config.get("summary") or hardcore_position_summary(config))
summary_parts = [base_summary] if base_summary else []
if variant_keys:
summary_parts.append("variants=" + ",".join(variant_keys))
if seeded_indices:
summary_parts.append(f"cue_seed={seed}")
selected_indices = {
key: prompt_variant_indices[key]
for key in variant_keys
if key in prompt_variant_indices
}
if selected_indices:
summary_parts.append(
"cue_indices="
+ ",".join(f"{key}:{selected_indices[key]}" for key in variant_keys if key in selected_indices)
)
config["summary"] = "; ".join(part for part in summary_parts if part)
return json.dumps(config, ensure_ascii=True, sort_keys=True)
def _variant_notes(variants):
return "; ".join(
f"{variant.get('key')} ({variant.get('status') or 'unknown'})"
for variant in variants
if variant.get("key")
)
class SxCPHardcorePositionPool:
@classmethod
def INPUT_TYPES(cls):
required = {
"combine_mode": (["replace", "add"], {"default": "replace"}),
"family": (hardcore_position_family_choices(), {"default": "any"}),
}
for choice in hardcore_position_key_choices():
required[_choice_input_key("include", choice)] = ("BOOLEAN", {"default": False})
return {
"required": required,
"optional": {
"hardcore_position_config": (SXCP_HARDCORE_POSITION_CONFIG,),
},
}
RETURN_TYPES = (SXCP_HARDCORE_POSITION_CONFIG, "STRING")
RETURN_NAMES = ("hardcore_position_config", "summary")
FUNCTION = "build"
CATEGORY = "prompt_builder"
def build(self, combine_mode="replace", family="any", hardcore_position_config="", **kwargs):
selected = [
choice
for choice in hardcore_position_key_choices()
if bool(kwargs.get(_choice_input_key("include", choice), False))
]
config = build_hardcore_position_pool_json(
hardcore_position_config=hardcore_position_config or "",
combine_mode=combine_mode,
family=family,
selected_positions=selected,
)
return config, json.loads(config).get("summary", "")
class SxCPKrea2PoseVariant:
@classmethod
def INPUT_TYPES(cls):
keys = krea2_pose_variant_catalog.variant_keys()
return {
"required": {
"variant_key": (keys or ["missing_catalog_variant"], {"default": keys[0] if keys else "missing_catalog_variant"}),
"combine_mode": (["replace", "add"], {"default": "replace"}),
},
"optional": {
"hardcore_position_config": (SXCP_HARDCORE_POSITION_CONFIG,),
"atlas_cue_seed": ("INT", {"default": -1, "min": -1, "max": 0xFFFFFFFF, "step": 1}),
},
}
RETURN_TYPES = (SXCP_HARDCORE_POSITION_CONFIG, "STRING", "STRING", "STRING", "STRING", "STRING")
RETURN_NAMES = (
"hardcore_position_config",
"variant_key",
"prompt_cues",
"avoid_cues",
"summary",
"variant_json",
)
FUNCTION = "build"
CATEGORY = "prompt_builder"
def build(self, variant_key, combine_mode="replace", hardcore_position_config="", atlas_cue_seed=-1):
variant = krea2_pose_variant_catalog.get_variant(variant_key)
if not variant:
empty = {
"key": str(variant_key or ""),
"status": "missing",
"summary": "missing Krea2 pose variant",
}
return hardcore_position_config or "", str(variant_key or ""), "", "", empty["summary"], json.dumps(empty, sort_keys=True)
positions = _variant_positions(variant)
family = _variant_family(variant.get("action_family") or variant.get("family"))
config = build_hardcore_position_pool_json(
hardcore_position_config=hardcore_position_config or "",
combine_mode=combine_mode,
family=family,
selected_positions=positions,
)
config = _merge_variant_metadata(config, [variant], atlas_cue_seed=atlas_cue_seed)
parsed_config = json.loads(config)
prompt_cues = "; ".join(str(cue) for cue in variant.get("prompt_cues", []) if str(cue).strip())
avoid_cues = "; ".join(str(cue) for cue in variant.get("avoid_cues", []) if str(cue).strip())
summary_parts = [
f"variant={variant.get('key')}",
f"status={variant.get('status')}",
f"family={family}",
f"positions={','.join(positions) or 'none'}",
]
if parsed_config.get("krea2_prompt_variant_seed") is not None:
summary_parts.append(f"cue_seed={parsed_config.get('krea2_prompt_variant_seed')}")
prompt_variant_indices = _normalized_prompt_variant_indices(parsed_config.get("krea2_prompt_variant_indices"))
selected_index = prompt_variant_indices.get(str(variant.get("key") or ""))
if selected_index is not None:
summary_parts.append(f"cue_indices={variant.get('key')}:{selected_index}")
summary = "; ".join(summary_parts)
return (
config,
str(variant.get("key") or variant_key),
prompt_cues,
avoid_cues,
summary,
json.dumps(variant, ensure_ascii=True, sort_keys=True),
)
class _SxCPKrea2POVVariantFilter:
ACTION_FAMILY = ""
POSITION_FAMILY = "any"
@classmethod
def INPUT_TYPES(cls):
required = {
"combine_mode": (["replace", "add"], {"default": "replace"}),
"atlas_cue_seed": ("INT", {"default": -1, "min": -1, "max": 0xFFFFFFFF, "step": 1}),
}
for variant in _variants_for_action_family(cls.ACTION_FAMILY):
required[_variant_input_key(variant.get("key"))] = ("BOOLEAN", {"default": False})
return {
"required": required,
"optional": {
"hardcore_position_config": (SXCP_HARDCORE_POSITION_CONFIG,),
},
}
RETURN_TYPES = (SXCP_HARDCORE_POSITION_CONFIG, "STRING", "STRING", "STRING", "STRING", "STRING")
RETURN_NAMES = (
"hardcore_position_config",
"selected_variant_keys",
"selected_positions",
"selected_variant_notes",
"summary",
"variants_json",
)
FUNCTION = "build"
CATEGORY = "prompt_builder"
def build(self, combine_mode="replace", hardcore_position_config="", atlas_cue_seed=-1, **kwargs):
variants = _selected_variant_rows(self.ACTION_FAMILY, kwargs)
if not variants:
config = _empty_or_incoming_config(hardcore_position_config or "", combine_mode)
return config, "", "", "", json.loads(config).get("summary", ""), "[]"
positions = _selected_variant_positions(variants)
family = _merged_family_for_variant_filter(
hardcore_position_config or "",
combine_mode,
self.POSITION_FAMILY or self.ACTION_FAMILY,
)
config = build_hardcore_position_pool_json(
hardcore_position_config=hardcore_position_config or "",
combine_mode=combine_mode,
family=family,
selected_positions=positions,
)
config = _merge_variant_metadata(config, variants, atlas_cue_seed=atlas_cue_seed)
parsed = json.loads(config)
selected_keys = parsed.get("krea2_variant_keys") or []
selected_positions = parsed.get("positions") or []
return (
config,
",".join(str(key) for key in selected_keys),
",".join(str(position) for position in selected_positions),
_variant_notes(variants),
str(parsed.get("summary") or ""),
json.dumps(variants, ensure_ascii=True, sort_keys=True),
)
class SxCPKrea2POVPenetrationFilter(_SxCPKrea2POVVariantFilter):
ACTION_FAMILY = "penetration"
POSITION_FAMILY = "penetration"
class SxCPKrea2POVOralFilter(_SxCPKrea2POVVariantFilter):
ACTION_FAMILY = "oral"
POSITION_FAMILY = "oral"
class SxCPKrea2POVOutercourseFilter(_SxCPKrea2POVVariantFilter):
ACTION_FAMILY = "outercourse"
POSITION_FAMILY = "outercourse"
class SxCPKrea2POVManualFilter(_SxCPKrea2POVVariantFilter):
ACTION_FAMILY = "manual"
POSITION_FAMILY = "manual"
class SxCPKrea2POVToyFilter(_SxCPKrea2POVVariantFilter):
ACTION_FAMILY = "toy"
POSITION_FAMILY = "any"
class SxCPKrea2POVClimaxFilter(_SxCPKrea2POVVariantFilter):
ACTION_FAMILY = "climax"
POSITION_FAMILY = "climax"
class SxCPKrea2POVInteractionFilter(_SxCPKrea2POVVariantFilter):
ACTION_FAMILY = "interaction"
POSITION_FAMILY = "interaction"
class SxCPKrea2POVPromptRestore:
CLOTHING_AXES = ["clothing_detail"]
FACE_EXPRESSION_AXES = ["face_detail", "expression_detail", "mouth_detail", "reaction_detail"]
BODY_TOUCH_AXES = ["body_contact", "hand_detail", "touch_detail", "foreplay_detail"]
CAMERA_PRESENTATION_AXES = ["performance_act", "visibility"]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"restore_clothing_detail": ("BOOLEAN", {"default": True}),
"restore_face_expression_detail": ("BOOLEAN", {"default": True}),
"restore_body_touch_detail": ("BOOLEAN", {"default": False}),
"restore_camera_presentation_detail": ("BOOLEAN", {"default": False}),
"relax_non_pose_axis_conflicts": ("BOOLEAN", {"default": True}),
},
"optional": {
"hardcore_position_config": (SXCP_HARDCORE_POSITION_CONFIG,),
},
}
RETURN_TYPES = (SXCP_HARDCORE_POSITION_CONFIG, "STRING")
RETURN_NAMES = ("hardcore_position_config", "summary")
FUNCTION = "build"
CATEGORY = "prompt_builder"
def build(
self,
restore_clothing_detail=True,
restore_face_expression_detail=True,
restore_body_touch_detail=False,
restore_camera_presentation_detail=False,
relax_non_pose_axis_conflicts=True,
hardcore_position_config="",
):
config = parse_hardcore_position_config(hardcore_position_config)
axes: list[str] = []
if restore_clothing_detail:
axes.extend(self.CLOTHING_AXES)
config["allow_foreplay"] = True
config["allow_interaction"] = True
if restore_face_expression_detail:
axes.extend(self.FACE_EXPRESSION_AXES)
config["allow_foreplay"] = True
config["allow_interaction"] = True
if restore_body_touch_detail:
axes.extend(self.BODY_TOUCH_AXES)
config["allow_foreplay"] = True
config["allow_interaction"] = True
if restore_camera_presentation_detail:
axes.extend(self.CAMERA_PRESENTATION_AXES)
config["allow_interaction"] = True
config["restore_prompt_axes"] = normalize_restore_prompt_axes(axes)
if config["restore_prompt_axes"]:
config["enabled"] = True
config["relax_non_pose_axis_conflicts"] = bool(relax_non_pose_axis_conflicts)
config["summary"] = hardcore_position_summary(config)
return json.dumps(config, ensure_ascii=True, sort_keys=True), str(config["summary"])
class SxCPKrea2VariantEvidence:
@classmethod
def INPUT_TYPES(cls):
keys = krea2_pose_variant_catalog.variant_keys()
return {
"required": {
"variant_key": (keys or ["missing_catalog_variant"], {"default": keys[0] if keys else "missing_catalog_variant"}),
"result": (["accepted", "rejected", "inconclusive", "any"], {"default": "accepted"}),
},
"optional": {
"variant_key_in": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "INT", "STRING")
RETURN_NAMES = (
"summary",
"baseline_image_path",
"candidate_image_path",
"evidence_json",
"seed",
"decision",
)
FUNCTION = "build"
CATEGORY = "prompt_builder"
def build(self, variant_key, result="accepted", variant_key_in=""):
key = str(variant_key_in or variant_key or "").strip()
result_filter = None if result == "any" else result
entries = krea2_eval_log.entries_for_variant(key, result=result_filter)
if not entries:
empty = {
"variant_key": key,
"result": result,
"summary": "no Krea2 eval evidence found",
}
return empty["summary"], "", "", json.dumps(empty, ensure_ascii=True, sort_keys=True), -1, ""
entry = entries[0]
summary = (
f"evidence={entry.get('id')}; variant={entry.get('variant_key')}; "
f"seed={entry.get('seed')}; result={entry.get('result')}; decision={entry.get('decision')}"
)
seed = entry.get("seed")
return (
summary,
str(entry.get("baseline_image") or ""),
str(entry.get("candidate_image") or ""),
json.dumps(entry, ensure_ascii=True, sort_keys=True),
int(seed) if isinstance(seed, int) else -1,
str(entry.get("decision") or ""),
)
class SxCPHardcoreActionFilter:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"focus": (hardcore_position_focus_choices(), {"default": "keep_pool"}),
"allow_toys": ("BOOLEAN", {"default": False}),
"allow_double": ("BOOLEAN", {"default": False}),
"allow_penetration": ("BOOLEAN", {"default": True}),
"allow_foreplay": ("BOOLEAN", {"default": True}),
"allow_interaction": ("BOOLEAN", {"default": True}),
"allow_manual": ("BOOLEAN", {"default": True}),
"allow_oral": ("BOOLEAN", {"default": True}),
"allow_outercourse": ("BOOLEAN", {"default": True}),
"allow_anal": ("BOOLEAN", {"default": True}),
"allow_climax": ("BOOLEAN", {"default": True}),
},
"optional": {
"hardcore_position_config": (SXCP_HARDCORE_POSITION_CONFIG,),
},
}
RETURN_TYPES = (SXCP_HARDCORE_POSITION_CONFIG, "STRING")
RETURN_NAMES = ("hardcore_position_config", "summary")
FUNCTION = "build"
CATEGORY = "prompt_builder"
def build(
self,
focus,
allow_toys,
allow_double,
allow_penetration,
allow_foreplay,
allow_interaction,
allow_manual,
allow_oral,
allow_outercourse,
allow_anal,
allow_climax,
hardcore_position_config="",
):
config = build_hardcore_action_filter_json(
hardcore_position_config=hardcore_position_config or "",
focus=focus,
allow_toys=allow_toys,
allow_double=allow_double,
allow_penetration=allow_penetration,
allow_foreplay=allow_foreplay,
allow_interaction=allow_interaction,
allow_manual=allow_manual,
allow_oral=allow_oral,
allow_outercourse=allow_outercourse,
allow_anal=allow_anal,
allow_climax=allow_climax,
)
return config, json.loads(config).get("summary", "")
NODE_CLASS_MAPPINGS = {
"SxCPHardcorePositionPool": SxCPHardcorePositionPool,
"SxCPHardcoreActionFilter": SxCPHardcoreActionFilter,
"SxCPKrea2PoseVariant": SxCPKrea2PoseVariant,
"SxCPKrea2POVPenetrationFilter": SxCPKrea2POVPenetrationFilter,
"SxCPKrea2POVOralFilter": SxCPKrea2POVOralFilter,
"SxCPKrea2POVOutercourseFilter": SxCPKrea2POVOutercourseFilter,
"SxCPKrea2POVManualFilter": SxCPKrea2POVManualFilter,
"SxCPKrea2POVToyFilter": SxCPKrea2POVToyFilter,
"SxCPKrea2POVClimaxFilter": SxCPKrea2POVClimaxFilter,
"SxCPKrea2POVInteractionFilter": SxCPKrea2POVInteractionFilter,
"SxCPKrea2POVPromptRestore": SxCPKrea2POVPromptRestore,
"SxCPKrea2VariantEvidence": SxCPKrea2VariantEvidence,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SxCPHardcorePositionPool": "SxCP Hardcore Position Pool",
"SxCPHardcoreActionFilter": "SxCP Hardcore Action Filter",
"SxCPKrea2PoseVariant": "SxCP Krea2 Pose Variant",
"SxCPKrea2POVPenetrationFilter": "SxCP Krea2 POV Penetration Filter",
"SxCPKrea2POVOralFilter": "SxCP Krea2 POV Oral Filter",
"SxCPKrea2POVOutercourseFilter": "SxCP Krea2 POV Outercourse Filter",
"SxCPKrea2POVManualFilter": "SxCP Krea2 POV Manual Filter",
"SxCPKrea2POVToyFilter": "SxCP Krea2 POV Toy Filter",
"SxCPKrea2POVClimaxFilter": "SxCP Krea2 POV Climax Filter",
"SxCPKrea2POVInteractionFilter": "SxCP Krea2 POV Interaction Filter",
"SxCPKrea2POVPromptRestore": "SxCP Krea2 POV Prompt Restore",
"SxCPKrea2VariantEvidence": "SxCP Krea2 Variant Evidence",
}