129 lines
4.4 KiB
Python
129 lines
4.4 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import Any
|
|
|
|
try:
|
|
from . import formatter_input as input_policy
|
|
except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`.
|
|
import formatter_input as input_policy
|
|
|
|
|
|
def _clean(value: Any) -> str:
|
|
return input_policy.clean_text(value)
|
|
|
|
|
|
def _with_indefinite_article(text: str) -> str:
|
|
text = _clean(text)
|
|
if not text or text.lower().startswith(("a ", "an ")):
|
|
return text
|
|
article = "an" if text[:1].lower() in "aeiou" else "a"
|
|
return f"{article} {text}"
|
|
|
|
|
|
def prompt_cast_descriptors(text: str) -> str:
|
|
return _clean(text).replace("Woman A / primary creator:", "Woman A:")
|
|
|
|
|
|
def cast_entries(text: str) -> list[tuple[str, str]]:
|
|
text = prompt_cast_descriptors(text)
|
|
entries: list[tuple[str, str]] = []
|
|
for part in text.split(";"):
|
|
part = _clean(part)
|
|
match = re.match(r"^((?:Woman|Man) [A-Z]):\s*(.+)$", part)
|
|
if match:
|
|
entries.append((match.group(1), _clean(match.group(2))))
|
|
return entries
|
|
|
|
|
|
def cast_labels(text: str) -> list[str]:
|
|
return [label for label, _descriptor in cast_entries(text)]
|
|
|
|
|
|
def natural_cast_descriptor_text(text: str) -> str:
|
|
entries = cast_entries(text)
|
|
if not entries:
|
|
return _clean(text)
|
|
labels = [label for label, _descriptor in entries]
|
|
if labels == ["Woman A"] or labels == ["Man A"]:
|
|
return f"A {entries[0][1]}"
|
|
if set(labels) == {"Woman A", "Man A"} and len(labels) == 2:
|
|
by_label = {label: descriptor for label, descriptor in entries}
|
|
return f"A {by_label['Woman A']} alongside a {by_label['Man A']}"
|
|
return " ".join(f"{label} is {descriptor}." for label, descriptor in entries)
|
|
|
|
|
|
def label_join(labels: list[str]) -> str:
|
|
labels = [_clean(label) for label in labels if _clean(label)]
|
|
if not labels:
|
|
return "the named adults"
|
|
if set(labels) == {"Woman A", "Man A"}:
|
|
return "the woman and man"
|
|
if len(labels) == 1:
|
|
if labels[0] == "Woman A":
|
|
return "the woman"
|
|
if labels[0] == "Man A":
|
|
return "the man"
|
|
return labels[0]
|
|
if len(labels) == 2:
|
|
return f"{labels[0]} and {labels[1]}"
|
|
return f"{', '.join(labels[:-1])}, and {labels[-1]}"
|
|
|
|
|
|
def natural_label_text(text: Any, labels: list[str], *, capitalize_sentence_starts: bool = True) -> str:
|
|
text = _clean(text)
|
|
if not text:
|
|
return ""
|
|
if set(labels) == {"Woman A", "Man A"}:
|
|
text = re.sub(r"\bWoman A\b", "the woman", text)
|
|
text = re.sub(r"\bMan A\b", "the man", text)
|
|
elif labels == ["Woman A"]:
|
|
text = re.sub(r"\bWoman A\b", "the woman", text)
|
|
elif labels == ["Man A"]:
|
|
text = re.sub(r"\bMan A\b", "the man", text)
|
|
if capitalize_sentence_starts:
|
|
text = re.sub(
|
|
r"(^|[.!?]\s+)(the woman|the man)\b",
|
|
lambda match: match.group(1) + match.group(2).capitalize(),
|
|
text,
|
|
flags=re.IGNORECASE,
|
|
)
|
|
return text
|
|
|
|
|
|
def lowercase_for_inline_join(text: str) -> str:
|
|
return re.sub(
|
|
r"^(The woman|The man|The viewer|The named adults)\b",
|
|
lambda match: match.group(1).lower(),
|
|
_clean(text),
|
|
flags=re.IGNORECASE,
|
|
)
|
|
|
|
|
|
def cast_prose(
|
|
text: str,
|
|
central_label: str = "Woman A",
|
|
omit_labels: list[str] | set[str] | tuple[str, ...] = (),
|
|
) -> tuple[str, list[str]]:
|
|
raw_entries = cast_entries(text)
|
|
omitted = set(omit_labels or [])
|
|
entries = [(label, descriptor) for label, descriptor in raw_entries if label not in omitted]
|
|
if raw_entries and not entries:
|
|
return "", []
|
|
if not entries:
|
|
return (f"{central_label} is {_clean(text)}" if _clean(text) else "", [])
|
|
labels = [label for label, _descriptor in entries]
|
|
if labels == ["Woman A"]:
|
|
return _with_indefinite_article(entries[0][1]), labels
|
|
if labels == ["Man A"]:
|
|
return _with_indefinite_article(entries[0][1]), labels
|
|
if set(labels) == {"Woman A", "Man A"} and len(labels) == 2:
|
|
by_label = {label: descriptor for label, descriptor in entries}
|
|
return f"{_with_indefinite_article(by_label['Woman A'])} alongside {_with_indefinite_article(by_label['Man A'])}", labels
|
|
sentences = []
|
|
for label, descriptor in entries:
|
|
sentences.append(f"{label} is {descriptor}.")
|
|
if central_label in labels:
|
|
sentences.append(f"{central_label} is the central subject.")
|
|
return " ".join(sentences), labels
|