Extract row expression policy

This commit is contained in:
2026-06-27 08:56:35 +02:00
parent e5822e42f8
commit 3d9dbdc95d
5 changed files with 412 additions and 222 deletions
+23 -217
View File
@@ -41,6 +41,7 @@ try:
from . import pov_policy
from . import row_normalization as row_policy
from . import row_camera as row_camera_policy
from . import row_expression as row_expression_policy
from . import row_location as row_location_policy
from . import row_pools as row_pool_policy
from . import seed_config as seed_policy
@@ -85,6 +86,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
import pov_policy
import row_normalization as row_policy
import row_camera as row_camera_policy
import row_expression as row_expression_policy
import row_location as row_location_policy
import row_pools as row_pool_policy
import seed_config as seed_policy
@@ -1135,44 +1137,15 @@ def _format(template: str, context: dict[str, Any]) -> str:
def _clean_prompt_punctuation(text: str) -> str:
text = re.sub(r"\s+", " ", str(text or "")).strip()
text = re.sub(r"\s+([,.;:])", r"\1", text)
text = re.sub(r"(?:,\s*){2,}", ", ", text)
text = re.sub(r"\.\s*\.", ".", text)
text = re.sub(r":\s*\.", ".", text)
return text.strip()
return row_expression_policy.clean_prompt_punctuation(text)
def _strip_expression_text(text: str, expression: Any = "") -> str:
text = str(text or "")
if not text:
return ""
text = re.sub(r"\s*Facial expressions?:\s*[^.]*\.\s*", " ", text, flags=re.IGNORECASE)
text = re.sub(r",\s*one with [^,]+ and the other with [^,]+(?=,)", "", text, flags=re.IGNORECASE)
text = re.sub(r",\s*a lively mix of expressions from [^,]+(?=,)", "", text, flags=re.IGNORECASE)
text = re.sub(r"\s+with\s+(?:an?|the)\s+[^,]*expression(?=,)", "", text, flags=re.IGNORECASE)
expression_text = str(expression or "").strip()
if expression_text:
for part in [piece.strip() for piece in expression_text.split(";") if piece.strip()]:
escaped = re.escape(part)
text = re.sub(rf",\s*{escaped}(?=,)", "", text, flags=re.IGNORECASE)
text = re.sub(rf"\s+with\s+(?:an?|the)?\s*{escaped}", "", text, flags=re.IGNORECASE)
return _clean_prompt_punctuation(text)
return row_expression_policy.strip_expression_text(text, expression)
def _disable_row_expression(row: dict[str, Any], source: str = "disabled") -> dict[str, Any]:
previous_expression = row.get("expression", "")
row["prompt"] = _strip_expression_text(row.get("prompt", ""), previous_expression)
row["caption"] = _strip_expression_text(row.get("caption", ""), previous_expression)
row["expression"] = ""
row["shared_expression"] = ""
row["character_expressions"] = []
row["character_expression_text"] = ""
row["expression_enabled"] = False
row["expression_disabled"] = True
row["expression_intensity"] = None
row["expression_intensity_source"] = source
return row
return row_expression_policy.disable_row_expression(row, source)
def _prepend_trigger(prompt: str, trigger: str, enabled: bool) -> str:
@@ -1837,10 +1810,6 @@ def _slot_effective_figure(
return character_slot_policy.slot_effective_figure(slot, subject_type, fallback_figure)
def _mean(values: list[float]) -> float:
return sum(values) / len(values)
def _cast_expression_intensity_override(
fallback: float,
label_map: dict[str, dict[str, Any]],
@@ -1848,35 +1817,13 @@ def _cast_expression_intensity_override(
men_count: int,
expression_phase: str = "",
) -> tuple[float | None, str]:
groups: list[tuple[str, list[str]]] = [
("women", [f"Woman {chr(ord('A') + index)}" for index in range(max(0, women_count))]),
("men", [f"Man {chr(ord('A') + index)}" for index in range(max(0, men_count))]),
]
all_values: list[float] = []
matching_slots: list[dict[str, Any]] = []
for group_name, labels in groups:
values: list[float] = []
value_labels: list[str] = []
for label in labels:
slot = label_map.get(label)
if _slot_is_pov(slot):
continue
if slot:
matching_slots.append(slot)
value = _slot_expression_intensity_for_phase(slot, expression_phase)
if value is not None:
values.append(value)
value_labels.append(label)
all_values.append(value)
if values:
if len(values) == 1:
return values[0], f"character_slot:{value_labels[0]}"
return _mean(values), f"character_slots:{group_name}"
if all_values:
return _mean(all_values), "character_slots:cast"
if matching_slots and all(not _slot_expression_enabled(slot) for slot in matching_slots):
return None, "character_slots:disabled"
return fallback, "input"
return row_expression_policy.cast_expression_intensity_override(
fallback,
label_map,
women_count,
men_count,
expression_phase,
)
def _character_expression_entries(
@@ -1888,41 +1835,15 @@ def _character_expression_entries(
men_count: int,
expression_phase: str = "",
) -> list[str]:
labels = [
*[f"Woman {chr(ord('A') + index)}" for index in range(max(0, women_count))],
*[f"Man {chr(ord('A') + index)}" for index in range(max(0, men_count))],
]
expressions: list[str] = []
used: set[str] = set()
for label in labels:
slot = label_map.get(label)
if not slot:
continue
if _slot_is_pov(slot):
continue
if not _slot_expression_enabled(slot):
continue
intensity = _slot_expression_intensity_for_phase(slot, expression_phase)
if intensity is None:
intensity = fallback_intensity
entries = _compatible_entries(
_expression_entries_for_intensity(expression_pool, intensity),
women_count,
men_count,
)
if not entries:
continue
choice = ""
for _attempt in range(5):
candidate = _choose_text(rng, entries)
if candidate not in used:
choice = candidate
break
if not choice:
choice = _choose_text(rng, entries)
used.add(choice)
expressions.append(f"{label} has {choice}")
return expressions
return row_expression_policy.character_expression_entries(
rng,
expression_pool,
fallback_intensity,
label_map,
women_count,
men_count,
expression_phase,
)
def _sanitize_character_expression_text_for_action(
@@ -2521,126 +2442,11 @@ def _expression_pool(category: dict[str, Any], subcategory: dict[str, Any], item
def _expression_intensity_hint(entry: Any) -> float:
if isinstance(entry, dict):
for key in ("expression_intensity", "intensity"):
if key in entry:
return _clamped_float(entry[key], 0.5)
text = _entry_text(entry).lower()
high_terms = (
"ahegao",
"orgasm",
"climax",
"drool",
"drooling",
"tongue out",
"eyes rolled",
"fucked-out",
"cum-smeared",
"saliva",
"gagging",
"slack jaw",
"jaw slack",
"slack-jawed",
"sex-drunk",
"overwhelmed",
"strained",
"messy",
"panting",
"trembling",
"shaking",
"wide open mouth",
"raw ",
"wild ",
"dazed",
"spent",
)
if any(term in text for term in high_terms):
return 0.9
medium_terms = (
"seductive",
"teasing",
"lustful",
"aroused",
"bedroom",
"dominant",
"predatory",
"control",
"stern",
"strict",
"smirk",
"parted lips",
"open-mouthed",
"heated",
"hungry",
"inviting",
"sensual",
"fetish",
"commanding",
"flushed",
"moan",
)
if any(term in text for term in medium_terms):
return 0.62
low_terms = (
"neutral",
"quiet",
"calm",
"reserved",
"relaxed",
"candid",
"closed-mouth",
"thoughtful",
"controlled",
"focused",
"steady",
"bitten-lip",
"braced",
"held breath",
"concentrated",
"aloof",
"bored",
"tired",
"unfocused",
"contented",
"fashion",
"soft",
"sleepy",
"fresh-faced",
)
if any(term in text for term in low_terms):
return 0.25
return 0.5
return row_expression_policy.expression_intensity_hint(entry)
def _expression_entries_for_intensity(entries: list[Any], expression_intensity: float) -> list[Any]:
target = _clamped_float(expression_intensity, 0.5)
weighted: list[Any] = []
for entry in entries:
entry_intensity = _expression_intensity_hint(entry)
distance = abs(target - entry_intensity)
if distance <= 0.18:
intensity_weight = 4.0
elif distance <= 0.35:
intensity_weight = 1.4
elif distance <= 0.55:
intensity_weight = 0.35
else:
intensity_weight = 0.05
if isinstance(entry, dict):
adjusted = dict(entry)
try:
base_weight = float(adjusted.get("weight", 1.0))
except (TypeError, ValueError):
base_weight = 1.0
adjusted["weight"] = max(0.0, base_weight) * intensity_weight
weighted.append(adjusted)
else:
weighted.append({"text": _entry_text(entry), "weight": intensity_weight})
return weighted or entries
return row_expression_policy.expression_entries_for_intensity(entries, expression_intensity)
def _pose_pool(category: dict[str, Any], subcategory: dict[str, Any], item: Any, subject_type: str, poses: str) -> list[Any]: