Normalize formatter metadata inputs
This commit is contained in:
+14
-1
@@ -4,6 +4,11 @@ import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from . import row_normalization as row_normalization_policy
|
||||
except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`.
|
||||
import row_normalization as row_normalization_policy
|
||||
|
||||
|
||||
DEFAULT_PROMPT_FIELD_LABELS = (
|
||||
"Ages",
|
||||
@@ -73,6 +78,14 @@ def maybe_json(text: Any) -> dict[str, Any] | None:
|
||||
return value if isinstance(value, dict) else None
|
||||
|
||||
|
||||
def normalize_input_metadata(row: dict[str, Any]) -> dict[str, Any]:
|
||||
row = dict(row)
|
||||
trigger = str(row.get("trigger") or "").strip()
|
||||
if row.get("mode") == "Insta/OF":
|
||||
return row_normalization_policy.normalize_pair_metadata(row, active_trigger=trigger)
|
||||
return row_normalization_policy.sanitize_metadata_row_text(row, active_trigger=trigger)
|
||||
|
||||
|
||||
def normalize_input_hint(value: Any, *, text_hint: str = INPUT_HINT_PROMPT) -> str:
|
||||
hint = clean_text(value).lower().replace("-", "_")
|
||||
hint = _INPUT_HINT_ALIASES.get(hint, hint)
|
||||
@@ -101,7 +114,7 @@ def row_from_inputs(
|
||||
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
|
||||
row = maybe_json(text)
|
||||
if row is not None:
|
||||
return row, method
|
||||
return normalize_input_metadata(row), method
|
||||
return None, "text"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user