Extract formatter input parsing policy
This commit is contained in:
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
DEFAULT_PROMPT_FIELD_LABELS = (
|
||||
"Ages",
|
||||
"Body types",
|
||||
"Cast",
|
||||
"Cast descriptors",
|
||||
"Characters",
|
||||
"Scene",
|
||||
"Setting",
|
||||
"Pose",
|
||||
"Sexual pose",
|
||||
"Sexual scene",
|
||||
"Facial expression",
|
||||
"Facial expressions",
|
||||
"Clothing",
|
||||
"Erotic outfit",
|
||||
"Prop/detail",
|
||||
"Composition",
|
||||
"Role graph",
|
||||
"Camera",
|
||||
"Camera control",
|
||||
"Use",
|
||||
"Avoid",
|
||||
)
|
||||
|
||||
|
||||
def clean_text(value: Any) -> str:
|
||||
text = "" if value is None else str(value)
|
||||
text = text.replace("\n", " ")
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
text = re.sub(r"\s+([,.;:])", r"\1", text)
|
||||
return text
|
||||
|
||||
|
||||
def maybe_json(text: Any) -> dict[str, Any] | None:
|
||||
text = clean_text(text)
|
||||
if not text.startswith("{"):
|
||||
return None
|
||||
try:
|
||||
value = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return value if isinstance(value, dict) else None
|
||||
|
||||
|
||||
def row_from_inputs(
|
||||
source_text: str,
|
||||
metadata_json: str,
|
||||
input_hint: str,
|
||||
*,
|
||||
metadata_methods: tuple[str, ...] = ("auto", "metadata_json"),
|
||||
) -> tuple[dict[str, Any] | None, str]:
|
||||
if input_hint in metadata_methods:
|
||||
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
|
||||
row = maybe_json(text)
|
||||
if row is not None:
|
||||
return row, method
|
||||
return None, "text"
|
||||
|
||||
|
||||
def strip_trigger_prefix(
|
||||
text: Any,
|
||||
trigger_candidates: tuple[str, ...] | list[str],
|
||||
*,
|
||||
preserve_trigger: bool = False,
|
||||
remove_exact: bool = False,
|
||||
) -> str:
|
||||
text = clean_text(text)
|
||||
if remove_exact:
|
||||
text = text.strip(" ,")
|
||||
if preserve_trigger:
|
||||
return text
|
||||
for trigger in trigger_candidates:
|
||||
trigger = clean_text(trigger)
|
||||
if not trigger:
|
||||
continue
|
||||
if text.lower().startswith(trigger.lower() + ","):
|
||||
return text[len(trigger) + 1 :].strip(" ,")
|
||||
if text.lower().startswith(trigger.lower() + "."):
|
||||
return text[len(trigger) + 1 :].strip(" ,")
|
||||
if remove_exact and text.lower() == trigger.lower():
|
||||
return ""
|
||||
return text
|
||||
|
||||
|
||||
def split_avoid(text: Any) -> tuple[str, str]:
|
||||
text = clean_text(text)
|
||||
match = re.search(r"\bAvoid:\s*(.*)$", text)
|
||||
if not match:
|
||||
return text, ""
|
||||
return text[: match.start()].strip(" ."), match.group(1).strip(" .")
|
||||
|
||||
|
||||
def prompt_field(
|
||||
text: Any,
|
||||
label: str,
|
||||
*,
|
||||
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
|
||||
) -> str:
|
||||
text = clean_text(text)
|
||||
if not text:
|
||||
return ""
|
||||
labels = "|".join(re.escape(name) for name in field_labels)
|
||||
pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)"
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
return ""
|
||||
return clean_text(match.group(1)).rstrip(".")
|
||||
|
||||
|
||||
def row_value(
|
||||
row: dict[str, Any],
|
||||
key: str,
|
||||
labels: tuple[str, ...] = (),
|
||||
*,
|
||||
field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS,
|
||||
) -> str:
|
||||
value = clean_text(row.get(key, ""))
|
||||
if value:
|
||||
return value
|
||||
prompt = clean_text(row.get("prompt", ""))
|
||||
for label in labels:
|
||||
value = prompt_field(prompt, label, field_labels=field_labels)
|
||||
if value:
|
||||
return value
|
||||
return ""
|
||||
Reference in New Issue
Block a user