Share formatter cast descriptor policy
This commit is contained in:
+7
-31
@@ -6,10 +6,12 @@ from typing import Any
|
||||
try:
|
||||
from . import formatter_input as input_policy
|
||||
from .hardcore_action_metadata import normalize_hardcore_action_family
|
||||
from . import krea_cast as cast_policy
|
||||
from .prompt_hygiene import sanitize_prose_text
|
||||
except ImportError: # Allows local smoke tests with `python -c`.
|
||||
import formatter_input as input_policy
|
||||
from hardcore_action_metadata import normalize_hardcore_action_family
|
||||
import krea_cast as cast_policy
|
||||
from prompt_hygiene import sanitize_prose_text
|
||||
|
||||
|
||||
@@ -132,49 +134,23 @@ def _metadata_action_label(row: dict[str, Any], default: str = "sexual pose") ->
|
||||
|
||||
|
||||
def _prompt_cast_descriptors(text: str) -> str:
|
||||
return _clean_text(text).replace("Woman A / primary creator:", "Woman A:")
|
||||
return cast_policy.prompt_cast_descriptors(text)
|
||||
|
||||
|
||||
def _cast_entries(text: str) -> list[tuple[str, str]]:
|
||||
text = _prompt_cast_descriptors(text)
|
||||
entries: list[tuple[str, str]] = []
|
||||
for part in text.split(";"):
|
||||
part = _clean_text(part)
|
||||
match = re.match(r"^((?:Woman|Man) [A-Z]):\s*(.+)$", part)
|
||||
if match:
|
||||
entries.append((match.group(1), _clean_text(match.group(2))))
|
||||
return entries
|
||||
return cast_policy.cast_entries(text)
|
||||
|
||||
|
||||
def _natural_cast_descriptor_text(text: str) -> str:
|
||||
entries = _cast_entries(text)
|
||||
if not entries:
|
||||
return _clean_text(text)
|
||||
labels = [label for label, _descriptor in entries]
|
||||
if labels == ["Woman A"] or labels == ["Man A"]:
|
||||
return f"A {entries[0][1]}"
|
||||
if set(labels) == {"Woman A", "Man A"} and len(labels) == 2:
|
||||
by_label = {label: descriptor for label, descriptor in entries}
|
||||
return f"A {by_label['Woman A']} alongside a {by_label['Man A']}"
|
||||
return " ".join(f"{label} is {descriptor}." for label, descriptor in entries)
|
||||
return cast_policy.natural_cast_descriptor_text(text)
|
||||
|
||||
|
||||
def _cast_labels(text: str) -> list[str]:
|
||||
return [label for label, _descriptor in _cast_entries(text)]
|
||||
return cast_policy.cast_labels(text)
|
||||
|
||||
|
||||
def _natural_label_text(text: Any, labels: list[str]) -> str:
|
||||
text = _clean_text(text)
|
||||
if not text:
|
||||
return ""
|
||||
if set(labels) == {"Woman A", "Man A"}:
|
||||
text = re.sub(r"\bWoman A\b", "the woman", text)
|
||||
text = re.sub(r"\bMan A\b", "the man", text)
|
||||
elif labels == ["Woman A"]:
|
||||
text = re.sub(r"\bWoman A\b", "the woman", text)
|
||||
elif labels == ["Man A"]:
|
||||
text = re.sub(r"\bMan A\b", "the man", text)
|
||||
return text
|
||||
return cast_policy.natural_label_text(text, labels, capitalize_sentence_starts=False)
|
||||
|
||||
|
||||
def _strip_style_tail(text: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user