Files
ComfyUI-Ethanfel-Prompt-Bui…/formatter_input.py

225 lines
6.5 KiB
Python

from __future__ import annotations
import json
import re
from typing import Any
try:
from . import row_normalization as row_normalization_policy
except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`.
import row_normalization as row_normalization_policy
DEFAULT_PROMPT_FIELD_LABELS = (
"Ages",
"Body types",
"Cast",
"Cast descriptors",
"Characters",
"Softcore setup",
"Hardcore setup",
"POV participant",
"Body exposure",
"Scene",
"Setting",
"Pose",
"Sexual pose",
"Sexual scene",
"Facial expression",
"Facial expressions",
"Clothing",
"Clothing state",
"Visual clothing state",
"Outfit",
"Erotic outfit",
"Teaser outfit detail",
"Softcore visual reference",
"Visible remaining styling",
"Prop/detail",
"Composition",
"Role graph",
"Camera",
"Camera control",
"Use",
"Avoid",
)
INPUT_HINT_AUTO = "auto"
INPUT_HINT_METADATA = "metadata_json"
INPUT_HINT_PROMPT = "prompt"
INPUT_HINT_CAPTION_OR_PROMPT = "caption_or_prompt"
TEXT_INPUT_HINTS = (INPUT_HINT_PROMPT, INPUT_HINT_CAPTION_OR_PROMPT)
FORMATTER_INPUT_HINTS = (INPUT_HINT_AUTO, INPUT_HINT_METADATA, INPUT_HINT_PROMPT, INPUT_HINT_CAPTION_OR_PROMPT)
METADATA_INPUT_HINTS = (INPUT_HINT_AUTO, INPUT_HINT_METADATA)
_INPUT_HINT_ALIASES = {
"caption": INPUT_HINT_CAPTION_OR_PROMPT,
"caption_prompt": INPUT_HINT_CAPTION_OR_PROMPT,
"caption_or_text": INPUT_HINT_CAPTION_OR_PROMPT,
"metadata": INPUT_HINT_METADATA,
"metadata json": INPUT_HINT_METADATA,
"source_json": INPUT_HINT_AUTO,
"source text": INPUT_HINT_PROMPT,
"source_text": INPUT_HINT_PROMPT,
"text": INPUT_HINT_PROMPT,
}
def prompt_field_labels() -> tuple[str, ...]:
return DEFAULT_PROMPT_FIELD_LABELS
def clean_text(value: Any) -> str:
text = "" if value is None else str(value)
text = text.replace("\n", " ")
text = re.sub(r"\s+", " ", text).strip()
text = re.sub(r"\s+([,.;:])", r"\1", text)
return text
def maybe_json(text: Any) -> dict[str, Any] | None:
text = clean_text(text)
if not text.startswith("{"):
return None
try:
value = json.loads(text)
except json.JSONDecodeError:
return None
return value if isinstance(value, dict) else None
def normalize_input_metadata(row: dict[str, Any]) -> dict[str, Any]:
row = dict(row)
trigger = str(row.get("trigger") or "").strip()
if is_pair_metadata(row):
return row_normalization_policy.normalize_pair_metadata(row, active_trigger=trigger)
return row_normalization_policy.sanitize_metadata_row_text(row, active_trigger=trigger)
def is_pair_metadata(row: Any) -> bool:
if not isinstance(row, dict):
return False
soft_side = (
isinstance(row.get("softcore_row"), dict)
or bool(clean_text(row.get("softcore_prompt")))
or bool(clean_text(row.get("softcore_caption")))
)
hard_side = (
isinstance(row.get("hardcore_row"), dict)
or bool(clean_text(row.get("hardcore_prompt")))
or bool(clean_text(row.get("hardcore_caption")))
)
return soft_side and hard_side
def normalize_input_hint(value: Any, *, text_hint: str = INPUT_HINT_PROMPT) -> str:
hint = clean_text(value).lower().replace("-", "_")
hint = _INPUT_HINT_ALIASES.get(hint, hint)
if hint in (INPUT_HINT_AUTO, INPUT_HINT_METADATA):
return hint
if hint in TEXT_INPUT_HINTS:
return text_hint if text_hint in TEXT_INPUT_HINTS else hint
return INPUT_HINT_AUTO
def input_hint_choices(*, text_hint: str = INPUT_HINT_PROMPT) -> list[str]:
text_hint = text_hint if text_hint in TEXT_INPUT_HINTS else INPUT_HINT_PROMPT
return [INPUT_HINT_AUTO, INPUT_HINT_METADATA, text_hint]
def row_from_inputs(
source_text: str,
metadata_json: str,
input_hint: str,
*,
metadata_methods: tuple[str, ...] = METADATA_INPUT_HINTS,
text_hint: str = INPUT_HINT_PROMPT,
) -> tuple[dict[str, Any] | None, str]:
input_hint = normalize_input_hint(input_hint, text_hint=text_hint)
if input_hint in metadata_methods:
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
row = maybe_json(text)
if row is not None:
return normalize_input_metadata(row), method
return None, "text"
def strip_trigger_prefix(
text: Any,
trigger_candidates: tuple[str, ...] | list[str],
*,
preserve_trigger: bool = False,
remove_exact: bool = False,
) -> str:
text = clean_text(text)
if remove_exact:
text = text.strip(" ,")
if preserve_trigger:
return text
for trigger in trigger_candidates:
trigger = clean_text(trigger)
if not trigger:
continue
if text.lower().startswith(trigger.lower() + ","):
return text[len(trigger) + 1 :].strip(" ,")
if text.lower().startswith(trigger.lower() + "."):
return text[len(trigger) + 1 :].strip(" ,")
if remove_exact and text.lower() == trigger.lower():
return ""
return text
def split_avoid(text: Any) -> tuple[str, str]:
text = clean_text(text)
match = re.search(r"\bAvoid:\s*(.*)$", text)
if not match:
return text, ""
return text[: match.start()].strip(" ."), match.group(1).strip(" .")
def strip_prompt_field_labels(
text: Any,
*,
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
) -> str:
text = clean_text(text)
if not text:
return ""
labels = "|".join(re.escape(name) for name in sorted(field_labels, key=len, reverse=True))
return clean_text(re.sub(rf"\b(?:{labels}):\s*", "", text))
def prompt_field(
text: Any,
label: str,
*,
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
) -> str:
text = clean_text(text)
if not text:
return ""
labels = "|".join(re.escape(name) for name in field_labels)
pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)"
match = re.search(pattern, text)
if not match:
return ""
return clean_text(match.group(1)).rstrip(".")
def row_value(
row: dict[str, Any],
key: str,
labels: tuple[str, ...] = (),
*,
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
) -> str:
value = clean_text(row.get(key, ""))
if value:
return value
prompt = clean_text(row.get("prompt", ""))
for label in labels:
value = prompt_field(prompt, label, field_labels=field_labels)
if value:
return value
return ""