Centralize formatter target policy
This commit is contained in:
@@ -3,6 +3,11 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import formatter_target as target_policy
|
||||||
|
except ImportError: # pragma: no cover - plain-script smoke tests
|
||||||
|
import formatter_target as target_policy
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class CaptionFormatRequest:
|
class CaptionFormatRequest:
|
||||||
@@ -49,7 +54,7 @@ def naturalize_caption_result(
|
|||||||
deps: CaptionFormatDependencies,
|
deps: CaptionFormatDependencies,
|
||||||
) -> CaptionFormatRoute:
|
) -> CaptionFormatRoute:
|
||||||
input_hint = request.input_hint if request.input_hint in ("auto", "metadata_json", "caption_or_prompt") else "auto"
|
input_hint = request.input_hint if request.input_hint in ("auto", "metadata_json", "caption_or_prompt") else "auto"
|
||||||
target = request.target if request.target in ("auto", "single", "softcore", "hardcore") else "auto"
|
target = target_policy.normalize_target(request.target)
|
||||||
detail_level, style_policy, include_trigger = deps.apply_caption_profile(
|
detail_level, style_policy, include_trigger = deps.apply_caption_profile(
|
||||||
request.caption_profile,
|
request.caption_profile,
|
||||||
request.detail_level,
|
request.detail_level,
|
||||||
|
|||||||
@@ -4,6 +4,11 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import formatter_target as target_policy
|
||||||
|
except ImportError: # pragma: no cover - plain-script smoke tests
|
||||||
|
import formatter_target as target_policy
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class CaptionMetadataRouteRequest:
|
class CaptionMetadataRouteRequest:
|
||||||
@@ -301,7 +306,8 @@ def insta_of_pair_from_row_result(
|
|||||||
row = request.row
|
row = request.row
|
||||||
detail_level = request.detail_level
|
detail_level = request.detail_level
|
||||||
keep_style = request.keep_style
|
keep_style = request.keep_style
|
||||||
target = request.target if request.target in ("softcore", "hardcore") else "auto"
|
pair_target = target_policy.pair_policy(request.target)
|
||||||
|
target = pair_target.pair_target
|
||||||
if deps.clean_text(row.get("mode")).lower() != "insta/of":
|
if deps.clean_text(row.get("mode")).lower() != "insta/of":
|
||||||
return None
|
return None
|
||||||
soft_row = row.get("softcore_row")
|
soft_row = row.get("softcore_row")
|
||||||
@@ -317,8 +323,8 @@ def insta_of_pair_from_row_result(
|
|||||||
if soft_row.get("composition"):
|
if soft_row.get("composition"):
|
||||||
hard_row_for_text["composition"] = soft_row["composition"]
|
hard_row_for_text["composition"] = soft_row["composition"]
|
||||||
|
|
||||||
include_soft = target in ("auto", "softcore")
|
include_soft = pair_target.include_softcore
|
||||||
include_hard = target in ("auto", "hardcore")
|
include_hard = pair_target.include_hardcore
|
||||||
soft_text = ""
|
soft_text = ""
|
||||||
hard_text = ""
|
hard_text = ""
|
||||||
if include_soft:
|
if include_soft:
|
||||||
|
|||||||
@@ -82,6 +82,15 @@ routes:
|
|||||||
It must not make formatter-style decisions. Krea prose, SDXL tags, and training
|
It must not make formatter-style decisions. Krea prose, SDXL tags, and training
|
||||||
caption sentence shape stay in their formatter modules.
|
caption sentence shape stay in their formatter modules.
|
||||||
|
|
||||||
|
Formatter target handling now has one home:
|
||||||
|
|
||||||
|
- `formatter_target.py`
|
||||||
|
|
||||||
|
It owns route-neutral target normalization for `auto`, `single`, `softcore`,
|
||||||
|
and `hardcore`, including pair-side semantics. Single-output formatters select
|
||||||
|
the softcore side for pair `auto`/`single` targets, while caption pair routing
|
||||||
|
can still include both sides for combined training captions.
|
||||||
|
|
||||||
Shared hardcore phrase cleanup now has one home:
|
Shared hardcore phrase cleanup now has one home:
|
||||||
|
|
||||||
- `hardcore_text_cleanup.py`
|
- `hardcore_text_cleanup.py`
|
||||||
@@ -345,7 +354,8 @@ Already isolated:
|
|||||||
|
|
||||||
- `krea_format_route.py` owns top-level Krea dispatch, including option
|
- `krea_format_route.py` owns top-level Krea dispatch, including option
|
||||||
normalization, metadata-vs-text input selection, single-vs-pair branching,
|
normalization, metadata-vs-text input selection, single-vs-pair branching,
|
||||||
extra positive/negative merging, final prose hygiene, and output shape;
|
shared target normalization via `formatter_target.py`, extra
|
||||||
|
positive/negative merging, final prose hygiene, and output shape;
|
||||||
`krea_formatter.py` keeps the public wrapper.
|
`krea_formatter.py` keeps the public wrapper.
|
||||||
- `krea_configured_cast_formatter.py` owns normal metadata configured-cast
|
- `krea_configured_cast_formatter.py` owns normal metadata configured-cast
|
||||||
Krea prose assembly behind `KreaConfiguredCastRequest`,
|
Krea prose assembly behind `KreaConfiguredCastRequest`,
|
||||||
@@ -417,9 +427,10 @@ Keep here:
|
|||||||
Already isolated:
|
Already isolated:
|
||||||
|
|
||||||
- `sdxl_format_route.py` owns top-level SDXL dispatch, including formatter
|
- `sdxl_format_route.py` owns top-level SDXL dispatch, including formatter
|
||||||
profile application, target and nude-weight normalization, metadata-vs-text
|
profile application, shared target normalization via `formatter_target.py`,
|
||||||
input selection, single-vs-pair branching, final prompt/negative output
|
nude-weight normalization, metadata-vs-text input selection, single-vs-pair
|
||||||
shape, and fallback routing; `sdxl_formatter.py` keeps the public wrapper.
|
branching, final prompt/negative output shape, and fallback routing;
|
||||||
|
`sdxl_formatter.py` keeps the public wrapper.
|
||||||
- `sdxl_tag_routes.py` owns normal metadata row tags and Insta/OF pair soft/hard
|
- `sdxl_tag_routes.py` owns normal metadata row tags and Insta/OF pair soft/hard
|
||||||
tag extraction behind `SDXLRowTagRequest`, `SDXLPairTagRequest`,
|
tag extraction behind `SDXLRowTagRequest`, `SDXLPairTagRequest`,
|
||||||
`SDXLTagRouteDependencies`, and `SDXLTagRoute`; `sdxl_formatter.py` keeps
|
`SDXLTagRouteDependencies`, and `SDXLTagRoute`; `sdxl_formatter.py` keeps
|
||||||
@@ -455,8 +466,9 @@ Keep here:
|
|||||||
Already isolated:
|
Already isolated:
|
||||||
|
|
||||||
- `caption_format_route.py` owns top-level caption dispatch, including input
|
- `caption_format_route.py` owns top-level caption dispatch, including input
|
||||||
hint normalization, caption profile application, metadata-vs-text branching,
|
hint normalization, shared target normalization via `formatter_target.py`,
|
||||||
trigger wrapping, final prose hygiene, and method/output shape;
|
caption profile application, metadata-vs-text branching, trigger wrapping,
|
||||||
|
final prose hygiene, and method/output shape;
|
||||||
`caption_naturalizer.py` keeps the public wrapper.
|
`caption_naturalizer.py` keeps the public wrapper.
|
||||||
- `caption_metadata_routes.py` owns metadata row natural-language assembly for
|
- `caption_metadata_routes.py` owns metadata row natural-language assembly for
|
||||||
single, couple, configured-cast, group/layout, and Insta/OF pair routes behind
|
single, couple, configured-cast, group/layout, and Insta/OF pair routes behind
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
FORMATTER_TARGETS = ("auto", "single", "softcore", "hardcore")
|
||||||
|
PAIR_SIDE_TARGETS = ("softcore", "hardcore")
|
||||||
|
DEFAULT_FORMATTER_TARGET = "auto"
|
||||||
|
DEFAULT_PAIR_SELECTED_SIDE = "softcore"
|
||||||
|
|
||||||
|
_TARGET_ALIASES = {
|
||||||
|
"soft": "softcore",
|
||||||
|
"soft_core": "softcore",
|
||||||
|
"hard": "hardcore",
|
||||||
|
"hard_core": "hardcore",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PairTargetPolicy:
|
||||||
|
target: str
|
||||||
|
pair_target: str
|
||||||
|
selected_side: str
|
||||||
|
include_softcore: bool
|
||||||
|
include_hardcore: bool
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_target(value: Any) -> str:
|
||||||
|
target = str(value or "").strip().lower().replace("-", "_").replace(" ", "_")
|
||||||
|
target = _TARGET_ALIASES.get(target, target)
|
||||||
|
return target if target in FORMATTER_TARGETS else DEFAULT_FORMATTER_TARGET
|
||||||
|
|
||||||
|
|
||||||
|
def pair_target(value: Any) -> str:
|
||||||
|
target = normalize_target(value)
|
||||||
|
return target if target in PAIR_SIDE_TARGETS else DEFAULT_FORMATTER_TARGET
|
||||||
|
|
||||||
|
|
||||||
|
def pair_selected_side(value: Any, default: str = DEFAULT_PAIR_SELECTED_SIDE) -> str:
|
||||||
|
side = pair_target(value)
|
||||||
|
if side in PAIR_SIDE_TARGETS:
|
||||||
|
return side
|
||||||
|
return default if default in PAIR_SIDE_TARGETS else DEFAULT_PAIR_SELECTED_SIDE
|
||||||
|
|
||||||
|
|
||||||
|
def pair_policy(value: Any, *, selected_default: str = DEFAULT_PAIR_SELECTED_SIDE) -> PairTargetPolicy:
|
||||||
|
target = normalize_target(value)
|
||||||
|
side_target = pair_target(target)
|
||||||
|
selected_side = pair_selected_side(side_target, selected_default)
|
||||||
|
return PairTargetPolicy(
|
||||||
|
target=target,
|
||||||
|
pair_target=side_target,
|
||||||
|
selected_side=selected_side,
|
||||||
|
include_softcore=side_target in ("auto", "softcore"),
|
||||||
|
include_hardcore=side_target in ("auto", "hardcore"),
|
||||||
|
)
|
||||||
@@ -3,6 +3,11 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import formatter_target as target_policy
|
||||||
|
except ImportError: # pragma: no cover - plain-script smoke tests
|
||||||
|
import formatter_target as target_policy
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class KreaFormatRequest:
|
class KreaFormatRequest:
|
||||||
@@ -45,10 +50,11 @@ class KreaFormatDependencies:
|
|||||||
def format_krea2_prompt_result(request: KreaFormatRequest, deps: KreaFormatDependencies) -> KreaFormatRoute:
|
def format_krea2_prompt_result(request: KreaFormatRequest, deps: KreaFormatDependencies) -> KreaFormatRoute:
|
||||||
detail_level = request.detail_level if request.detail_level in ("concise", "balanced", "dense") else "balanced"
|
detail_level = request.detail_level if request.detail_level in ("concise", "balanced", "dense") else "balanced"
|
||||||
style_mode = request.style_mode if request.style_mode in ("preserve", "photographic", "minimal") else "preserve"
|
style_mode = request.style_mode if request.style_mode in ("preserve", "photographic", "minimal") else "preserve"
|
||||||
target = request.target if request.target in ("auto", "single", "softcore", "hardcore") else "auto"
|
target = target_policy.normalize_target(request.target)
|
||||||
row, method = deps.row_from_inputs(request.source_text, request.metadata_json, request.input_hint)
|
row, method = deps.row_from_inputs(request.source_text, request.metadata_json, request.input_hint)
|
||||||
|
|
||||||
if row and row.get("mode") == "Insta/OF":
|
if row and row.get("mode") == "Insta/OF":
|
||||||
|
pair_target = target_policy.pair_policy(target)
|
||||||
soft_prompt, soft_negative, hard_prompt, hard_negative = deps.insta_pair_to_krea(
|
soft_prompt, soft_negative, hard_prompt, hard_negative = deps.insta_pair_to_krea(
|
||||||
row,
|
row,
|
||||||
detail_level,
|
detail_level,
|
||||||
@@ -63,8 +69,8 @@ def format_krea2_prompt_result(request: KreaFormatRequest, deps: KreaFormatDepen
|
|||||||
hard_prompt = f"{hard_prompt.rstrip()} {request.extra_positive.strip()}"
|
hard_prompt = f"{hard_prompt.rstrip()} {request.extra_positive.strip()}"
|
||||||
soft_prompt = deps.sanitize_prose_text(soft_prompt, triggers=deps.trigger_candidates)
|
soft_prompt = deps.sanitize_prose_text(soft_prompt, triggers=deps.trigger_candidates)
|
||||||
hard_prompt = deps.sanitize_prose_text(hard_prompt, triggers=deps.trigger_candidates)
|
hard_prompt = deps.sanitize_prose_text(hard_prompt, triggers=deps.trigger_candidates)
|
||||||
selected = hard_prompt if target == "hardcore" else soft_prompt if target == "softcore" else soft_prompt
|
selected = hard_prompt if pair_target.selected_side == "hardcore" else soft_prompt
|
||||||
selected_negative = hard_negative if target == "hardcore" else soft_negative
|
selected_negative = hard_negative if pair_target.selected_side == "hardcore" else soft_negative
|
||||||
negative = deps.sanitize_negative_text(
|
negative = deps.sanitize_negative_text(
|
||||||
deps.combine_negative(selected_negative, request.negative_prompt, request.extra_negative)
|
deps.combine_negative(selected_negative, request.negative_prompt, request.extra_negative)
|
||||||
)
|
)
|
||||||
|
|||||||
+11
-3
@@ -3,6 +3,11 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import formatter_target as target_policy
|
||||||
|
except ImportError: # pragma: no cover - plain-script smoke tests
|
||||||
|
import formatter_target as target_policy
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class SDXLFormatRequest:
|
class SDXLFormatRequest:
|
||||||
@@ -56,11 +61,12 @@ def format_sdxl_prompt_result(request: SDXLFormatRequest, deps: SDXLFormatDepend
|
|||||||
request.style_preset,
|
request.style_preset,
|
||||||
request.quality_preset,
|
request.quality_preset,
|
||||||
)
|
)
|
||||||
target = request.target if request.target in ("auto", "single", "softcore", "hardcore") else "auto"
|
target = target_policy.normalize_target(request.target)
|
||||||
nude_weight = max(0.1, min(3.0, float(request.nude_weight)))
|
nude_weight = max(0.1, min(3.0, float(request.nude_weight)))
|
||||||
row, method = deps.row_from_inputs(request.source_text, request.metadata_json, request.input_hint)
|
row, method = deps.row_from_inputs(request.source_text, request.metadata_json, request.input_hint)
|
||||||
|
|
||||||
if row and row.get("mode") == "Insta/OF":
|
if row and row.get("mode") == "Insta/OF":
|
||||||
|
pair_target = target_policy.pair_policy(target)
|
||||||
soft_row = row.get("softcore_row") if isinstance(row.get("softcore_row"), dict) else {}
|
soft_row = row.get("softcore_row") if isinstance(row.get("softcore_row"), dict) else {}
|
||||||
hard_row = row.get("hardcore_row") if isinstance(row.get("hardcore_row"), dict) else {}
|
hard_row = row.get("hardcore_row") if isinstance(row.get("hardcore_row"), dict) else {}
|
||||||
soft_body = deps.soft_tags(soft_row, row, nude_weight)
|
soft_body = deps.soft_tags(soft_row, row, nude_weight)
|
||||||
@@ -85,9 +91,11 @@ def format_sdxl_prompt_result(request: SDXLFormatRequest, deps: SDXLFormatDepend
|
|||||||
request.custom_quality,
|
request.custom_quality,
|
||||||
request.extra_positive,
|
request.extra_positive,
|
||||||
)
|
)
|
||||||
selected = hard_prompt if target == "hardcore" else soft_prompt
|
selected = hard_prompt if pair_target.selected_side == "hardcore" else soft_prompt
|
||||||
selected_negative = (
|
selected_negative = (
|
||||||
row.get("hardcore_negative_prompt") if target == "hardcore" else row.get("softcore_negative_prompt")
|
row.get("hardcore_negative_prompt")
|
||||||
|
if pair_target.selected_side == "hardcore"
|
||||||
|
else row.get("softcore_negative_prompt")
|
||||||
)
|
)
|
||||||
output = {
|
output = {
|
||||||
"sdxl_prompt": selected,
|
"sdxl_prompt": selected,
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ import category_cast_config # noqa: E402
|
|||||||
import category_library # noqa: E402
|
import category_library # noqa: E402
|
||||||
import filter_config # noqa: E402
|
import filter_config # noqa: E402
|
||||||
import formatter_input # noqa: E402
|
import formatter_input # noqa: E402
|
||||||
|
import formatter_target # noqa: E402
|
||||||
import hardcore_position_config # noqa: E402
|
import hardcore_position_config # noqa: E402
|
||||||
import __init__ as sxcp_nodes # noqa: E402
|
import __init__ as sxcp_nodes # noqa: E402
|
||||||
import generation_profile_config # noqa: E402
|
import generation_profile_config # noqa: E402
|
||||||
@@ -2759,6 +2760,31 @@ def smoke_formatter_input_policy() -> None:
|
|||||||
_expect("blur" in fallback_sdxl.get("negative_prompt", ""), "SDXL fallback lost Avoid negative text")
|
_expect("blur" in fallback_sdxl.get("negative_prompt", ""), "SDXL fallback lost Avoid negative text")
|
||||||
|
|
||||||
|
|
||||||
|
def smoke_formatter_target_policy() -> None:
|
||||||
|
_expect(formatter_target.normalize_target("single") == "single", "Formatter target lost single")
|
||||||
|
_expect(formatter_target.normalize_target("Hard-Core") == "hardcore", "Formatter target alias lost hardcore")
|
||||||
|
_expect(formatter_target.normalize_target("soft") == "softcore", "Formatter target alias lost softcore")
|
||||||
|
_expect(formatter_target.normalize_target("bad target") == "auto", "Formatter target should normalize invalid values")
|
||||||
|
|
||||||
|
auto_pair = formatter_target.pair_policy("auto")
|
||||||
|
_expect(auto_pair.target == "auto", "Pair target policy lost normalized auto target")
|
||||||
|
_expect(auto_pair.pair_target == "auto", "Pair target policy lost auto pair target")
|
||||||
|
_expect(auto_pair.selected_side == "softcore", "Pair auto should select softcore side for single-output formatters")
|
||||||
|
_expect(auto_pair.include_softcore and auto_pair.include_hardcore, "Pair auto should include both sides for combined captions")
|
||||||
|
|
||||||
|
single_pair = formatter_target.pair_policy("single")
|
||||||
|
_expect(single_pair.target == "single", "Pair target policy lost normalized single target")
|
||||||
|
_expect(single_pair.pair_target == "auto", "Pair single should map to auto for pair inclusion")
|
||||||
|
_expect(single_pair.selected_side == "softcore", "Pair single should select softcore side by default")
|
||||||
|
_expect(single_pair.include_softcore and single_pair.include_hardcore, "Pair single should include both sides when treated as auto")
|
||||||
|
|
||||||
|
hard_pair = formatter_target.pair_policy("hard")
|
||||||
|
_expect(hard_pair.target == "hardcore", "Pair target policy lost hard alias")
|
||||||
|
_expect(hard_pair.pair_target == "hardcore", "Pair hard alias should become hardcore pair target")
|
||||||
|
_expect(hard_pair.selected_side == "hardcore", "Pair hardcore should select hardcore side")
|
||||||
|
_expect(not hard_pair.include_softcore and hard_pair.include_hardcore, "Pair hardcore should include only hard side")
|
||||||
|
|
||||||
|
|
||||||
def smoke_krea_format_route_policy() -> None:
|
def smoke_krea_format_route_policy() -> None:
|
||||||
row = _prompt_row(
|
row = _prompt_row(
|
||||||
name="krea_format_route_single",
|
name="krea_format_route_single",
|
||||||
@@ -6557,6 +6583,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
|
|||||||
("row_role_graph_policy", smoke_row_role_graph_policy),
|
("row_role_graph_policy", smoke_row_role_graph_policy),
|
||||||
("row_assembly_policy", smoke_row_assembly_policy),
|
("row_assembly_policy", smoke_row_assembly_policy),
|
||||||
("formatter_input_policy", smoke_formatter_input_policy),
|
("formatter_input_policy", smoke_formatter_input_policy),
|
||||||
|
("formatter_target_policy", smoke_formatter_target_policy),
|
||||||
("krea_format_route_policy", smoke_krea_format_route_policy),
|
("krea_format_route_policy", smoke_krea_format_route_policy),
|
||||||
("formatter_cast_policy", smoke_formatter_cast_policy),
|
("formatter_cast_policy", smoke_formatter_cast_policy),
|
||||||
("caption_policy", smoke_caption_policy),
|
("caption_policy", smoke_caption_policy),
|
||||||
|
|||||||
Reference in New Issue
Block a user