Share formatter cast descriptor policy
This commit is contained in:
+31
-12
@@ -3,13 +3,14 @@ from __future__ import annotations
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from . import formatter_input as input_policy
|
||||
except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`.
|
||||
import formatter_input as input_policy
|
||||
|
||||
|
||||
def _clean(value: Any) -> str:
|
||||
text = "" if value is None else str(value)
|
||||
text = text.replace("\n", " ")
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
text = re.sub(r"\s+([,.;:])", r"\1", text)
|
||||
return text
|
||||
return input_policy.clean_text(value)
|
||||
|
||||
|
||||
def _with_indefinite_article(text: str) -> str:
|
||||
@@ -35,6 +36,23 @@ def cast_entries(text: str) -> list[tuple[str, str]]:
|
||||
return entries
|
||||
|
||||
|
||||
def cast_labels(text: str) -> list[str]:
|
||||
return [label for label, _descriptor in cast_entries(text)]
|
||||
|
||||
|
||||
def natural_cast_descriptor_text(text: str) -> str:
|
||||
entries = cast_entries(text)
|
||||
if not entries:
|
||||
return _clean(text)
|
||||
labels = [label for label, _descriptor in entries]
|
||||
if labels == ["Woman A"] or labels == ["Man A"]:
|
||||
return f"A {entries[0][1]}"
|
||||
if set(labels) == {"Woman A", "Man A"} and len(labels) == 2:
|
||||
by_label = {label: descriptor for label, descriptor in entries}
|
||||
return f"A {by_label['Woman A']} alongside a {by_label['Man A']}"
|
||||
return " ".join(f"{label} is {descriptor}." for label, descriptor in entries)
|
||||
|
||||
|
||||
def label_join(labels: list[str]) -> str:
|
||||
labels = [_clean(label) for label in labels if _clean(label)]
|
||||
if not labels:
|
||||
@@ -52,7 +70,7 @@ def label_join(labels: list[str]) -> str:
|
||||
return f"{', '.join(labels[:-1])}, and {labels[-1]}"
|
||||
|
||||
|
||||
def natural_label_text(text: Any, labels: list[str]) -> str:
|
||||
def natural_label_text(text: Any, labels: list[str], *, capitalize_sentence_starts: bool = True) -> str:
|
||||
text = _clean(text)
|
||||
if not text:
|
||||
return ""
|
||||
@@ -63,12 +81,13 @@ def natural_label_text(text: Any, labels: list[str]) -> str:
|
||||
text = re.sub(r"\bWoman A\b", "the woman", text)
|
||||
elif labels == ["Man A"]:
|
||||
text = re.sub(r"\bMan A\b", "the man", text)
|
||||
text = re.sub(
|
||||
r"(^|[.!?]\s+)(the woman|the man)\b",
|
||||
lambda match: match.group(1) + match.group(2).capitalize(),
|
||||
text,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
if capitalize_sentence_starts:
|
||||
text = re.sub(
|
||||
r"(^|[.!?]\s+)(the woman|the man)\b",
|
||||
lambda match: match.group(1) + match.group(2).capitalize(),
|
||||
text,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user