Files
ComfyUI-Ethanfel-Prompt-Bui…/krea_pair_formatter.py
T

391 lines
16 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import re
from typing import Any, Callable
try:
from . import krea2_pose_variant_catalog
except ImportError: # Allows local smoke tests with top-level imports.
import krea2_pose_variant_catalog
MOUTH_EXPRESSION_TERMS = ("mouth", "oral", "tongue", "lips", "gagging", "saliva", "drool")
TOP_VIEW_ORAL_VARIANT = "pov_blowjob_top_down_vertical_shaft"
SIDE_PROFILE_ORAL_VARIANT = "pov_blowjob_side_profile_oral"
ORAL_CONTACT_VARIANTS = frozenset(
(
TOP_VIEW_ORAL_VARIANT,
"pov_blowjob_side_profile_oral",
"pov_blowjob_laying_frontal_oral",
"pov_blowjob_sitting_upright_oral",
)
)
@dataclass(frozen=True)
class KreaPairFormatRequest:
row: dict[str, Any]
detail_level: str
style_mode: str
@dataclass(frozen=True)
class KreaPairPrompts:
soft_prompt: str
soft_negative: str
hard_prompt: str
hard_negative: str
def as_tuple(self) -> tuple[str, str, str, str]:
return self.soft_prompt, self.soft_negative, self.hard_prompt, self.hard_negative
@dataclass(frozen=True)
class KreaPairFormatDependencies:
clean: Callable[[Any], str]
prompt_cast_descriptors: Callable[[str], str]
pair_camera_phrase: Callable[[Any, Any, dict[str, Any]], str]
camera_scene_phrase: Callable[[dict[str, Any]], str]
style_phrase: Callable[[dict[str, Any], str], str]
sanitize_hardcore_environment_anchors: Callable[[Any], str]
sanitize_hardcore_axis_values: Callable[[Any], Any]
sanitize_scene_text_for_cast: Callable[[Any, list[str]], str]
normalize_hardcore_detail_density: Callable[[Any], str]
row_action_family: Callable[[Any], str]
hardcore_action_sentence: Callable[[str, str, str, Any, str, str], str]
pov_action_phrase: Callable[[str, list[str], str, str, str, Any, str], str]
pov_labels_from_value: Callable[[Any], list[str]]
merge_labels: Callable[..., list[str]]
cast_prose_omit: Callable[[str, list[str]], tuple[str, list[str]]]
label_join: Callable[[list[str]], str]
filter_pov_labeled_clauses: Callable[[Any, list[str]], str]
natural_label_text: Callable[[Any, list[str]], str]
expression_disabled: Callable[[dict[str, Any]], bool]
expression_phrase: Callable[[Any], str]
pov_camera_phrase: Callable[[list[str]], str]
pov_soft_camera_phrase: Callable[[list[str]], str]
pov_composition_text: Callable[[Any, list[str]], str]
softcore_cast_presence_phrase: Callable[..., str]
natural_clothing_state: Callable[[Any, str], str]
composition_phrase: Callable[..., str]
paragraph: Callable[[list[str]], str]
combine_negative: Callable[..., str]
def _list_values(value: Any) -> list[str]:
if isinstance(value, list):
return [str(item) for item in value if str(item).strip()]
if isinstance(value, str) and value.strip():
return [part.strip() for part in value.split(",") if part.strip()]
return []
def _krea2_variant_keys(row: dict[str, Any]) -> list[str]:
config = row.get("hardcore_position_config") if isinstance(row.get("hardcore_position_config"), dict) else {}
axis_values = row.get("item_axis_values") if isinstance(row.get("item_axis_values"), dict) else {}
return list(dict.fromkeys([*_list_values(config.get("krea2_variant_keys")), *_list_values(axis_values.get("krea2_variant_keys"))]))
def _has_krea2_variant(row: dict[str, Any], key: str) -> bool:
return key in _krea2_variant_keys(row)
def _has_krea2_oral_contact_variant(row: dict[str, Any]) -> bool:
return any(key in ORAL_CONTACT_VARIANTS for key in _krea2_variant_keys(row))
def _has_krea2_atlas_variant(row: dict[str, Any]) -> bool:
return any(krea2_pose_variant_catalog.get_variant(key) for key in _krea2_variant_keys(row))
def _restores_krea2_prompt_axis(row: dict[str, Any], axis_name: str) -> bool:
config = row.get("hardcore_position_config") if isinstance(row.get("hardcore_position_config"), dict) else {}
axis_values = row.get("item_axis_values") if isinstance(row.get("item_axis_values"), dict) else {}
restored_axes = [
*_list_values(config.get("restore_prompt_axes")),
*_list_values(axis_values.get("restored_prompt_axes")),
]
return axis_name in restored_axes
def _side_profile_hidden_lower_clothing_clause(clause: str) -> bool:
lower = clause.lower()
if "below the hips" in lower or "lower body" in lower:
return True
return any(
term in lower
for term in (
"panty",
"panties",
"brief",
"briefs",
"thong",
"shorts",
"jeans",
"trousers",
"pants",
"skirt",
)
) and any(term in lower for term in ("pulled aside", "removed", "lowered", "visible"))
def _side_profile_visible_clothing_clause(clause: str) -> str:
text = clause.strip(" .")
if not text:
return ""
lower = text.lower()
if lower.startswith("woman a's "):
text = text[len("Woman A's "):]
elif lower.startswith("the woman's "):
text = text[len("the woman's "):]
elif lower.startswith("her "):
text = text[len("her "):]
text = re.sub(r"\s+(?:remain|remains)\s+visible\s+from\s+the\s+same\s+outfit$", " from the same outfit", text, flags=re.IGNORECASE)
text = re.sub(r"\s+(?:remain|remains)\s+visible$", "", text, flags=re.IGNORECASE)
text = re.sub(r",\s+((?:a|an|the)\s+)", r" and \1", text, count=1, flags=re.IGNORECASE)
if not text.lower().startswith(("a ", "an ", "the ")):
text = f"the {text}"
return f"the woman wears {text}"
def _krea2_atlas_clothing_text(row: dict[str, Any], text: Any) -> str:
clothing = str(text or "").strip()
if not clothing or not _has_krea2_atlas_variant(row):
return clothing
side_profile_oral = _has_krea2_variant(row, SIDE_PROFILE_ORAL_VARIANT)
kept: list[str] = []
side_profile_removed_hidden_lower = False
for clause in clothing.split(";"):
clause = clause.strip(" .")
if not clause:
continue
lower = clause.lower()
if lower.startswith(("pov foreground clothing cue:", "pov foreground body cue:")):
continue
if side_profile_oral and _side_profile_hidden_lower_clothing_clause(clause):
side_profile_removed_hidden_lower = True
continue
if side_profile_oral:
clause = _side_profile_visible_clothing_clause(clause)
kept.append(clause)
if side_profile_oral and side_profile_removed_hidden_lower:
kept.insert(0, "The woman's lower garments are pulled aside out of frame")
return "; ".join(kept)
def _has_krea2_top_down_variant(row: dict[str, Any]) -> bool:
for key in _krea2_variant_keys(row):
variant = krea2_pose_variant_catalog.get_variant(key)
geometry = " ".join(
[str(variant.get("canonical_geometry") or ""), *[str(cue) for cue in variant.get("prompt_cues") or []]]
).lower()
if any(term in geometry for term in ("top-down", "top view", "top-view", "nadir", "overhead")):
return True
return False
def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) -> Any:
if not _has_krea2_oral_contact_variant(row):
return expression
clauses = [clause.strip() for clause in str(expression or "").split(";") if clause.strip()]
if not clauses:
return expression
kept = [
clause
for clause in clauses
if not any(term in clause.lower() for term in MOUTH_EXPRESSION_TERMS)
]
return "; ".join(kept)
def _filter_camera_scene_for_krea2_variant(row: dict[str, Any], camera_scene: Any) -> str:
text = str(camera_scene or "")
if _has_krea2_atlas_variant(row):
return ""
if (_has_krea2_oral_contact_variant(row) or _has_krea2_top_down_variant(row)) and "eye-level" in text.lower():
return ""
return text
def format_insta_pair_result(request: KreaPairFormatRequest, deps: KreaPairFormatDependencies) -> KreaPairPrompts:
row = request.row
detail_level = request.detail_level
style_mode = request.style_mode
descriptor = deps.clean(row.get("shared_descriptor"))
cast_descriptors = row.get("shared_cast_descriptors")
if isinstance(cast_descriptors, list):
cast_descriptor_text = "; ".join(deps.clean(item) for item in cast_descriptors if deps.clean(item))
else:
cast_descriptor_text = deps.clean(cast_descriptors)
cast_descriptor_text = deps.prompt_cast_descriptors(cast_descriptor_text)
soft = row.get("softcore_row") if isinstance(row.get("softcore_row"), dict) else {}
hard = row.get("hardcore_row") if isinstance(row.get("hardcore_row"), dict) else {}
soft_camera = deps.pair_camera_phrase(row.get("softcore_camera_directive"), row.get("softcore_camera_config"), soft)
hard_camera = deps.pair_camera_phrase(row.get("hardcore_camera_directive"), row.get("hardcore_camera_config"), hard)
soft_camera_scene = deps.camera_scene_phrase(soft) or deps.clean(row.get("softcore_camera_scene_directive"))
hard_camera_scene = _filter_camera_scene_for_krea2_variant(
hard,
deps.camera_scene_phrase(hard) or deps.clean(row.get("hardcore_camera_scene_directive")),
)
soft_style = deps.style_phrase(soft, style_mode)
hard_style = deps.style_phrase(hard, style_mode)
options = row.get("options") if isinstance(row.get("options"), dict) else {}
soft_level = deps.clean(options.get("softcore_level")).replace("_", " ")
hard_level = deps.clean(options.get("hardcore_level")).replace("_", " ")
same_room = options.get("continuity") == "same_creator_same_room"
hard_scene = hard.get("scene_text") or (soft.get("scene_text") if same_room else "")
hard_composition = deps.sanitize_hardcore_environment_anchors(hard.get("composition"))
hard_source_composition = deps.sanitize_hardcore_environment_anchors(hard.get("source_composition") or hard_composition)
pov_labels = deps.merge_labels(
deps.pov_labels_from_value(row.get("pov_character_labels")),
deps.pov_labels_from_value(soft.get("pov_character_labels")),
deps.pov_labels_from_value(hard.get("pov_character_labels")),
)
if pov_labels:
hard_camera = ""
if options.get("softcore_cast") == "same_as_hardcore":
soft_camera = ""
soft_cast_descriptor_text = (
cast_descriptor_text
if options.get("softcore_cast") == "same_as_hardcore"
else f"Woman A: {descriptor}"
)
soft_cast_prose, soft_labels = deps.cast_prose_omit(
soft_cast_descriptor_text,
pov_labels if options.get("softcore_cast") == "same_as_hardcore" else [],
)
hard_cast_prose, hard_labels = deps.cast_prose_omit(cast_descriptor_text, pov_labels)
soft_labels = deps.merge_labels(soft_labels, pov_labels if options.get("softcore_cast") == "same_as_hardcore" else [])
hard_labels = deps.merge_labels(hard_labels, pov_labels)
hard_item = deps.sanitize_scene_text_for_cast(
deps.sanitize_hardcore_environment_anchors(hard.get("item")),
hard_labels,
)
hard_role_graph = deps.sanitize_scene_text_for_cast(
deps.sanitize_hardcore_environment_anchors(hard.get("source_role_graph") or hard.get("role_graph")),
hard_labels,
)
hard_item = deps.natural_label_text(hard_item, hard_labels)
hard_role_graph = deps.natural_label_text(hard_role_graph, hard_labels)
hard_axis_values = deps.sanitize_hardcore_axis_values(hard.get("item_axis_values"))
hard_detail_density = deps.normalize_hardcore_detail_density(
hard.get("hardcore_detail_density") or row.get("hardcore_detail_density") or options.get("hardcore_detail_density")
)
hard_action = deps.hardcore_action_sentence(
hard_role_graph,
hard_item,
hard_source_composition,
hard_axis_values,
hard_detail_density,
deps.row_action_family(hard) or deps.row_action_family(row),
)
hard_action = deps.pov_action_phrase(
hard_action,
pov_labels,
hard_role_graph,
hard_item,
hard_source_composition,
hard_axis_values,
hard_detail_density,
)
hard_has_atlas_variant = _has_krea2_atlas_variant(hard)
hard_output_composition = "" if hard_has_atlas_variant else deps.pov_composition_text(hard_composition, pov_labels)
hard_restores_clothing = hard_has_atlas_variant and _restores_krea2_prompt_axis(hard, "clothing_detail")
hard_clothing = deps.natural_label_text(
deps.filter_pov_labeled_clauses(
deps.natural_clothing_state(row.get("hardcore_clothing_state"), hard_action),
pov_labels,
),
hard_labels,
)
hard_clothing = _krea2_atlas_clothing_text(hard, hard_clothing)
same_soft_cast = options.get("softcore_cast") == "same_as_hardcore"
soft_output_composition = deps.pov_composition_text(soft.get("composition"), pov_labels if same_soft_cast else [])
soft_cast_presence = deps.softcore_cast_presence_phrase(
same_cast=same_soft_cast,
pov_labels=pov_labels if same_soft_cast else [],
cast_label=deps.label_join(soft_labels),
woman_label="the woman",
)
partner_styling = row.get("softcore_partner_styling")
if isinstance(partner_styling, dict):
outfits = partner_styling.get("outfits")
partner_outfit_text = "; ".join(deps.clean(item) for item in outfits if deps.clean(item)) if isinstance(outfits, list) else ""
partner_pose = deps.clean(partner_styling.get("pose"))
else:
partner_outfit_text = ""
partner_pose = ""
partner_outfit_text = deps.filter_pov_labeled_clauses(partner_outfit_text, pov_labels)
if pov_labels:
partner_pose = ""
partner_outfit_text = deps.natural_label_text(partner_outfit_text, soft_labels)
soft_expression = ""
if not deps.expression_disabled(soft):
soft_expression_source = deps.filter_pov_labeled_clauses(
deps.clean(soft.get("character_expression_text")) or deps.clean(soft.get("expression")),
pov_labels,
)
soft_expression = deps.natural_label_text(
soft_expression_source,
soft_labels,
)
hard_expression = ""
if not deps.expression_disabled(hard) and not hard_has_atlas_variant:
hard_expression_source = _filter_expression_for_krea2_variant(
hard,
deps.clean(hard.get("character_expression_text")) or deps.clean(hard.get("expression")),
)
hard_expression_source = deps.filter_pov_labeled_clauses(
hard_expression_source,
pov_labels,
)
hard_expression = deps.natural_label_text(
hard_expression_source,
hard_labels,
)
soft_item = deps.clean(soft.get("item"))
soft_item_label = deps.clean(soft.get("softcore_item_prompt_label"))
soft_item_phrase = ""
if soft_item:
soft_item_phrase = f"body exposure: {soft_item}" if soft_item_label == "Body exposure" else f"wearing {soft_item}"
soft_parts = [
soft_cast_prose,
soft_cast_presence,
partner_outfit_text,
partner_pose,
deps.pov_soft_camera_phrase(pov_labels) if same_soft_cast else "",
soft_item_phrase,
f"{soft.get('pose')}" if soft.get("pose") else "",
deps.expression_phrase(soft_expression),
f"in {soft.get('scene_text')}" if soft.get("scene_text") else "",
soft_camera_scene,
deps.composition_phrase(soft_output_composition),
soft_camera,
soft_style if detail_level != "concise" else "",
]
hard_parts = [
hard_cast_prose,
hard_action,
deps.pov_camera_phrase(pov_labels),
hard_clothing if (not hard_has_atlas_variant or hard_restores_clothing) else "",
f"set in {hard_scene}" if hard_scene else "",
hard_camera_scene,
deps.expression_phrase(hard_expression),
deps.composition_phrase(hard_output_composition, hard_action, detail_density=hard_detail_density),
hard_camera,
"" if hard_has_atlas_variant else hard_style if detail_level != "concise" else "",
]
return KreaPairPrompts(
soft_prompt=deps.paragraph(soft_parts),
soft_negative=deps.combine_negative(row.get("softcore_negative_prompt")),
hard_prompt=deps.paragraph(hard_parts),
hard_negative=deps.combine_negative(row.get("hardcore_negative_prompt")),
)
def format_insta_pair(request: KreaPairFormatRequest, deps: KreaPairFormatDependencies) -> tuple[str, str, str, str]:
return format_insta_pair_result(request, deps).as_tuple()