From c4d5477bf958d166095c4875a6db6b484f718114 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 27 Jun 2026 13:42:06 +0200 Subject: [PATCH] Centralize formatter target policy --- caption_format_route.py | 7 ++- caption_metadata_routes.py | 12 +++-- docs/prompt-architecture-improvement-plan.md | 24 ++++++--- formatter_target.py | 57 ++++++++++++++++++++ krea_format_route.py | 12 +++-- sdxl_format_route.py | 14 +++-- tools/prompt_smoke.py | 27 ++++++++++ 7 files changed, 137 insertions(+), 16 deletions(-) create mode 100644 formatter_target.py diff --git a/caption_format_route.py b/caption_format_route.py index 19f5bbf..9f557ca 100644 --- a/caption_format_route.py +++ b/caption_format_route.py @@ -3,6 +3,11 @@ from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable +try: + from . import formatter_target as target_policy +except ImportError: # pragma: no cover - plain-script smoke tests + import formatter_target as target_policy + @dataclass(frozen=True) class CaptionFormatRequest: @@ -49,7 +54,7 @@ def naturalize_caption_result( deps: CaptionFormatDependencies, ) -> CaptionFormatRoute: input_hint = request.input_hint if request.input_hint in ("auto", "metadata_json", "caption_or_prompt") else "auto" - target = request.target if request.target in ("auto", "single", "softcore", "hardcore") else "auto" + target = target_policy.normalize_target(request.target) detail_level, style_policy, include_trigger = deps.apply_caption_profile( request.caption_profile, request.detail_level, diff --git a/caption_metadata_routes.py b/caption_metadata_routes.py index 6df922e..0c7e896 100644 --- a/caption_metadata_routes.py +++ b/caption_metadata_routes.py @@ -4,6 +4,11 @@ import re from dataclasses import dataclass from typing import Any, Callable +try: + from . import formatter_target as target_policy +except ImportError: # pragma: no cover - plain-script smoke tests + import formatter_target as target_policy + @dataclass(frozen=True) class CaptionMetadataRouteRequest: @@ -301,7 +306,8 @@ def insta_of_pair_from_row_result( row = request.row detail_level = request.detail_level keep_style = request.keep_style - target = request.target if request.target in ("softcore", "hardcore") else "auto" + pair_target = target_policy.pair_policy(request.target) + target = pair_target.pair_target if deps.clean_text(row.get("mode")).lower() != "insta/of": return None soft_row = row.get("softcore_row") @@ -317,8 +323,8 @@ def insta_of_pair_from_row_result( if soft_row.get("composition"): hard_row_for_text["composition"] = soft_row["composition"] - include_soft = target in ("auto", "softcore") - include_hard = target in ("auto", "hardcore") + include_soft = pair_target.include_softcore + include_hard = pair_target.include_hardcore soft_text = "" hard_text = "" if include_soft: diff --git a/docs/prompt-architecture-improvement-plan.md b/docs/prompt-architecture-improvement-plan.md index baf13aa..bb811c8 100644 --- a/docs/prompt-architecture-improvement-plan.md +++ b/docs/prompt-architecture-improvement-plan.md @@ -82,6 +82,15 @@ routes: It must not make formatter-style decisions. Krea prose, SDXL tags, and training caption sentence shape stay in their formatter modules. +Formatter target handling now has one home: + +- `formatter_target.py` + +It owns route-neutral target normalization for `auto`, `single`, `softcore`, +and `hardcore`, including pair-side semantics. Single-output formatters select +the softcore side for pair `auto`/`single` targets, while caption pair routing +can still include both sides for combined training captions. + Shared hardcore phrase cleanup now has one home: - `hardcore_text_cleanup.py` @@ -345,7 +354,8 @@ Already isolated: - `krea_format_route.py` owns top-level Krea dispatch, including option normalization, metadata-vs-text input selection, single-vs-pair branching, - extra positive/negative merging, final prose hygiene, and output shape; + shared target normalization via `formatter_target.py`, extra + positive/negative merging, final prose hygiene, and output shape; `krea_formatter.py` keeps the public wrapper. - `krea_configured_cast_formatter.py` owns normal metadata configured-cast Krea prose assembly behind `KreaConfiguredCastRequest`, @@ -417,9 +427,10 @@ Keep here: Already isolated: - `sdxl_format_route.py` owns top-level SDXL dispatch, including formatter - profile application, target and nude-weight normalization, metadata-vs-text - input selection, single-vs-pair branching, final prompt/negative output - shape, and fallback routing; `sdxl_formatter.py` keeps the public wrapper. + profile application, shared target normalization via `formatter_target.py`, + nude-weight normalization, metadata-vs-text input selection, single-vs-pair + branching, final prompt/negative output shape, and fallback routing; + `sdxl_formatter.py` keeps the public wrapper. - `sdxl_tag_routes.py` owns normal metadata row tags and Insta/OF pair soft/hard tag extraction behind `SDXLRowTagRequest`, `SDXLPairTagRequest`, `SDXLTagRouteDependencies`, and `SDXLTagRoute`; `sdxl_formatter.py` keeps @@ -455,8 +466,9 @@ Keep here: Already isolated: - `caption_format_route.py` owns top-level caption dispatch, including input - hint normalization, caption profile application, metadata-vs-text branching, - trigger wrapping, final prose hygiene, and method/output shape; + hint normalization, shared target normalization via `formatter_target.py`, + caption profile application, metadata-vs-text branching, trigger wrapping, + final prose hygiene, and method/output shape; `caption_naturalizer.py` keeps the public wrapper. - `caption_metadata_routes.py` owns metadata row natural-language assembly for single, couple, configured-cast, group/layout, and Insta/OF pair routes behind diff --git a/formatter_target.py b/formatter_target.py new file mode 100644 index 0000000..3705e98 --- /dev/null +++ b/formatter_target.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +FORMATTER_TARGETS = ("auto", "single", "softcore", "hardcore") +PAIR_SIDE_TARGETS = ("softcore", "hardcore") +DEFAULT_FORMATTER_TARGET = "auto" +DEFAULT_PAIR_SELECTED_SIDE = "softcore" + +_TARGET_ALIASES = { + "soft": "softcore", + "soft_core": "softcore", + "hard": "hardcore", + "hard_core": "hardcore", +} + + +@dataclass(frozen=True) +class PairTargetPolicy: + target: str + pair_target: str + selected_side: str + include_softcore: bool + include_hardcore: bool + + +def normalize_target(value: Any) -> str: + target = str(value or "").strip().lower().replace("-", "_").replace(" ", "_") + target = _TARGET_ALIASES.get(target, target) + return target if target in FORMATTER_TARGETS else DEFAULT_FORMATTER_TARGET + + +def pair_target(value: Any) -> str: + target = normalize_target(value) + return target if target in PAIR_SIDE_TARGETS else DEFAULT_FORMATTER_TARGET + + +def pair_selected_side(value: Any, default: str = DEFAULT_PAIR_SELECTED_SIDE) -> str: + side = pair_target(value) + if side in PAIR_SIDE_TARGETS: + return side + return default if default in PAIR_SIDE_TARGETS else DEFAULT_PAIR_SELECTED_SIDE + + +def pair_policy(value: Any, *, selected_default: str = DEFAULT_PAIR_SELECTED_SIDE) -> PairTargetPolicy: + target = normalize_target(value) + side_target = pair_target(target) + selected_side = pair_selected_side(side_target, selected_default) + return PairTargetPolicy( + target=target, + pair_target=side_target, + selected_side=selected_side, + include_softcore=side_target in ("auto", "softcore"), + include_hardcore=side_target in ("auto", "hardcore"), + ) diff --git a/krea_format_route.py b/krea_format_route.py index 0ccf206..d12a4f6 100644 --- a/krea_format_route.py +++ b/krea_format_route.py @@ -3,6 +3,11 @@ from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable +try: + from . import formatter_target as target_policy +except ImportError: # pragma: no cover - plain-script smoke tests + import formatter_target as target_policy + @dataclass(frozen=True) class KreaFormatRequest: @@ -45,10 +50,11 @@ class KreaFormatDependencies: def format_krea2_prompt_result(request: KreaFormatRequest, deps: KreaFormatDependencies) -> KreaFormatRoute: detail_level = request.detail_level if request.detail_level in ("concise", "balanced", "dense") else "balanced" style_mode = request.style_mode if request.style_mode in ("preserve", "photographic", "minimal") else "preserve" - target = request.target if request.target in ("auto", "single", "softcore", "hardcore") else "auto" + target = target_policy.normalize_target(request.target) row, method = deps.row_from_inputs(request.source_text, request.metadata_json, request.input_hint) if row and row.get("mode") == "Insta/OF": + pair_target = target_policy.pair_policy(target) soft_prompt, soft_negative, hard_prompt, hard_negative = deps.insta_pair_to_krea( row, detail_level, @@ -63,8 +69,8 @@ def format_krea2_prompt_result(request: KreaFormatRequest, deps: KreaFormatDepen hard_prompt = f"{hard_prompt.rstrip()} {request.extra_positive.strip()}" soft_prompt = deps.sanitize_prose_text(soft_prompt, triggers=deps.trigger_candidates) hard_prompt = deps.sanitize_prose_text(hard_prompt, triggers=deps.trigger_candidates) - selected = hard_prompt if target == "hardcore" else soft_prompt if target == "softcore" else soft_prompt - selected_negative = hard_negative if target == "hardcore" else soft_negative + selected = hard_prompt if pair_target.selected_side == "hardcore" else soft_prompt + selected_negative = hard_negative if pair_target.selected_side == "hardcore" else soft_negative negative = deps.sanitize_negative_text( deps.combine_negative(selected_negative, request.negative_prompt, request.extra_negative) ) diff --git a/sdxl_format_route.py b/sdxl_format_route.py index b06a165..6ff37c5 100644 --- a/sdxl_format_route.py +++ b/sdxl_format_route.py @@ -3,6 +3,11 @@ from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable +try: + from . import formatter_target as target_policy +except ImportError: # pragma: no cover - plain-script smoke tests + import formatter_target as target_policy + @dataclass(frozen=True) class SDXLFormatRequest: @@ -56,11 +61,12 @@ def format_sdxl_prompt_result(request: SDXLFormatRequest, deps: SDXLFormatDepend request.style_preset, request.quality_preset, ) - target = request.target if request.target in ("auto", "single", "softcore", "hardcore") else "auto" + target = target_policy.normalize_target(request.target) nude_weight = max(0.1, min(3.0, float(request.nude_weight))) row, method = deps.row_from_inputs(request.source_text, request.metadata_json, request.input_hint) if row and row.get("mode") == "Insta/OF": + pair_target = target_policy.pair_policy(target) soft_row = row.get("softcore_row") if isinstance(row.get("softcore_row"), dict) else {} hard_row = row.get("hardcore_row") if isinstance(row.get("hardcore_row"), dict) else {} soft_body = deps.soft_tags(soft_row, row, nude_weight) @@ -85,9 +91,11 @@ def format_sdxl_prompt_result(request: SDXLFormatRequest, deps: SDXLFormatDepend request.custom_quality, request.extra_positive, ) - selected = hard_prompt if target == "hardcore" else soft_prompt + selected = hard_prompt if pair_target.selected_side == "hardcore" else soft_prompt selected_negative = ( - row.get("hardcore_negative_prompt") if target == "hardcore" else row.get("softcore_negative_prompt") + row.get("hardcore_negative_prompt") + if pair_target.selected_side == "hardcore" + else row.get("softcore_negative_prompt") ) output = { "sdxl_prompt": selected, diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 722d2f8..0a62841 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -42,6 +42,7 @@ import category_cast_config # noqa: E402 import category_library # noqa: E402 import filter_config # noqa: E402 import formatter_input # noqa: E402 +import formatter_target # noqa: E402 import hardcore_position_config # noqa: E402 import __init__ as sxcp_nodes # noqa: E402 import generation_profile_config # noqa: E402 @@ -2759,6 +2760,31 @@ def smoke_formatter_input_policy() -> None: _expect("blur" in fallback_sdxl.get("negative_prompt", ""), "SDXL fallback lost Avoid negative text") +def smoke_formatter_target_policy() -> None: + _expect(formatter_target.normalize_target("single") == "single", "Formatter target lost single") + _expect(formatter_target.normalize_target("Hard-Core") == "hardcore", "Formatter target alias lost hardcore") + _expect(formatter_target.normalize_target("soft") == "softcore", "Formatter target alias lost softcore") + _expect(formatter_target.normalize_target("bad target") == "auto", "Formatter target should normalize invalid values") + + auto_pair = formatter_target.pair_policy("auto") + _expect(auto_pair.target == "auto", "Pair target policy lost normalized auto target") + _expect(auto_pair.pair_target == "auto", "Pair target policy lost auto pair target") + _expect(auto_pair.selected_side == "softcore", "Pair auto should select softcore side for single-output formatters") + _expect(auto_pair.include_softcore and auto_pair.include_hardcore, "Pair auto should include both sides for combined captions") + + single_pair = formatter_target.pair_policy("single") + _expect(single_pair.target == "single", "Pair target policy lost normalized single target") + _expect(single_pair.pair_target == "auto", "Pair single should map to auto for pair inclusion") + _expect(single_pair.selected_side == "softcore", "Pair single should select softcore side by default") + _expect(single_pair.include_softcore and single_pair.include_hardcore, "Pair single should include both sides when treated as auto") + + hard_pair = formatter_target.pair_policy("hard") + _expect(hard_pair.target == "hardcore", "Pair target policy lost hard alias") + _expect(hard_pair.pair_target == "hardcore", "Pair hard alias should become hardcore pair target") + _expect(hard_pair.selected_side == "hardcore", "Pair hardcore should select hardcore side") + _expect(not hard_pair.include_softcore and hard_pair.include_hardcore, "Pair hardcore should include only hard side") + + def smoke_krea_format_route_policy() -> None: row = _prompt_row( name="krea_format_route_single", @@ -6557,6 +6583,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [ ("row_role_graph_policy", smoke_row_role_graph_policy), ("row_assembly_policy", smoke_row_assembly_policy), ("formatter_input_policy", smoke_formatter_input_policy), + ("formatter_target_policy", smoke_formatter_target_policy), ("krea_format_route_policy", smoke_krea_format_route_policy), ("formatter_cast_policy", smoke_formatter_cast_policy), ("caption_policy", smoke_caption_policy),