from __future__ import annotations import json import re from typing import Any try: from . import row_normalization as row_normalization_policy except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`. import row_normalization as row_normalization_policy DEFAULT_PROMPT_FIELD_LABELS = ( "Ages", "Body types", "Cast", "Cast descriptors", "Characters", "Scene", "Setting", "Pose", "Sexual pose", "Sexual scene", "Facial expression", "Facial expressions", "Clothing", "Erotic outfit", "Prop/detail", "Composition", "Role graph", "Camera", "Camera control", "Use", "Avoid", ) INPUT_HINT_AUTO = "auto" INPUT_HINT_METADATA = "metadata_json" INPUT_HINT_PROMPT = "prompt" INPUT_HINT_CAPTION_OR_PROMPT = "caption_or_prompt" TEXT_INPUT_HINTS = (INPUT_HINT_PROMPT, INPUT_HINT_CAPTION_OR_PROMPT) FORMATTER_INPUT_HINTS = (INPUT_HINT_AUTO, INPUT_HINT_METADATA, INPUT_HINT_PROMPT, INPUT_HINT_CAPTION_OR_PROMPT) METADATA_INPUT_HINTS = (INPUT_HINT_AUTO, INPUT_HINT_METADATA) _INPUT_HINT_ALIASES = { "caption": INPUT_HINT_CAPTION_OR_PROMPT, "caption_prompt": INPUT_HINT_CAPTION_OR_PROMPT, "caption_or_text": INPUT_HINT_CAPTION_OR_PROMPT, "metadata": INPUT_HINT_METADATA, "metadata json": INPUT_HINT_METADATA, "source_json": INPUT_HINT_AUTO, "source text": INPUT_HINT_PROMPT, "source_text": INPUT_HINT_PROMPT, "text": INPUT_HINT_PROMPT, } def prompt_field_labels() -> tuple[str, ...]: return DEFAULT_PROMPT_FIELD_LABELS def clean_text(value: Any) -> str: text = "" if value is None else str(value) text = text.replace("\n", " ") text = re.sub(r"\s+", " ", text).strip() text = re.sub(r"\s+([,.;:])", r"\1", text) return text def maybe_json(text: Any) -> dict[str, Any] | None: text = clean_text(text) if not text.startswith("{"): return None try: value = json.loads(text) except json.JSONDecodeError: return None return value if isinstance(value, dict) else None def normalize_input_metadata(row: dict[str, Any]) -> dict[str, Any]: row = dict(row) trigger = str(row.get("trigger") or "").strip() if is_pair_metadata(row): return row_normalization_policy.normalize_pair_metadata(row, active_trigger=trigger) return row_normalization_policy.sanitize_metadata_row_text(row, active_trigger=trigger) def is_pair_metadata(row: Any) -> bool: if not isinstance(row, dict): return False soft_side = ( isinstance(row.get("softcore_row"), dict) or bool(clean_text(row.get("softcore_prompt"))) or bool(clean_text(row.get("softcore_caption"))) ) hard_side = ( isinstance(row.get("hardcore_row"), dict) or bool(clean_text(row.get("hardcore_prompt"))) or bool(clean_text(row.get("hardcore_caption"))) ) return soft_side and hard_side def normalize_input_hint(value: Any, *, text_hint: str = INPUT_HINT_PROMPT) -> str: hint = clean_text(value).lower().replace("-", "_") hint = _INPUT_HINT_ALIASES.get(hint, hint) if hint in (INPUT_HINT_AUTO, INPUT_HINT_METADATA): return hint if hint in TEXT_INPUT_HINTS: return text_hint if text_hint in TEXT_INPUT_HINTS else hint return INPUT_HINT_AUTO def input_hint_choices(*, text_hint: str = INPUT_HINT_PROMPT) -> list[str]: text_hint = text_hint if text_hint in TEXT_INPUT_HINTS else INPUT_HINT_PROMPT return [INPUT_HINT_AUTO, INPUT_HINT_METADATA, text_hint] def row_from_inputs( source_text: str, metadata_json: str, input_hint: str, *, metadata_methods: tuple[str, ...] = METADATA_INPUT_HINTS, text_hint: str = INPUT_HINT_PROMPT, ) -> tuple[dict[str, Any] | None, str]: input_hint = normalize_input_hint(input_hint, text_hint=text_hint) if input_hint in metadata_methods: for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")): row = maybe_json(text) if row is not None: return normalize_input_metadata(row), method return None, "text" def strip_trigger_prefix( text: Any, trigger_candidates: tuple[str, ...] | list[str], *, preserve_trigger: bool = False, remove_exact: bool = False, ) -> str: text = clean_text(text) if remove_exact: text = text.strip(" ,") if preserve_trigger: return text for trigger in trigger_candidates: trigger = clean_text(trigger) if not trigger: continue if text.lower().startswith(trigger.lower() + ","): return text[len(trigger) + 1 :].strip(" ,") if text.lower().startswith(trigger.lower() + "."): return text[len(trigger) + 1 :].strip(" ,") if remove_exact and text.lower() == trigger.lower(): return "" return text def split_avoid(text: Any) -> tuple[str, str]: text = clean_text(text) match = re.search(r"\bAvoid:\s*(.*)$", text) if not match: return text, "" return text[: match.start()].strip(" ."), match.group(1).strip(" .") def strip_prompt_field_labels( text: Any, *, field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS, ) -> str: text = clean_text(text) if not text: return "" labels = "|".join(re.escape(name) for name in sorted(field_labels, key=len, reverse=True)) return clean_text(re.sub(rf"\b(?:{labels}):\s*", "", text)) def prompt_field( text: Any, label: str, *, field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS, ) -> str: text = clean_text(text) if not text: return "" labels = "|".join(re.escape(name) for name in field_labels) pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)" match = re.search(pattern, text) if not match: return "" return clean_text(match.group(1)).rstrip(".") def row_value( row: dict[str, Any], key: str, labels: tuple[str, ...] = (), *, field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS, ) -> str: value = clean_text(row.get(key, "")) if value: return value prompt = clean_text(row.get("prompt", "")) for label in labels: value = prompt_field(prompt, label, field_labels=field_labels) if value: return value return ""