Extract formatter input parsing policy

This commit is contained in:
2026-06-27 01:22:07 +02:00
parent b54b8b9421
commit 4c45d96472
7 changed files with 239 additions and 159 deletions
+132
View File
@@ -0,0 +1,132 @@
from __future__ import annotations
import json
import re
from typing import Any
DEFAULT_PROMPT_FIELD_LABELS = (
"Ages",
"Body types",
"Cast",
"Cast descriptors",
"Characters",
"Scene",
"Setting",
"Pose",
"Sexual pose",
"Sexual scene",
"Facial expression",
"Facial expressions",
"Clothing",
"Erotic outfit",
"Prop/detail",
"Composition",
"Role graph",
"Camera",
"Camera control",
"Use",
"Avoid",
)
def clean_text(value: Any) -> str:
text = "" if value is None else str(value)
text = text.replace("\n", " ")
text = re.sub(r"\s+", " ", text).strip()
text = re.sub(r"\s+([,.;:])", r"\1", text)
return text
def maybe_json(text: Any) -> dict[str, Any] | None:
text = clean_text(text)
if not text.startswith("{"):
return None
try:
value = json.loads(text)
except json.JSONDecodeError:
return None
return value if isinstance(value, dict) else None
def row_from_inputs(
source_text: str,
metadata_json: str,
input_hint: str,
*,
metadata_methods: tuple[str, ...] = ("auto", "metadata_json"),
) -> tuple[dict[str, Any] | None, str]:
if input_hint in metadata_methods:
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
row = maybe_json(text)
if row is not None:
return row, method
return None, "text"
def strip_trigger_prefix(
text: Any,
trigger_candidates: tuple[str, ...] | list[str],
*,
preserve_trigger: bool = False,
remove_exact: bool = False,
) -> str:
text = clean_text(text)
if remove_exact:
text = text.strip(" ,")
if preserve_trigger:
return text
for trigger in trigger_candidates:
trigger = clean_text(trigger)
if not trigger:
continue
if text.lower().startswith(trigger.lower() + ","):
return text[len(trigger) + 1 :].strip(" ,")
if text.lower().startswith(trigger.lower() + "."):
return text[len(trigger) + 1 :].strip(" ,")
if remove_exact and text.lower() == trigger.lower():
return ""
return text
def split_avoid(text: Any) -> tuple[str, str]:
text = clean_text(text)
match = re.search(r"\bAvoid:\s*(.*)$", text)
if not match:
return text, ""
return text[: match.start()].strip(" ."), match.group(1).strip(" .")
def prompt_field(
text: Any,
label: str,
*,
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
) -> str:
text = clean_text(text)
if not text:
return ""
labels = "|".join(re.escape(name) for name in field_labels)
pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)"
match = re.search(pattern, text)
if not match:
return ""
return clean_text(match.group(1)).rstrip(".")
def row_value(
row: dict[str, Any],
key: str,
labels: tuple[str, ...] = (),
*,
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
) -> str:
value = clean_text(row.get(key, ""))
if value:
return value
prompt = clean_text(row.get("prompt", ""))
for label in labels:
value = prompt_field(prompt, label, field_labels=field_labels)
if value:
return value
return ""