from __future__ import annotations import json import re from typing import Any try: from . import row_normalization as row_normalization_policy except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`. import row_normalization as row_normalization_policy DEFAULT_PROMPT_FIELD_LABELS = ( "Ages", "Body types", "Cast", "Cast descriptors", "Characters", "Softcore setup", "Hardcore setup", "POV participant", "Body exposure", "Scene", "Setting", "Pose", "Sexual pose", "Sexual scene", "Facial expression", "Facial expressions", "Clothing", "Clothing state", "Visual clothing state", "Outfit", "Erotic outfit", "Teaser outfit detail", "Softcore visual reference", "Visible remaining styling", "Prop/detail", "Composition", "Role graph", "Camera", "Camera control", "Use", "Avoid", ) INPUT_HINT_AUTO = "auto" INPUT_HINT_METADATA = "metadata_json" INPUT_HINT_PROMPT = "prompt" INPUT_HINT_CAPTION_OR_PROMPT = "caption_or_prompt" TEXT_INPUT_HINTS = (INPUT_HINT_PROMPT, INPUT_HINT_CAPTION_OR_PROMPT) FORMATTER_INPUT_HINTS = (INPUT_HINT_AUTO, INPUT_HINT_METADATA, INPUT_HINT_PROMPT, INPUT_HINT_CAPTION_OR_PROMPT) METADATA_INPUT_HINTS = (INPUT_HINT_AUTO, INPUT_HINT_METADATA) _INPUT_HINT_ALIASES = { "caption": INPUT_HINT_CAPTION_OR_PROMPT, "caption_prompt": INPUT_HINT_CAPTION_OR_PROMPT, "caption_or_text": INPUT_HINT_CAPTION_OR_PROMPT, "metadata": INPUT_HINT_METADATA, "metadata json": INPUT_HINT_METADATA, "source_json": INPUT_HINT_AUTO, "source text": INPUT_HINT_PROMPT, "source_text": INPUT_HINT_PROMPT, "text": INPUT_HINT_PROMPT, } def prompt_field_labels() -> tuple[str, ...]: return DEFAULT_PROMPT_FIELD_LABELS def clean_text(value: Any) -> str: text = "" if value is None else str(value) text = text.replace("\n", " ") text = re.sub(r"\s+", " ", text).strip() text = re.sub(r"\s+([,.;:])", r"\1", text) return text def maybe_json(text: Any) -> dict[str, Any] | None: text = clean_text(text) if not text.startswith("{"): return None try: value = json.loads(text) except json.JSONDecodeError: return None return value if isinstance(value, dict) else None def normalize_input_metadata(row: dict[str, Any]) -> dict[str, Any]: row = dict(row) trigger = str(row.get("trigger") or "").strip() if is_pair_metadata(row): return row_normalization_policy.normalize_pair_metadata(row, active_trigger=trigger) return row_normalization_policy.sanitize_metadata_row_text(row, active_trigger=trigger) def is_pair_metadata(row: Any) -> bool: if not isinstance(row, dict): return False soft_side = ( isinstance(row.get("softcore_row"), dict) or bool(clean_text(row.get("softcore_prompt"))) or bool(clean_text(row.get("softcore_caption"))) ) hard_side = ( isinstance(row.get("hardcore_row"), dict) or bool(clean_text(row.get("hardcore_prompt"))) or bool(clean_text(row.get("hardcore_caption"))) ) return soft_side and hard_side def normalize_input_hint(value: Any, *, text_hint: str = INPUT_HINT_PROMPT) -> str: hint = clean_text(value).lower().replace("-", "_") hint = _INPUT_HINT_ALIASES.get(hint, hint) if hint in (INPUT_HINT_AUTO, INPUT_HINT_METADATA): return hint if hint in TEXT_INPUT_HINTS: return text_hint if text_hint in TEXT_INPUT_HINTS else hint return INPUT_HINT_AUTO def input_hint_choices(*, text_hint: str = INPUT_HINT_PROMPT) -> list[str]: text_hint = text_hint if text_hint in TEXT_INPUT_HINTS else INPUT_HINT_PROMPT return [INPUT_HINT_AUTO, INPUT_HINT_METADATA, text_hint] def row_from_inputs( source_text: str, metadata_json: str, input_hint: str, *, metadata_methods: tuple[str, ...] = METADATA_INPUT_HINTS, text_hint: str = INPUT_HINT_PROMPT, ) -> tuple[dict[str, Any] | None, str]: input_hint = normalize_input_hint(input_hint, text_hint=text_hint) if input_hint in metadata_methods: for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")): row = maybe_json(text) if row is not None: return normalize_input_metadata(row), method return None, "text" def strip_trigger_prefix( text: Any, trigger_candidates: tuple[str, ...] | list[str], *, preserve_trigger: bool = False, remove_exact: bool = False, ) -> str: text = clean_text(text) if remove_exact: text = text.strip(" ,") if preserve_trigger: return text for trigger in trigger_candidates: trigger = clean_text(trigger) if not trigger: continue if text.lower().startswith(trigger.lower() + ","): return text[len(trigger) + 1 :].strip(" ,") if text.lower().startswith(trigger.lower() + "."): return text[len(trigger) + 1 :].strip(" ,") if remove_exact and text.lower() == trigger.lower(): return "" return text def split_avoid(text: Any) -> tuple[str, str]: text = clean_text(text) match = re.search(r"\bAvoid:\s*(.*)$", text) if not match: return text, "" return text[: match.start()].strip(" ."), match.group(1).strip(" .") def strip_prompt_field_labels( text: Any, *, field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS, ) -> str: text = clean_text(text) if not text: return "" labels = "|".join(re.escape(name) for name in sorted(field_labels, key=len, reverse=True)) return clean_text(re.sub(rf"\b(?:{labels}):\s*", "", text)) def prompt_field( text: Any, label: str, *, field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS, ) -> str: text = clean_text(text) if not text: return "" labels = "|".join(re.escape(name) for name in field_labels) pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)" match = re.search(pattern, text) if not match: return "" return clean_text(match.group(1)).rstrip(".") def row_value( row: dict[str, Any], key: str, labels: tuple[str, ...] = (), *, field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS, ) -> str: value = clean_text(row.get(key, "")) if value: return value prompt = clean_text(row.get("prompt", "")) for label in labels: value = prompt_field(prompt, label, field_labels=field_labels) if value: return value return ""