Files
ComfyUI-Ethanfel-Prompt-Bui…/row_expression.py
T

400 lines
12 KiB
Python

from __future__ import annotations
import random
import re
from typing import Any
try:
from . import category_library as category_policy
from . import character_slot as character_slot_policy
from . import pov_policy
except ImportError: # Allows local smoke tests with top-level imports.
import category_library as category_policy
import character_slot as character_slot_policy
import pov_policy
def clean_prompt_punctuation(text: str) -> str:
text = re.sub(r"\s+", " ", str(text or "")).strip()
text = re.sub(r"\s+([,.;:])", r"\1", text)
text = re.sub(r"(?:,\s*){2,}", ", ", text)
text = re.sub(r"\.\s*\.", ".", text)
text = re.sub(r":\s*\.", ".", text)
return text.strip()
def strip_expression_text(text: str, expression: Any = "") -> str:
text = str(text or "")
if not text:
return ""
text = re.sub(r"\s*Facial expressions?:\s*[^.]*\.\s*", " ", text, flags=re.IGNORECASE)
text = re.sub(r",\s*one with [^,]+ and the other with [^,]+(?=,)", "", text, flags=re.IGNORECASE)
text = re.sub(r",\s*a lively mix of expressions from [^,]+(?=,)", "", text, flags=re.IGNORECASE)
text = re.sub(r"\s+with\s+(?:an?|the)\s+[^,]*expression(?=,)", "", text, flags=re.IGNORECASE)
expression_text = str(expression or "").strip()
if expression_text:
for part in [piece.strip() for piece in expression_text.split(";") if piece.strip()]:
escaped = re.escape(part)
text = re.sub(rf",\s*{escaped}(?=,)", "", text, flags=re.IGNORECASE)
text = re.sub(rf"\s+with\s+(?:an?|the)?\s*{escaped}", "", text, flags=re.IGNORECASE)
return clean_prompt_punctuation(text)
def disable_row_expression(row: dict[str, Any], source: str = "disabled") -> dict[str, Any]:
previous_expression = row.get("expression", "")
row["prompt"] = strip_expression_text(row.get("prompt", ""), previous_expression)
row["caption"] = strip_expression_text(row.get("caption", ""), previous_expression)
row["expression"] = ""
row["shared_expression"] = ""
row["character_expressions"] = []
row["character_expression_text"] = ""
row["expression_enabled"] = False
row["expression_disabled"] = True
row["expression_intensity"] = None
row["expression_intensity_source"] = source
return row
def _clamped_float(value: Any, default: float = 0.5, min_value: float = 0.0, max_value: float = 1.0) -> float:
try:
number = float(value)
except (TypeError, ValueError):
return default
return max(min_value, min(max_value, number))
def _entry_text(entry: Any) -> str:
return category_policy._entry_text(entry)
def expression_intensity_hint(entry: Any) -> float:
if isinstance(entry, dict):
for key in ("expression_intensity", "intensity"):
if key in entry:
return _clamped_float(entry[key], 0.5)
text = _entry_text(entry).lower()
high_terms = (
"ahegao",
"orgasm",
"climax",
"drool",
"drooling",
"tongue out",
"eyes rolled",
"fucked-out",
"cum-smeared",
"saliva",
"gagging",
"slack jaw",
"jaw slack",
"slack-jawed",
"sex-drunk",
"overwhelmed",
"strained",
"messy",
"panting",
"trembling",
"shaking",
"wide open mouth",
"raw ",
"wild ",
"dazed",
"spent",
)
if any(term in text for term in high_terms):
return 0.9
medium_terms = (
"seductive",
"teasing",
"lustful",
"aroused",
"bedroom",
"dominant",
"predatory",
"control",
"stern",
"strict",
"smirk",
"parted lips",
"open-mouthed",
"heated",
"hungry",
"inviting",
"sensual",
"fetish",
"commanding",
"flushed",
"moan",
)
if any(term in text for term in medium_terms):
return 0.62
low_terms = (
"neutral",
"quiet",
"calm",
"reserved",
"relaxed",
"candid",
"closed-mouth",
"thoughtful",
"controlled",
"focused",
"steady",
"bitten-lip",
"braced",
"held breath",
"concentrated",
"aloof",
"bored",
"tired",
"unfocused",
"contented",
"fashion",
"soft",
"sleepy",
"fresh-faced",
)
if any(term in text for term in low_terms):
return 0.25
return 0.5
def expression_entries_for_intensity(entries: list[Any], expression_intensity: float) -> list[Any]:
target = _clamped_float(expression_intensity, 0.5)
weighted: list[Any] = []
for entry in entries:
entry_intensity = expression_intensity_hint(entry)
distance = abs(target - entry_intensity)
if distance <= 0.18:
intensity_weight = 4.0
elif distance <= 0.35:
intensity_weight = 1.4
elif distance <= 0.55:
intensity_weight = 0.35
else:
intensity_weight = 0.05
if isinstance(entry, dict):
adjusted = dict(entry)
try:
base_weight = float(adjusted.get("weight", 1.0))
except (TypeError, ValueError):
base_weight = 1.0
adjusted["weight"] = max(0.0, base_weight) * intensity_weight
weighted.append(adjusted)
else:
weighted.append({"text": _entry_text(entry), "weight": intensity_weight})
return weighted or entries
def _mean(values: list[float]) -> float:
return sum(values) / len(values)
def cast_expression_intensity_override(
fallback: float,
label_map: dict[str, dict[str, Any]],
women_count: int,
men_count: int,
expression_phase: str = "",
) -> tuple[float | None, str]:
groups: list[tuple[str, list[str]]] = [
("women", [f"Woman {chr(ord('A') + index)}" for index in range(max(0, women_count))]),
("men", [f"Man {chr(ord('A') + index)}" for index in range(max(0, men_count))]),
]
all_values: list[float] = []
matching_slots: list[dict[str, Any]] = []
for group_name, labels in groups:
values: list[float] = []
value_labels: list[str] = []
for label in labels:
slot = label_map.get(label)
if pov_policy.slot_is_pov(slot):
continue
if slot:
matching_slots.append(slot)
value = character_slot_policy.slot_expression_intensity_for_phase(slot, expression_phase)
if value is not None:
values.append(value)
value_labels.append(label)
all_values.append(value)
if values:
if len(values) == 1:
return values[0], f"character_slot:{value_labels[0]}"
return _mean(values), f"character_slots:{group_name}"
if all_values:
return _mean(all_values), "character_slots:cast"
if matching_slots and all(not character_slot_policy.slot_expression_enabled(slot) for slot in matching_slots):
return None, "character_slots:disabled"
return fallback, "input"
def _weighted_choice(rng: random.Random, items: list[Any]) -> Any:
if not items:
raise ValueError("Cannot choose from an empty list")
weights: list[float] = []
for item in items:
weight = item.get("weight", 1.0) if isinstance(item, dict) else 1.0
try:
weights.append(max(0.0, float(weight)))
except (TypeError, ValueError):
weights.append(1.0)
total = sum(weights)
if total <= 0:
return items[rng.randrange(len(items))]
pick = rng.random() * total
running = 0.0
for item, weight in zip(items, weights):
running += weight
if pick <= running:
return item
return items[-1]
def _choose_text(rng: random.Random, items: list[Any]) -> str:
return _entry_text(_weighted_choice(rng, items))
def character_expression_entries(
rng: random.Random,
expression_pool: list[Any],
fallback_intensity: float,
label_map: dict[str, dict[str, Any]],
women_count: int,
men_count: int,
expression_phase: str = "",
) -> list[str]:
labels = [
*[f"Woman {chr(ord('A') + index)}" for index in range(max(0, women_count))],
*[f"Man {chr(ord('A') + index)}" for index in range(max(0, men_count))],
]
expressions: list[str] = []
used: set[str] = set()
for label in labels:
slot = label_map.get(label)
if not slot:
continue
if pov_policy.slot_is_pov(slot):
continue
if not character_slot_policy.slot_expression_enabled(slot):
continue
intensity = character_slot_policy.slot_expression_intensity_for_phase(slot, expression_phase)
if intensity is None:
intensity = fallback_intensity
entries = category_policy.compatible_entries(
expression_entries_for_intensity(expression_pool, intensity),
women_count,
men_count,
)
if not entries:
continue
choice = ""
for _attempt in range(5):
candidate = _choose_text(rng, entries)
if candidate not in used:
choice = candidate
break
if not choice:
choice = _choose_text(rng, entries)
used.add(choice)
expressions.append(f"{label} has {choice}")
return expressions
def sanitize_character_expression_text_for_action(
expression_text: str,
role_graph: Any,
item: Any,
axis_values: Any = None,
) -> str:
text = str(expression_text or "").strip()
if not text:
return ""
context = " ".join(
str(part or "").lower()
for part in (
role_graph,
item,
*((axis_values or {}).values() if isinstance(axis_values, dict) else ()),
)
)
woman_active_outercourse = (
re.search(r"\bwoman [a-z]\b", context)
and re.search(r"\bman [a-z]\b", context)
and any(
term in context
for term in (
"boobjob",
"titjob",
"breast sex",
"breasts tightly",
"testicle",
"balls-licking",
"balls licking",
"penis-licking",
"penis licking",
"handjob",
"hand job",
"footjob",
)
)
)
woman_gives_oral = (
re.search(r"\bwoman [a-z]\b", context)
and re.search(r"\bman [a-z]\b", context)
and any(
term in context
for term in (
"takes man",
"penis in her mouth",
"mouth at penis level",
"fellatio",
"blowjob",
"deepthroat",
"penis sucking",
"lips wrapped",
)
)
)
man_gives_oral = (
re.search(r"\bwoman [a-z]\b", context)
and re.search(r"\bman [a-z]\b", context)
and any(
term in context
for term in (
"mouth on her pussy",
"mouth on woman",
"mouth pressed to her pussy",
"cunnilingus",
"pussy licking",
"tongue on pussy",
)
)
)
mouth_expression_terms = ("mouth", "oral", "tongue", "lips", "gagging", "saliva")
clauses = [clause.strip() for clause in text.split(";") if clause.strip()]
if woman_active_outercourse:
clauses = [clause for clause in clauses if not re.match(r"^Man [A-Z] has\b", clause)]
if woman_gives_oral:
clauses = [
clause
for clause in clauses
if not (
re.match(r"^Man [A-Z] has\b", clause)
and any(term in clause.lower() for term in mouth_expression_terms)
)
]
if man_gives_oral:
clauses = [
clause
for clause in clauses
if not (
re.match(r"^Woman [A-Z] has\b", clause)
and any(term in clause.lower() for term in mouth_expression_terms)
)
]
return "; ".join(clauses)