Extract caption text policy

This commit is contained in:
2026-06-27 11:58:18 +02:00
parent 2605fae3eb
commit f1567118b4
5 changed files with 396 additions and 189 deletions
+36 -183
View File
@@ -1,21 +1,18 @@
from __future__ import annotations
import re
from typing import Any
try:
from . import caption_metadata_routes
from . import caption_policy
from . import caption_text_policy
from . import formatter_input as input_policy
from . import krea_cast as cast_policy
from . import route_metadata as route_metadata_policy
from .prompt_hygiene import sanitize_prose_text
except ImportError: # Allows local smoke tests with `python -c`.
import caption_metadata_routes
import caption_policy
import caption_text_policy
import formatter_input as input_policy
import krea_cast as cast_policy
import route_metadata as route_metadata_policy
from prompt_hygiene import sanitize_prose_text
@@ -23,125 +20,86 @@ OLD_TRIGGER = caption_policy.OLD_TRIGGER
DEFAULT_TRIGGER = caption_policy.DEFAULT_TRIGGER
STYLE_TAILS = caption_policy.STYLE_TAILS
PROMPT_FIELD_LABELS = input_policy.prompt_field_labels()
PROMPT_FIELD_LABELS = caption_text_policy.PROMPT_FIELD_LABELS
ITEM_LABELS = caption_policy.ITEM_LABELS
ACTION_FAMILY_CAPTION_LABELS = caption_policy.ACTION_FAMILY_CAPTION_LABELS
POSITION_FAMILY_CAPTION_LABELS = caption_policy.POSITION_FAMILY_CAPTION_LABELS
def _clean_text(value: Any) -> str:
return input_policy.clean_text(value)
return caption_text_policy.clean_text(value)
def _is_false(value: Any) -> bool:
if isinstance(value, bool):
return value is False
if isinstance(value, str):
return value.strip().lower() in ("false", "0", "no", "off")
return False
return caption_text_policy.is_false(value)
def _expression_disabled(row: dict[str, Any]) -> bool:
return bool(row.get("expression_disabled")) or _is_false(row.get("expression_enabled", True))
return caption_text_policy.expression_disabled(row)
def _cap_first(text: str) -> str:
text = _clean_text(text).strip(" ,")
return text[:1].upper() + text[1:] if text else ""
return caption_text_policy.cap_first(text)
def _article(noun_phrase: str) -> str:
word = noun_phrase.lstrip().lower()
if word.startswith("hour") or word[:1] in "aeiou":
return "an"
return "a"
return caption_text_policy.article(noun_phrase)
def _sentence(text: str) -> str:
text = _clean_text(text).strip(" ,;")
if not text:
return ""
if text[-1] not in ".!?":
text += "."
return _cap_first(text)
return caption_text_policy.sentence(text)
def _join_sentences(parts: list[str]) -> str:
return " ".join(part for part in (_sentence(part) for part in parts) if part)
return caption_text_policy.join_sentences(parts)
def _formatter_hint_parts(row: dict[str, Any]) -> list[str]:
hints: list[str] = []
if not isinstance(row, dict):
return hints
for hint in route_metadata_policy.row_formatter_hints(row, "caption"):
hint = _clean_text(hint).strip(" .")
if hint and hint not in hints:
hints.append(hint)
return hints
return caption_text_policy.formatter_hint_parts(row)
def _append_formatter_hints(prose: str, row: dict[str, Any]) -> str:
hints = _formatter_hint_parts(row)
if not hints:
return prose
return _join_sentences([prose, *hints])
return caption_text_policy.append_formatter_hints(prose, row)
def _human_join(parts: list[str]) -> str:
parts = [part for part in (_clean_text(part) for part in parts) if part]
if len(parts) <= 1:
return "".join(parts)
if len(parts) == 2:
return f"{parts[0]} and {parts[1]}"
return f"{', '.join(parts[:-1])}, and {parts[-1]}"
return caption_text_policy.human_join(parts)
def _metadata_action_label(row: dict[str, Any], default: str = "sexual pose") -> str:
return caption_policy.metadata_action_label(row, default)
return caption_text_policy.metadata_action_label(row, default)
def _prompt_cast_descriptors(text: str) -> str:
return cast_policy.prompt_cast_descriptors(text)
return caption_text_policy.prompt_cast_descriptors(text)
def _cast_entries(text: str) -> list[tuple[str, str]]:
return cast_policy.cast_entries(text)
return caption_text_policy.cast_entries(text)
def _natural_cast_descriptor_text(text: str) -> str:
return cast_policy.natural_cast_descriptor_text(text)
return caption_text_policy.natural_cast_descriptor_text(text)
def _cast_labels(text: str) -> list[str]:
return cast_policy.cast_labels(text)
return caption_text_policy.cast_labels(text)
def _natural_label_text(text: Any, labels: list[str]) -> str:
return cast_policy.natural_label_text(text, labels, capitalize_sentence_starts=False)
return caption_text_policy.natural_label_text(text, labels)
def _strip_style_tail(text: str) -> str:
return caption_policy.strip_style_tail(text)
return caption_text_policy.strip_style_tail(text)
def _remove_trigger(text: str, trigger: str) -> str:
return input_policy.strip_trigger_prefix(
text,
(trigger, OLD_TRIGGER, DEFAULT_TRIGGER),
remove_exact=True,
)
return caption_text_policy.remove_trigger(text, trigger)
def _with_trigger(text: str, trigger: str, include_trigger: bool) -> str:
text = _join_sentences([text]) if "." not in text else _clean_text(text)
trigger = _clean_text(trigger or DEFAULT_TRIGGER)
if not include_trigger or not trigger:
return text
if text.lower().startswith(trigger.lower() + "."):
return text
return f"{trigger}. {text}"
return caption_text_policy.with_trigger(text, trigger, include_trigger)
def _maybe_json(text: str) -> dict[str, Any] | None:
@@ -153,164 +111,59 @@ def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> t
def _prompt_field(text: str, label: str) -> str:
return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS)
return caption_text_policy.prompt_field(text, label)
def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str:
return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS)
return caption_text_policy.row_value(row, key, labels)
def _field_from_any_prompt(text: str, labels: tuple[str, ...]) -> str:
for label in labels:
value = input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS)
if value:
return value
return ""
return caption_text_policy.field_from_any_prompt(text, labels)
def _normalize_composition(text: str) -> str:
return caption_policy.normalize_composition(text)
return caption_text_policy.normalize_composition(text)
def _clean_clothing(text: str) -> str:
return caption_policy.clean_clothing(text)
return caption_text_policy.clean_clothing(text)
def _body_phrase(body: Any, figure_note: Any = "") -> str:
body = _clean_text(body)
figure_note = _clean_text(figure_note)
if not body:
return figure_note
if not figure_note:
return f"{body} figure"
if "figure" in figure_note.lower():
return f"{body} build and {figure_note}"
return f"{body} figure with {figure_note}"
return caption_text_policy.body_phrase(body, figure_note)
def _single_caption_front(row: dict[str, Any]) -> dict[str, str]:
caption = _clean_text(row.get("caption"))
if not caption:
return {}
caption = _remove_trigger(_strip_style_tail(caption), _clean_text(row.get("trigger")) or DEFAULT_TRIGGER)
caption = _remove_trigger(caption, OLD_TRIGGER)
subject = _clean_text(row.get("primary_subject"))
age = _clean_text(row.get("age_band") or row.get("age"))
body_phrase = _clean_text(row.get("body_phrase"))
if not body_phrase:
body = _clean_text(row.get("body_type") or row.get("body"))
figure = _clean_text(row.get("figure"))
body_phrase = _body_phrase(body, figure)
front = f"{subject}, {age}, {body_phrase}, "
if subject in ("woman", "man") and age and body_phrase and caption.startswith(front):
try:
skin, hair, eyes, _rest = caption[len(front) :].split(", ", 3)
except ValueError:
return {}
else:
pieces = [piece.strip() for piece in caption.split(", ", 6)]
if len(pieces) < 7:
return {}
subject, age, body_phrase, skin, hair, eyes, _rest = pieces
if subject not in ("woman", "man"):
return {}
return {
"caption_subject": subject,
"caption_age": age,
"caption_body_phrase": body_phrase,
"caption_skin": skin,
"caption_hair": hair,
"caption_eyes": eyes,
}
return caption_text_policy.single_caption_front(row)
def _pose_clause(pose: str) -> str:
pose = _clean_text(pose)
if not pose:
return ""
first = pose.split(None, 1)[0].lower()
if first.endswith("ing") or first in ("seated", "reclined", "posed"):
return pose
return f"posing in {pose}"
return caption_text_policy.pose_clause(pose)
def _age_subject(age: str, subject: str) -> str:
age = _clean_text(age)
subject = _clean_text(subject) or "person"
if not age:
return f"An adult {subject}"
clean_age = re.sub(r"\s+adults?$", "", age).strip()
if "year-old" in clean_age:
return f"A {clean_age} adult {subject}"
if re.search(r"\d", clean_age):
poss = "her" if subject == "woman" else "his"
return f"An adult {subject} in {poss} {clean_age}"
return f"An adult {clean_age} {subject}"
return caption_text_policy.age_subject(age, subject)
def _clean_age_phrase(age: str) -> str:
age = _clean_text(age)
age = re.sub(r"\s+adults?$", "", age).strip()
return age.replace("-year-old", " years old")
return caption_text_policy.clean_age_phrase(age)
def _subject_phrase_from_counts(row: dict[str, Any]) -> str:
subject = _clean_text(row.get("subject_phrase"))
if subject:
return subject
try:
women = int(row.get("women_count") or 0)
men = int(row.get("men_count") or 0)
except (TypeError, ValueError):
return _clean_text(row.get("primary_subject")) or "adult scene"
parts = []
if women:
parts.append(f"{women} adult {'woman' if women == 1 else 'women'}")
if men:
parts.append(f"{men} adult {'man' if men == 1 else 'men'}")
if not parts:
return _clean_text(row.get("primary_subject")) or "adult scene"
return " and ".join(parts)
return caption_text_policy.subject_phrase_from_counts(row)
def _verb_for_row(row: dict[str, Any]) -> str:
try:
return "is" if int(row.get("person_count") or 0) == 1 else "are"
except (TypeError, ValueError):
return "are"
return caption_text_policy.verb_for_row(row)
def _detail_allows(level: str, dense_only: bool = False) -> bool:
return caption_policy.detail_allows(level, dense_only=dense_only)
return caption_text_policy.detail_allows(level, dense_only=dense_only)
def _caption_metadata_route_dependencies() -> caption_metadata_routes.CaptionMetadataRouteDependencies:
return caption_metadata_routes.CaptionMetadataRouteDependencies(
item_labels=ITEM_LABELS,
clean_text=_clean_text,
row_value=_row_value,
field_row_value=lambda row, key: _row_value(row, key),
clean_clothing=_clean_clothing,
normalize_composition=_normalize_composition,
expression_disabled=_expression_disabled,
detail_allows=_detail_allows,
join_sentences=_join_sentences,
human_join=_human_join,
article=_article,
cap_first=_cap_first,
body_phrase=_body_phrase,
single_caption_front=_single_caption_front,
pose_clause=_pose_clause,
age_subject=_age_subject,
clean_age_phrase=_clean_age_phrase,
subject_phrase_from_counts=_subject_phrase_from_counts,
verb_for_row=_verb_for_row,
metadata_action_label=_metadata_action_label,
natural_cast_descriptor_text=_natural_cast_descriptor_text,
cast_labels=_cast_labels,
natural_label_text=_natural_label_text,
metadata_to_prose=_metadata_to_prose,
)
return caption_text_policy.metadata_route_dependencies(_metadata_to_prose)
def _caption_metadata_route_request(