199 lines
5.8 KiB
Python
199 lines
5.8 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
from typing import Any
|
|
|
|
try:
|
|
from . import row_normalization as row_normalization_policy
|
|
except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`.
|
|
import row_normalization as row_normalization_policy
|
|
|
|
|
|
DEFAULT_PROMPT_FIELD_LABELS = (
|
|
"Ages",
|
|
"Body types",
|
|
"Cast",
|
|
"Cast descriptors",
|
|
"Characters",
|
|
"Scene",
|
|
"Setting",
|
|
"Pose",
|
|
"Sexual pose",
|
|
"Sexual scene",
|
|
"Facial expression",
|
|
"Facial expressions",
|
|
"Clothing",
|
|
"Erotic outfit",
|
|
"Prop/detail",
|
|
"Composition",
|
|
"Role graph",
|
|
"Camera",
|
|
"Camera control",
|
|
"Use",
|
|
"Avoid",
|
|
)
|
|
|
|
INPUT_HINT_AUTO = "auto"
|
|
INPUT_HINT_METADATA = "metadata_json"
|
|
INPUT_HINT_PROMPT = "prompt"
|
|
INPUT_HINT_CAPTION_OR_PROMPT = "caption_or_prompt"
|
|
TEXT_INPUT_HINTS = (INPUT_HINT_PROMPT, INPUT_HINT_CAPTION_OR_PROMPT)
|
|
FORMATTER_INPUT_HINTS = (INPUT_HINT_AUTO, INPUT_HINT_METADATA, INPUT_HINT_PROMPT, INPUT_HINT_CAPTION_OR_PROMPT)
|
|
METADATA_INPUT_HINTS = (INPUT_HINT_AUTO, INPUT_HINT_METADATA)
|
|
|
|
_INPUT_HINT_ALIASES = {
|
|
"caption": INPUT_HINT_CAPTION_OR_PROMPT,
|
|
"caption_prompt": INPUT_HINT_CAPTION_OR_PROMPT,
|
|
"caption_or_text": INPUT_HINT_CAPTION_OR_PROMPT,
|
|
"metadata": INPUT_HINT_METADATA,
|
|
"metadata json": INPUT_HINT_METADATA,
|
|
"source_json": INPUT_HINT_AUTO,
|
|
"source text": INPUT_HINT_PROMPT,
|
|
"source_text": INPUT_HINT_PROMPT,
|
|
"text": INPUT_HINT_PROMPT,
|
|
}
|
|
|
|
|
|
def prompt_field_labels() -> tuple[str, ...]:
|
|
return DEFAULT_PROMPT_FIELD_LABELS
|
|
|
|
|
|
def clean_text(value: Any) -> str:
|
|
text = "" if value is None else str(value)
|
|
text = text.replace("\n", " ")
|
|
text = re.sub(r"\s+", " ", text).strip()
|
|
text = re.sub(r"\s+([,.;:])", r"\1", text)
|
|
return text
|
|
|
|
|
|
def maybe_json(text: Any) -> dict[str, Any] | None:
|
|
text = clean_text(text)
|
|
if not text.startswith("{"):
|
|
return None
|
|
try:
|
|
value = json.loads(text)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
return value if isinstance(value, dict) else None
|
|
|
|
|
|
def normalize_input_metadata(row: dict[str, Any]) -> dict[str, Any]:
|
|
row = dict(row)
|
|
trigger = str(row.get("trigger") or "").strip()
|
|
if row.get("mode") == "Insta/OF":
|
|
return row_normalization_policy.normalize_pair_metadata(row, active_trigger=trigger)
|
|
return row_normalization_policy.sanitize_metadata_row_text(row, active_trigger=trigger)
|
|
|
|
|
|
def normalize_input_hint(value: Any, *, text_hint: str = INPUT_HINT_PROMPT) -> str:
|
|
hint = clean_text(value).lower().replace("-", "_")
|
|
hint = _INPUT_HINT_ALIASES.get(hint, hint)
|
|
if hint in (INPUT_HINT_AUTO, INPUT_HINT_METADATA):
|
|
return hint
|
|
if hint in TEXT_INPUT_HINTS:
|
|
return text_hint if text_hint in TEXT_INPUT_HINTS else hint
|
|
return INPUT_HINT_AUTO
|
|
|
|
|
|
def input_hint_choices(*, text_hint: str = INPUT_HINT_PROMPT) -> list[str]:
|
|
text_hint = text_hint if text_hint in TEXT_INPUT_HINTS else INPUT_HINT_PROMPT
|
|
return [INPUT_HINT_AUTO, INPUT_HINT_METADATA, text_hint]
|
|
|
|
|
|
def row_from_inputs(
|
|
source_text: str,
|
|
metadata_json: str,
|
|
input_hint: str,
|
|
*,
|
|
metadata_methods: tuple[str, ...] = METADATA_INPUT_HINTS,
|
|
text_hint: str = INPUT_HINT_PROMPT,
|
|
) -> tuple[dict[str, Any] | None, str]:
|
|
input_hint = normalize_input_hint(input_hint, text_hint=text_hint)
|
|
if input_hint in metadata_methods:
|
|
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
|
|
row = maybe_json(text)
|
|
if row is not None:
|
|
return normalize_input_metadata(row), method
|
|
return None, "text"
|
|
|
|
|
|
def strip_trigger_prefix(
|
|
text: Any,
|
|
trigger_candidates: tuple[str, ...] | list[str],
|
|
*,
|
|
preserve_trigger: bool = False,
|
|
remove_exact: bool = False,
|
|
) -> str:
|
|
text = clean_text(text)
|
|
if remove_exact:
|
|
text = text.strip(" ,")
|
|
if preserve_trigger:
|
|
return text
|
|
for trigger in trigger_candidates:
|
|
trigger = clean_text(trigger)
|
|
if not trigger:
|
|
continue
|
|
if text.lower().startswith(trigger.lower() + ","):
|
|
return text[len(trigger) + 1 :].strip(" ,")
|
|
if text.lower().startswith(trigger.lower() + "."):
|
|
return text[len(trigger) + 1 :].strip(" ,")
|
|
if remove_exact and text.lower() == trigger.lower():
|
|
return ""
|
|
return text
|
|
|
|
|
|
def split_avoid(text: Any) -> tuple[str, str]:
|
|
text = clean_text(text)
|
|
match = re.search(r"\bAvoid:\s*(.*)$", text)
|
|
if not match:
|
|
return text, ""
|
|
return text[: match.start()].strip(" ."), match.group(1).strip(" .")
|
|
|
|
|
|
def strip_prompt_field_labels(
|
|
text: Any,
|
|
*,
|
|
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
|
|
) -> str:
|
|
text = clean_text(text)
|
|
if not text:
|
|
return ""
|
|
labels = "|".join(re.escape(name) for name in sorted(field_labels, key=len, reverse=True))
|
|
return clean_text(re.sub(rf"\b(?:{labels}):\s*", "", text))
|
|
|
|
|
|
def prompt_field(
|
|
text: Any,
|
|
label: str,
|
|
*,
|
|
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
|
|
) -> str:
|
|
text = clean_text(text)
|
|
if not text:
|
|
return ""
|
|
labels = "|".join(re.escape(name) for name in field_labels)
|
|
pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)"
|
|
match = re.search(pattern, text)
|
|
if not match:
|
|
return ""
|
|
return clean_text(match.group(1)).rstrip(".")
|
|
|
|
|
|
def row_value(
|
|
row: dict[str, Any],
|
|
key: str,
|
|
labels: tuple[str, ...] = (),
|
|
*,
|
|
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
|
|
) -> str:
|
|
value = clean_text(row.get(key, ""))
|
|
if value:
|
|
return value
|
|
prompt = clean_text(row.get("prompt", ""))
|
|
for label in labels:
|
|
value = prompt_field(prompt, label, field_labels=field_labels)
|
|
if value:
|
|
return value
|
|
return ""
|