From 4c45d964729b1d26d0f9ae8e99ffff8ffdca2e5f Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 27 Jun 2026 01:22:07 +0200 Subject: [PATCH] Extract formatter input parsing policy --- caption_naturalizer.py | 67 ++-------- docs/prompt-architecture-improvement-plan.md | 22 ++++ docs/prompt-pool-routing-map.md | 1 + formatter_input.py | 132 +++++++++++++++++++ krea_formatter.py | 63 ++------- sdxl_formatter.py | 60 ++------- tools/prompt_smoke.py | 53 ++++++++ 7 files changed, 239 insertions(+), 159 deletions(-) create mode 100644 formatter_input.py diff --git a/caption_naturalizer.py b/caption_naturalizer.py index 50bc163..3bcb1f0 100644 --- a/caption_naturalizer.py +++ b/caption_naturalizer.py @@ -1,13 +1,14 @@ from __future__ import annotations -import json import re from typing import Any try: + from . import formatter_input as input_policy from .hardcore_action_metadata import normalize_hardcore_action_family from .prompt_hygiene import sanitize_prose_text except ImportError: # Allows local smoke tests with `python -c`. + import formatter_input as input_policy from hardcore_action_metadata import normalize_hardcore_action_family from prompt_hygiene import sanitize_prose_text @@ -71,11 +72,7 @@ POSITION_FAMILY_CAPTION_LABELS = { def _clean_text(value: Any) -> str: - text = "" if value is None else str(value) - text = text.replace("\n", " ") - text = re.sub(r"\s+", " ", text).strip() - text = re.sub(r"\s+([,.;:])", r"\1", text) - return text + return input_policy.clean_text(value) def _is_false(value: Any) -> bool: @@ -189,18 +186,11 @@ def _strip_style_tail(text: str) -> str: def _remove_trigger(text: str, trigger: str) -> str: - text = _clean_text(text).strip(" ,") - for candidate in (trigger, OLD_TRIGGER, DEFAULT_TRIGGER): - candidate = candidate.strip() - if not candidate: - continue - if text.lower().startswith(candidate.lower() + ","): - return text[len(candidate) + 1 :].strip(" ,") - if text.lower().startswith(candidate.lower() + "."): - return text[len(candidate) + 1 :].strip(" ,") - if text.lower() == candidate.lower(): - return "" - return text + return input_policy.strip_trigger_prefix( + text, + (trigger, OLD_TRIGGER, DEFAULT_TRIGGER), + remove_exact=True, + ) def _with_trigger(text: str, trigger: str, include_trigger: bool) -> str: @@ -214,55 +204,24 @@ def _with_trigger(text: str, trigger: str, include_trigger: bool) -> str: def _maybe_json(text: str) -> dict[str, Any] | None: - text = _clean_text(text) - if not text or not text.startswith("{"): - return None - try: - value = json.loads(text) - except json.JSONDecodeError: - return None - return value if isinstance(value, dict) else None + return input_policy.maybe_json(text) def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]: - candidates: list[tuple[str, str]] = [] - if input_hint in ("auto", "metadata_json"): - candidates.append((metadata_json, "metadata_json")) - candidates.append((source_text, "source_json")) - for text, method in candidates: - row = _maybe_json(text) - if row is not None: - return row, method - return None, "text" + return input_policy.row_from_inputs(source_text, metadata_json, input_hint) def _prompt_field(text: str, label: str) -> str: - text = _clean_text(text) - if not text: - return "" - labels = "|".join(re.escape(name) for name in PROMPT_FIELD_LABELS) - pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)" - match = re.search(pattern, text) - if not match: - return "" - return _clean_text(match.group(1)).rstrip(".") + return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS) def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str: - value = _clean_text(row.get(key, "")) - if value: - return value - prompt = _clean_text(row.get("prompt", "")) - for label in labels: - value = _prompt_field(prompt, label) - if value: - return value - return "" + return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS) def _field_from_any_prompt(text: str, labels: tuple[str, ...]) -> str: for label in labels: - value = _prompt_field(text, label) + value = input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS) if value: return value return "" diff --git a/docs/prompt-architecture-improvement-plan.md b/docs/prompt-architecture-improvement-plan.md index 54b6c08..dbeb977 100644 --- a/docs/prompt-architecture-improvement-plan.md +++ b/docs/prompt-architecture-improvement-plan.md @@ -62,6 +62,23 @@ route-specific owner. It also preserves ordinary words such as `composition` inside normal sentences; empty field-label cleanup is limited to standalone labels. +Formatter input/fallback parsing now has one home: + +- `formatter_input.py` + +It owns route-neutral parsing shared by Krea2, SDXL, and natural-caption +routes: + +- whitespace and punctuation normalization before formatter parsing; +- JSON row detection from `metadata_json` or source text; +- trigger-prefix stripping with route-specific trigger candidate lists; +- `Avoid:` positive/negative splitting for fallback text; +- prompt field extraction such as `Setting:` or `Composition:`; +- row-value fallback from metadata fields to labeled prompt text. + +It must not make formatter-style decisions. Krea prose, SDXL tags, and training +caption sentence shape stay in their formatter modules. + Shared hardcore phrase cleanup now has one home: - `hardcore_text_cleanup.py` @@ -242,6 +259,9 @@ Already isolated: - `krea_pov_actions.py` owns POV hardcore action sentence rewriting, first-person body geometry, and selected-position-axis priority before loose context fallback. +- `formatter_input.py` owns shared metadata/source JSON detection, trigger + stripping, prompt-field extraction, `Avoid:` splitting, and row-value + fallback for Krea, SDXL, and caption routes. Improve later: @@ -262,6 +282,7 @@ Keep here: - negative-prompt assembly. - metadata-family tag hints from `action_family`, `position_family`, and `position_keys`. +- shared formatter input parsing from `formatter_input.py`. Improve later: @@ -280,6 +301,7 @@ Keep here: - training-caption trigger behavior; - style-tail policy. - metadata-family action labels from `action_family` and `position_family`. +- shared formatter input parsing from `formatter_input.py`. Improve later: diff --git a/docs/prompt-pool-routing-map.md b/docs/prompt-pool-routing-map.md index cfefb47..6081102 100644 --- a/docs/prompt-pool-routing-map.md +++ b/docs/prompt-pool-routing-map.md @@ -94,6 +94,7 @@ Core helper ownership: | `scene_camera_adapters.py` | Location-aware camera/scene prose such as coworking lounge camera layout. | | `prompt_hygiene.py` | Generic prompt, caption, and negative-prompt cleanup. | | `row_normalization.py` | Final prompt-row and pair metadata normalization: trigger prepending, extra-positive append, negative merge/dedupe, caption-part joining, and embedded soft/hard row sanitation. | +| `formatter_input.py` | Shared formatter input parsing: text cleanup, metadata/source JSON detection, trigger-prefix stripping, `Avoid:` splitting, prompt-field extraction, and metadata row-value fallback. | ## Node IO Map diff --git a/formatter_input.py b/formatter_input.py new file mode 100644 index 0000000..f64570c --- /dev/null +++ b/formatter_input.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import json +import re +from typing import Any + + +DEFAULT_PROMPT_FIELD_LABELS = ( + "Ages", + "Body types", + "Cast", + "Cast descriptors", + "Characters", + "Scene", + "Setting", + "Pose", + "Sexual pose", + "Sexual scene", + "Facial expression", + "Facial expressions", + "Clothing", + "Erotic outfit", + "Prop/detail", + "Composition", + "Role graph", + "Camera", + "Camera control", + "Use", + "Avoid", +) + + +def clean_text(value: Any) -> str: + text = "" if value is None else str(value) + text = text.replace("\n", " ") + text = re.sub(r"\s+", " ", text).strip() + text = re.sub(r"\s+([,.;:])", r"\1", text) + return text + + +def maybe_json(text: Any) -> dict[str, Any] | None: + text = clean_text(text) + if not text.startswith("{"): + return None + try: + value = json.loads(text) + except json.JSONDecodeError: + return None + return value if isinstance(value, dict) else None + + +def row_from_inputs( + source_text: str, + metadata_json: str, + input_hint: str, + *, + metadata_methods: tuple[str, ...] = ("auto", "metadata_json"), +) -> tuple[dict[str, Any] | None, str]: + if input_hint in metadata_methods: + for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")): + row = maybe_json(text) + if row is not None: + return row, method + return None, "text" + + +def strip_trigger_prefix( + text: Any, + trigger_candidates: tuple[str, ...] | list[str], + *, + preserve_trigger: bool = False, + remove_exact: bool = False, +) -> str: + text = clean_text(text) + if remove_exact: + text = text.strip(" ,") + if preserve_trigger: + return text + for trigger in trigger_candidates: + trigger = clean_text(trigger) + if not trigger: + continue + if text.lower().startswith(trigger.lower() + ","): + return text[len(trigger) + 1 :].strip(" ,") + if text.lower().startswith(trigger.lower() + "."): + return text[len(trigger) + 1 :].strip(" ,") + if remove_exact and text.lower() == trigger.lower(): + return "" + return text + + +def split_avoid(text: Any) -> tuple[str, str]: + text = clean_text(text) + match = re.search(r"\bAvoid:\s*(.*)$", text) + if not match: + return text, "" + return text[: match.start()].strip(" ."), match.group(1).strip(" .") + + +def prompt_field( + text: Any, + label: str, + *, + field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS, +) -> str: + text = clean_text(text) + if not text: + return "" + labels = "|".join(re.escape(name) for name in field_labels) + pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)" + match = re.search(pattern, text) + if not match: + return "" + return clean_text(match.group(1)).rstrip(".") + + +def row_value( + row: dict[str, Any], + key: str, + labels: tuple[str, ...] = (), + *, + field_labels: tuple[str, ...] | list[str] = DEFAULT_PROMPT_FIELD_LABELS, +) -> str: + value = clean_text(row.get(key, "")) + if value: + return value + prompt = clean_text(row.get("prompt", "")) + for label in labels: + value = prompt_field(prompt, label, field_labels=field_labels) + if value: + return value + return "" diff --git a/krea_formatter.py b/krea_formatter.py index 78e45ed..7852fb9 100644 --- a/krea_formatter.py +++ b/krea_formatter.py @@ -1,10 +1,10 @@ from __future__ import annotations -import json import re from typing import Any try: + from . import formatter_input as input_policy from .krea_action_context import ( is_close_foreplay_text as _is_close_foreplay_text, is_outercourse_text as _is_outercourse_text, @@ -34,6 +34,7 @@ try: from .krea_pov_actions import pov_action_phrase as _pov_action_phrase from .prompt_hygiene import sanitize_negative_text, sanitize_prose_text except ImportError: # Allows local smoke tests with `python -c`. + import formatter_input as input_policy from krea_action_context import ( is_close_foreplay_text as _is_close_foreplay_text, is_outercourse_text as _is_outercourse_text, @@ -91,11 +92,7 @@ PROMPT_FIELD_LABELS = ( def _clean(value: Any) -> str: - text = "" if value is None else str(value) - text = text.replace("\n", " ") - text = re.sub(r"\s+", " ", text).strip() - text = re.sub(r"\s+([,.;:])", r"\1", text) - return text + return input_policy.clean_text(value) def _is_false(value: Any) -> bool: @@ -133,69 +130,27 @@ def _with_indefinite_article(text: str) -> str: def _maybe_json(text: str) -> dict[str, Any] | None: - text = _clean(text) - if not text.startswith("{"): - return None - try: - value = json.loads(text) - except json.JSONDecodeError: - return None - return value if isinstance(value, dict) else None + return input_policy.maybe_json(text) def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]: - candidates: list[tuple[str, str]] = [] - if input_hint in ("auto", "metadata_json"): - candidates.append((metadata_json, "metadata_json")) - candidates.append((source_text, "source_json")) - for text, method in candidates: - row = _maybe_json(text) - if row is not None: - return row, method - return None, "text" + return input_policy.row_from_inputs(source_text, metadata_json, input_hint) def _strip_trigger(text: str, preserve_trigger: bool) -> str: - text = _clean(text) - if preserve_trigger: - return text - for trigger in TRIGGER_CANDIDATES: - if text.lower().startswith(trigger.lower() + ","): - return text[len(trigger) + 1 :].strip(" ,") - if text.lower().startswith(trigger.lower() + "."): - return text[len(trigger) + 1 :].strip(" ,") - return text + return input_policy.strip_trigger_prefix(text, TRIGGER_CANDIDATES, preserve_trigger=preserve_trigger) def _split_avoid(text: str) -> tuple[str, str]: - match = re.search(r"\bAvoid:\s*(.*)$", text) - if not match: - return text, "" - return text[: match.start()].strip(" ."), match.group(1).strip(" .") + return input_policy.split_avoid(text) def _prompt_field(text: str, label: str) -> str: - text = _clean(text) - if not text: - return "" - labels = "|".join(re.escape(name) for name in PROMPT_FIELD_LABELS) - pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)" - match = re.search(pattern, text) - if not match: - return "" - return _clean(match.group(1)).rstrip(".") + return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS) def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str: - value = _clean(row.get(key, "")) - if value: - return value - prompt = _clean(row.get("prompt", "")) - for label in labels: - value = _prompt_field(prompt, label) - if value: - return value - return "" + return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS) def _body_phrase(body: Any, figure_note: Any = "") -> str: diff --git a/sdxl_formatter.py b/sdxl_formatter.py index e2c8476..99f5d40 100644 --- a/sdxl_formatter.py +++ b/sdxl_formatter.py @@ -1,13 +1,14 @@ from __future__ import annotations -import json import re from typing import Any try: + from . import formatter_input as input_policy from .hardcore_action_metadata import normalize_hardcore_action_family from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt except ImportError: # Allows local smoke tests with `python -c`. + import formatter_input as input_policy from hardcore_action_metadata import normalize_hardcore_action_family from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt @@ -95,74 +96,31 @@ def sdxl_quality_preset_choices() -> list[str]: def _clean(value: Any) -> str: - text = "" if value is None else str(value) - text = text.replace("\n", " ") - text = re.sub(r"\s+", " ", text).strip() - text = re.sub(r"\s+([,.;:])", r"\1", text) - return text + return input_policy.clean_text(value) def _maybe_json(text: str) -> dict[str, Any] | None: - text = _clean(text) - if not text.startswith("{"): - return None - try: - value = json.loads(text) - except json.JSONDecodeError: - return None - return value if isinstance(value, dict) else None + return input_policy.maybe_json(text) def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]: - if input_hint in ("auto", "metadata_json"): - for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")): - row = _maybe_json(text) - if row is not None: - return row, method - return None, "text" + return input_policy.row_from_inputs(source_text, metadata_json, input_hint) def _strip_trigger(text: str, preserve_trigger: bool) -> str: - text = _clean(text) - if preserve_trigger: - return text - for trigger in TRIGGER_CANDIDATES: - if text.lower().startswith(trigger.lower() + ","): - return text[len(trigger) + 1 :].strip(" ,") - if text.lower().startswith(trigger.lower() + "."): - return text[len(trigger) + 1 :].strip(" ,") - return text + return input_policy.strip_trigger_prefix(text, TRIGGER_CANDIDATES, preserve_trigger=preserve_trigger) def _split_avoid(text: str) -> tuple[str, str]: - match = re.search(r"\bAvoid:\s*(.*)$", text) - if not match: - return text, "" - return text[: match.start()].strip(" ."), match.group(1).strip(" .") + return input_policy.split_avoid(text) def _prompt_field(text: str, label: str) -> str: - text = _clean(text) - if not text: - return "" - labels = "|".join(re.escape(name) for name in PROMPT_FIELD_LABELS) - pattern = rf"{re.escape(label)}:\s*(.*?)(?=\. (?:{labels}):|\. Use\b|\. Avoid\b|$)" - match = re.search(pattern, text) - if not match: - return "" - return _clean(match.group(1)).rstrip(".") + return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS) def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str: - value = _clean(row.get(key, "")) - if value: - return value - prompt = _clean(row.get("prompt", "")) - for label in labels: - value = _prompt_field(prompt, label) - if value: - return value - return "" + return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS) def _split_tag_text(text: Any) -> list[str]: diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 2c8a861..e60c792 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -29,6 +29,7 @@ import character_profile # noqa: E402 import category_cast_config # noqa: E402 import category_library # noqa: E402 import filter_config # noqa: E402 +import formatter_input # noqa: E402 import hardcore_position_config # noqa: E402 import __init__ as sxcp_nodes # noqa: E402 import generation_profile_config # noqa: E402 @@ -847,6 +848,57 @@ def smoke_row_normalization_policy() -> None: _expect_no_duplicate_comma_items("row_normalization.pair.hard_row_negative", pair["hardcore_row"].get("negative_prompt")) +def smoke_formatter_input_policy() -> None: + source_row = { + "prompt": "A simple adult portrait. Setting: quiet studio. Pose: standing calmly. Avoid: low quality.", + "caption": "adult portrait, quiet studio", + "negative_prompt": "low quality", + "subject_type": "woman", + "primary_subject": "woman", + "age": "25-year-old adult", + "body_phrase": "average figure", + "skin": "warm skin", + "hair": "dark hair", + "eyes": "brown eyes", + "item": "black dress", + "scene_text": "quiet studio", + "pose": "standing calmly", + "composition": "centered portrait", + "trigger": Trigger, + } + source_json = _json(source_row) + + row, method = formatter_input.row_from_inputs(source_json, "", "auto") + _expect(method == "source_json", "Formatter input parser should read source JSON when metadata is empty") + _expect(row == source_row, "Formatter input parser changed parsed JSON row") + _expect(formatter_input.split_avoid("Prompt body. Avoid: blur, watermark") == ("Prompt body", "blur, watermark"), "Avoid split changed") + _expect( + formatter_input.prompt_field(source_row["prompt"], "Setting") == "quiet studio", + "Prompt field extraction changed", + ) + _expect( + formatter_input.row_value({"prompt": source_row["prompt"]}, "scene_text", ("Setting",)) == "quiet studio", + "Row value prompt fallback changed", + ) + + _expect(krea_formatter._clean("a b , c") == formatter_input.clean_text("a b , c"), "Krea clean helper is not delegated") + _expect(sdxl_formatter._clean("a b , c") == formatter_input.clean_text("a b , c"), "SDXL clean helper is not delegated") + _expect(caption_naturalizer._clean_text("a b , c") == formatter_input.clean_text("a b , c"), "Caption clean helper is not delegated") + _expect(krea_formatter._strip_trigger(f"{Trigger}, prompt text", False) == "prompt text", "Krea trigger stripping changed") + _expect(sdxl_formatter._strip_trigger(f"{SdxlTrigger}, prompt text", False) == "prompt text", "SDXL trigger stripping changed") + _expect(caption_naturalizer._remove_trigger(Trigger, Trigger) == "", "Caption exact-trigger removal changed") + + krea = krea_formatter.format_krea2_prompt(source_json, input_hint="auto") + sdxl = sdxl_formatter.format_sdxl_prompt(source_json, input_hint="auto", trigger=SdxlTrigger, prepend_trigger=True) + caption, caption_method = caption_naturalizer.naturalize_caption(source_json, input_hint="auto", trigger=Trigger) + _expect(krea.get("method", "").startswith("source_json:krea2("), "Krea formatter did not use shared source JSON parsing") + _expect(sdxl.get("method", "").startswith("source_json:sdxl("), "SDXL formatter did not use shared source JSON parsing") + _expect(caption_method.startswith("source_json:metadata("), "Caption naturalizer did not use shared source JSON parsing") + _expect_text("formatter_input.krea_prompt", krea.get("krea_prompt"), 20) + _expect_text("formatter_input.sdxl_prompt", sdxl.get("sdxl_prompt"), 20) + _expect_text("formatter_input.caption", caption, 20) + + def smoke_hardcore_position_config_policy() -> None: _expect( pb.HARDCORE_POSITION_FAMILY_CHOICES is hardcore_position_config.HARDCORE_POSITION_FAMILY_CHOICES, @@ -2818,6 +2870,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [ ("character_config_policy", smoke_character_config_policy), ("character_profile_policy", smoke_character_profile_policy), ("row_normalization_policy", smoke_row_normalization_policy), + ("formatter_input_policy", smoke_formatter_input_policy), ("hardcore_position_config_policy", smoke_hardcore_position_config_policy), ("category_library_route", smoke_category_library_route), ("hardcore_category_routes", smoke_hardcore_category_routes),