Centralize formatter target policy
This commit is contained in:
@@ -3,6 +3,11 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
try:
|
||||
from . import formatter_target as target_policy
|
||||
except ImportError: # pragma: no cover - plain-script smoke tests
|
||||
import formatter_target as target_policy
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CaptionFormatRequest:
|
||||
@@ -49,7 +54,7 @@ def naturalize_caption_result(
|
||||
deps: CaptionFormatDependencies,
|
||||
) -> CaptionFormatRoute:
|
||||
input_hint = request.input_hint if request.input_hint in ("auto", "metadata_json", "caption_or_prompt") else "auto"
|
||||
target = request.target if request.target in ("auto", "single", "softcore", "hardcore") else "auto"
|
||||
target = target_policy.normalize_target(request.target)
|
||||
detail_level, style_policy, include_trigger = deps.apply_caption_profile(
|
||||
request.caption_profile,
|
||||
request.detail_level,
|
||||
|
||||
@@ -4,6 +4,11 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
try:
|
||||
from . import formatter_target as target_policy
|
||||
except ImportError: # pragma: no cover - plain-script smoke tests
|
||||
import formatter_target as target_policy
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CaptionMetadataRouteRequest:
|
||||
@@ -301,7 +306,8 @@ def insta_of_pair_from_row_result(
|
||||
row = request.row
|
||||
detail_level = request.detail_level
|
||||
keep_style = request.keep_style
|
||||
target = request.target if request.target in ("softcore", "hardcore") else "auto"
|
||||
pair_target = target_policy.pair_policy(request.target)
|
||||
target = pair_target.pair_target
|
||||
if deps.clean_text(row.get("mode")).lower() != "insta/of":
|
||||
return None
|
||||
soft_row = row.get("softcore_row")
|
||||
@@ -317,8 +323,8 @@ def insta_of_pair_from_row_result(
|
||||
if soft_row.get("composition"):
|
||||
hard_row_for_text["composition"] = soft_row["composition"]
|
||||
|
||||
include_soft = target in ("auto", "softcore")
|
||||
include_hard = target in ("auto", "hardcore")
|
||||
include_soft = pair_target.include_softcore
|
||||
include_hard = pair_target.include_hardcore
|
||||
soft_text = ""
|
||||
hard_text = ""
|
||||
if include_soft:
|
||||
|
||||
@@ -82,6 +82,15 @@ routes:
|
||||
It must not make formatter-style decisions. Krea prose, SDXL tags, and training
|
||||
caption sentence shape stay in their formatter modules.
|
||||
|
||||
Formatter target handling now has one home:
|
||||
|
||||
- `formatter_target.py`
|
||||
|
||||
It owns route-neutral target normalization for `auto`, `single`, `softcore`,
|
||||
and `hardcore`, including pair-side semantics. Single-output formatters select
|
||||
the softcore side for pair `auto`/`single` targets, while caption pair routing
|
||||
can still include both sides for combined training captions.
|
||||
|
||||
Shared hardcore phrase cleanup now has one home:
|
||||
|
||||
- `hardcore_text_cleanup.py`
|
||||
@@ -345,7 +354,8 @@ Already isolated:
|
||||
|
||||
- `krea_format_route.py` owns top-level Krea dispatch, including option
|
||||
normalization, metadata-vs-text input selection, single-vs-pair branching,
|
||||
extra positive/negative merging, final prose hygiene, and output shape;
|
||||
shared target normalization via `formatter_target.py`, extra
|
||||
positive/negative merging, final prose hygiene, and output shape;
|
||||
`krea_formatter.py` keeps the public wrapper.
|
||||
- `krea_configured_cast_formatter.py` owns normal metadata configured-cast
|
||||
Krea prose assembly behind `KreaConfiguredCastRequest`,
|
||||
@@ -417,9 +427,10 @@ Keep here:
|
||||
Already isolated:
|
||||
|
||||
- `sdxl_format_route.py` owns top-level SDXL dispatch, including formatter
|
||||
profile application, target and nude-weight normalization, metadata-vs-text
|
||||
input selection, single-vs-pair branching, final prompt/negative output
|
||||
shape, and fallback routing; `sdxl_formatter.py` keeps the public wrapper.
|
||||
profile application, shared target normalization via `formatter_target.py`,
|
||||
nude-weight normalization, metadata-vs-text input selection, single-vs-pair
|
||||
branching, final prompt/negative output shape, and fallback routing;
|
||||
`sdxl_formatter.py` keeps the public wrapper.
|
||||
- `sdxl_tag_routes.py` owns normal metadata row tags and Insta/OF pair soft/hard
|
||||
tag extraction behind `SDXLRowTagRequest`, `SDXLPairTagRequest`,
|
||||
`SDXLTagRouteDependencies`, and `SDXLTagRoute`; `sdxl_formatter.py` keeps
|
||||
@@ -455,8 +466,9 @@ Keep here:
|
||||
Already isolated:
|
||||
|
||||
- `caption_format_route.py` owns top-level caption dispatch, including input
|
||||
hint normalization, caption profile application, metadata-vs-text branching,
|
||||
trigger wrapping, final prose hygiene, and method/output shape;
|
||||
hint normalization, shared target normalization via `formatter_target.py`,
|
||||
caption profile application, metadata-vs-text branching, trigger wrapping,
|
||||
final prose hygiene, and method/output shape;
|
||||
`caption_naturalizer.py` keeps the public wrapper.
|
||||
- `caption_metadata_routes.py` owns metadata row natural-language assembly for
|
||||
single, couple, configured-cast, group/layout, and Insta/OF pair routes behind
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
FORMATTER_TARGETS = ("auto", "single", "softcore", "hardcore")
|
||||
PAIR_SIDE_TARGETS = ("softcore", "hardcore")
|
||||
DEFAULT_FORMATTER_TARGET = "auto"
|
||||
DEFAULT_PAIR_SELECTED_SIDE = "softcore"
|
||||
|
||||
_TARGET_ALIASES = {
|
||||
"soft": "softcore",
|
||||
"soft_core": "softcore",
|
||||
"hard": "hardcore",
|
||||
"hard_core": "hardcore",
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PairTargetPolicy:
|
||||
target: str
|
||||
pair_target: str
|
||||
selected_side: str
|
||||
include_softcore: bool
|
||||
include_hardcore: bool
|
||||
|
||||
|
||||
def normalize_target(value: Any) -> str:
|
||||
target = str(value or "").strip().lower().replace("-", "_").replace(" ", "_")
|
||||
target = _TARGET_ALIASES.get(target, target)
|
||||
return target if target in FORMATTER_TARGETS else DEFAULT_FORMATTER_TARGET
|
||||
|
||||
|
||||
def pair_target(value: Any) -> str:
|
||||
target = normalize_target(value)
|
||||
return target if target in PAIR_SIDE_TARGETS else DEFAULT_FORMATTER_TARGET
|
||||
|
||||
|
||||
def pair_selected_side(value: Any, default: str = DEFAULT_PAIR_SELECTED_SIDE) -> str:
|
||||
side = pair_target(value)
|
||||
if side in PAIR_SIDE_TARGETS:
|
||||
return side
|
||||
return default if default in PAIR_SIDE_TARGETS else DEFAULT_PAIR_SELECTED_SIDE
|
||||
|
||||
|
||||
def pair_policy(value: Any, *, selected_default: str = DEFAULT_PAIR_SELECTED_SIDE) -> PairTargetPolicy:
|
||||
target = normalize_target(value)
|
||||
side_target = pair_target(target)
|
||||
selected_side = pair_selected_side(side_target, selected_default)
|
||||
return PairTargetPolicy(
|
||||
target=target,
|
||||
pair_target=side_target,
|
||||
selected_side=selected_side,
|
||||
include_softcore=side_target in ("auto", "softcore"),
|
||||
include_hardcore=side_target in ("auto", "hardcore"),
|
||||
)
|
||||
@@ -3,6 +3,11 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
try:
|
||||
from . import formatter_target as target_policy
|
||||
except ImportError: # pragma: no cover - plain-script smoke tests
|
||||
import formatter_target as target_policy
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KreaFormatRequest:
|
||||
@@ -45,10 +50,11 @@ class KreaFormatDependencies:
|
||||
def format_krea2_prompt_result(request: KreaFormatRequest, deps: KreaFormatDependencies) -> KreaFormatRoute:
|
||||
detail_level = request.detail_level if request.detail_level in ("concise", "balanced", "dense") else "balanced"
|
||||
style_mode = request.style_mode if request.style_mode in ("preserve", "photographic", "minimal") else "preserve"
|
||||
target = request.target if request.target in ("auto", "single", "softcore", "hardcore") else "auto"
|
||||
target = target_policy.normalize_target(request.target)
|
||||
row, method = deps.row_from_inputs(request.source_text, request.metadata_json, request.input_hint)
|
||||
|
||||
if row and row.get("mode") == "Insta/OF":
|
||||
pair_target = target_policy.pair_policy(target)
|
||||
soft_prompt, soft_negative, hard_prompt, hard_negative = deps.insta_pair_to_krea(
|
||||
row,
|
||||
detail_level,
|
||||
@@ -63,8 +69,8 @@ def format_krea2_prompt_result(request: KreaFormatRequest, deps: KreaFormatDepen
|
||||
hard_prompt = f"{hard_prompt.rstrip()} {request.extra_positive.strip()}"
|
||||
soft_prompt = deps.sanitize_prose_text(soft_prompt, triggers=deps.trigger_candidates)
|
||||
hard_prompt = deps.sanitize_prose_text(hard_prompt, triggers=deps.trigger_candidates)
|
||||
selected = hard_prompt if target == "hardcore" else soft_prompt if target == "softcore" else soft_prompt
|
||||
selected_negative = hard_negative if target == "hardcore" else soft_negative
|
||||
selected = hard_prompt if pair_target.selected_side == "hardcore" else soft_prompt
|
||||
selected_negative = hard_negative if pair_target.selected_side == "hardcore" else soft_negative
|
||||
negative = deps.sanitize_negative_text(
|
||||
deps.combine_negative(selected_negative, request.negative_prompt, request.extra_negative)
|
||||
)
|
||||
|
||||
+11
-3
@@ -3,6 +3,11 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
try:
|
||||
from . import formatter_target as target_policy
|
||||
except ImportError: # pragma: no cover - plain-script smoke tests
|
||||
import formatter_target as target_policy
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SDXLFormatRequest:
|
||||
@@ -56,11 +61,12 @@ def format_sdxl_prompt_result(request: SDXLFormatRequest, deps: SDXLFormatDepend
|
||||
request.style_preset,
|
||||
request.quality_preset,
|
||||
)
|
||||
target = request.target if request.target in ("auto", "single", "softcore", "hardcore") else "auto"
|
||||
target = target_policy.normalize_target(request.target)
|
||||
nude_weight = max(0.1, min(3.0, float(request.nude_weight)))
|
||||
row, method = deps.row_from_inputs(request.source_text, request.metadata_json, request.input_hint)
|
||||
|
||||
if row and row.get("mode") == "Insta/OF":
|
||||
pair_target = target_policy.pair_policy(target)
|
||||
soft_row = row.get("softcore_row") if isinstance(row.get("softcore_row"), dict) else {}
|
||||
hard_row = row.get("hardcore_row") if isinstance(row.get("hardcore_row"), dict) else {}
|
||||
soft_body = deps.soft_tags(soft_row, row, nude_weight)
|
||||
@@ -85,9 +91,11 @@ def format_sdxl_prompt_result(request: SDXLFormatRequest, deps: SDXLFormatDepend
|
||||
request.custom_quality,
|
||||
request.extra_positive,
|
||||
)
|
||||
selected = hard_prompt if target == "hardcore" else soft_prompt
|
||||
selected = hard_prompt if pair_target.selected_side == "hardcore" else soft_prompt
|
||||
selected_negative = (
|
||||
row.get("hardcore_negative_prompt") if target == "hardcore" else row.get("softcore_negative_prompt")
|
||||
row.get("hardcore_negative_prompt")
|
||||
if pair_target.selected_side == "hardcore"
|
||||
else row.get("softcore_negative_prompt")
|
||||
)
|
||||
output = {
|
||||
"sdxl_prompt": selected,
|
||||
|
||||
@@ -42,6 +42,7 @@ import category_cast_config # noqa: E402
|
||||
import category_library # noqa: E402
|
||||
import filter_config # noqa: E402
|
||||
import formatter_input # noqa: E402
|
||||
import formatter_target # noqa: E402
|
||||
import hardcore_position_config # noqa: E402
|
||||
import __init__ as sxcp_nodes # noqa: E402
|
||||
import generation_profile_config # noqa: E402
|
||||
@@ -2759,6 +2760,31 @@ def smoke_formatter_input_policy() -> None:
|
||||
_expect("blur" in fallback_sdxl.get("negative_prompt", ""), "SDXL fallback lost Avoid negative text")
|
||||
|
||||
|
||||
def smoke_formatter_target_policy() -> None:
|
||||
_expect(formatter_target.normalize_target("single") == "single", "Formatter target lost single")
|
||||
_expect(formatter_target.normalize_target("Hard-Core") == "hardcore", "Formatter target alias lost hardcore")
|
||||
_expect(formatter_target.normalize_target("soft") == "softcore", "Formatter target alias lost softcore")
|
||||
_expect(formatter_target.normalize_target("bad target") == "auto", "Formatter target should normalize invalid values")
|
||||
|
||||
auto_pair = formatter_target.pair_policy("auto")
|
||||
_expect(auto_pair.target == "auto", "Pair target policy lost normalized auto target")
|
||||
_expect(auto_pair.pair_target == "auto", "Pair target policy lost auto pair target")
|
||||
_expect(auto_pair.selected_side == "softcore", "Pair auto should select softcore side for single-output formatters")
|
||||
_expect(auto_pair.include_softcore and auto_pair.include_hardcore, "Pair auto should include both sides for combined captions")
|
||||
|
||||
single_pair = formatter_target.pair_policy("single")
|
||||
_expect(single_pair.target == "single", "Pair target policy lost normalized single target")
|
||||
_expect(single_pair.pair_target == "auto", "Pair single should map to auto for pair inclusion")
|
||||
_expect(single_pair.selected_side == "softcore", "Pair single should select softcore side by default")
|
||||
_expect(single_pair.include_softcore and single_pair.include_hardcore, "Pair single should include both sides when treated as auto")
|
||||
|
||||
hard_pair = formatter_target.pair_policy("hard")
|
||||
_expect(hard_pair.target == "hardcore", "Pair target policy lost hard alias")
|
||||
_expect(hard_pair.pair_target == "hardcore", "Pair hard alias should become hardcore pair target")
|
||||
_expect(hard_pair.selected_side == "hardcore", "Pair hardcore should select hardcore side")
|
||||
_expect(not hard_pair.include_softcore and hard_pair.include_hardcore, "Pair hardcore should include only hard side")
|
||||
|
||||
|
||||
def smoke_krea_format_route_policy() -> None:
|
||||
row = _prompt_row(
|
||||
name="krea_format_route_single",
|
||||
@@ -6557,6 +6583,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
|
||||
("row_role_graph_policy", smoke_row_role_graph_policy),
|
||||
("row_assembly_policy", smoke_row_assembly_policy),
|
||||
("formatter_input_policy", smoke_formatter_input_policy),
|
||||
("formatter_target_policy", smoke_formatter_target_policy),
|
||||
("krea_format_route_policy", smoke_krea_format_route_policy),
|
||||
("formatter_cast_policy", smoke_formatter_cast_policy),
|
||||
("caption_policy", smoke_caption_policy),
|
||||
|
||||
Reference in New Issue
Block a user