Normalize formatter metadata inputs
This commit is contained in:
+14
-1
@@ -4,6 +4,11 @@ import json
|
|||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import row_normalization as row_normalization_policy
|
||||||
|
except ImportError: # Allows local smoke tests with `python tools/prompt_smoke.py`.
|
||||||
|
import row_normalization as row_normalization_policy
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PROMPT_FIELD_LABELS = (
|
DEFAULT_PROMPT_FIELD_LABELS = (
|
||||||
"Ages",
|
"Ages",
|
||||||
@@ -73,6 +78,14 @@ def maybe_json(text: Any) -> dict[str, Any] | None:
|
|||||||
return value if isinstance(value, dict) else None
|
return value if isinstance(value, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_input_metadata(row: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
row = dict(row)
|
||||||
|
trigger = str(row.get("trigger") or "").strip()
|
||||||
|
if row.get("mode") == "Insta/OF":
|
||||||
|
return row_normalization_policy.normalize_pair_metadata(row, active_trigger=trigger)
|
||||||
|
return row_normalization_policy.sanitize_metadata_row_text(row, active_trigger=trigger)
|
||||||
|
|
||||||
|
|
||||||
def normalize_input_hint(value: Any, *, text_hint: str = INPUT_HINT_PROMPT) -> str:
|
def normalize_input_hint(value: Any, *, text_hint: str = INPUT_HINT_PROMPT) -> str:
|
||||||
hint = clean_text(value).lower().replace("-", "_")
|
hint = clean_text(value).lower().replace("-", "_")
|
||||||
hint = _INPUT_HINT_ALIASES.get(hint, hint)
|
hint = _INPUT_HINT_ALIASES.get(hint, hint)
|
||||||
@@ -101,7 +114,7 @@ def row_from_inputs(
|
|||||||
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
|
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
|
||||||
row = maybe_json(text)
|
row = maybe_json(text)
|
||||||
if row is not None:
|
if row is not None:
|
||||||
return row, method
|
return normalize_input_metadata(row), method
|
||||||
return None, "text"
|
return None, "text"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2811,6 +2811,43 @@ def smoke_formatter_input_policy() -> None:
|
|||||||
_expect(method == "source_json" and row == source_row, "Formatter input parser should treat invalid hints as auto")
|
_expect(method == "source_json" and row == source_row, "Formatter input parser should treat invalid hints as auto")
|
||||||
row, method = formatter_input.row_from_inputs(source_json, "", "prompt")
|
row, method = formatter_input.row_from_inputs(source_json, "", "prompt")
|
||||||
_expect(row is None and method == "text", "Formatter input parser should not parse source JSON in explicit prompt mode")
|
_expect(row is None and method == "text", "Formatter input parser should not parse source JSON in explicit prompt mode")
|
||||||
|
pair_metadata = {
|
||||||
|
"mode": "Insta/OF",
|
||||||
|
"trigger": Trigger,
|
||||||
|
"softcore_row": {
|
||||||
|
"prompt": f"{Trigger}, {Trigger}, embedded-only soft.",
|
||||||
|
"caption": f"{Trigger}, {Trigger}, embedded-only soft caption.",
|
||||||
|
"negative_prompt": "bad anatomy, bad anatomy",
|
||||||
|
"softcore_partner_styling": {"outfits": ["row partner outfit"], "pose": "row partner pose"},
|
||||||
|
"camera_config": {"camera_mode": "standard"},
|
||||||
|
"camera_directive": "Camera: row soft front view.",
|
||||||
|
"camera_scene_directive": "Row soft scene camera layout.",
|
||||||
|
},
|
||||||
|
"hardcore_row": {
|
||||||
|
"prompt": f"{Trigger}, {Trigger}, embedded-only hard.",
|
||||||
|
"caption": f"{Trigger}, {Trigger}, embedded-only hard caption.",
|
||||||
|
"negative_prompt": "low quality, low quality",
|
||||||
|
"hardcore_clothing_state": "row hard clothing state",
|
||||||
|
"camera_config": {"camera_mode": "pov"},
|
||||||
|
"camera_scene_directive": "Row hard scene camera layout.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
parsed_pair, pair_method = formatter_input.row_from_inputs("", _json(pair_metadata), "metadata_json")
|
||||||
|
_expect(pair_method == "metadata_json", "Formatter input parser should read pair metadata JSON")
|
||||||
|
_expect_trigger_once("formatter_input.pair.soft_prompt", parsed_pair.get("softcore_prompt"), Trigger)
|
||||||
|
_expect(
|
||||||
|
parsed_pair.get("softcore_partner_styling") == parsed_pair["softcore_row"].get("softcore_partner_styling"),
|
||||||
|
"Formatter input parser did not normalize pair soft side metadata",
|
||||||
|
)
|
||||||
|
_expect(
|
||||||
|
parsed_pair.get("hardcore_clothing_state") == parsed_pair["hardcore_row"].get("hardcore_clothing_state"),
|
||||||
|
"Formatter input parser did not normalize pair hard side metadata",
|
||||||
|
)
|
||||||
|
_expect(
|
||||||
|
parsed_pair.get("softcore_camera_config") == parsed_pair["softcore_row"].get("camera_config"),
|
||||||
|
"Formatter input parser did not normalize pair camera metadata",
|
||||||
|
)
|
||||||
|
_expect_no_duplicate_comma_items("formatter_input.pair.hard_negative", parsed_pair.get("hardcore_negative_prompt"))
|
||||||
_expect(formatter_input.split_avoid("Prompt body. Avoid: blur, watermark") == ("Prompt body", "blur, watermark"), "Avoid split changed")
|
_expect(formatter_input.split_avoid("Prompt body. Avoid: blur, watermark") == ("Prompt body", "blur, watermark"), "Avoid split changed")
|
||||||
_expect(
|
_expect(
|
||||||
formatter_input.prompt_field(source_row["prompt"], "Setting") == "quiet studio",
|
formatter_input.prompt_field(source_row["prompt"], "Setting") == "quiet studio",
|
||||||
|
|||||||
Reference in New Issue
Block a user