Centralize formatter input hints
This commit is contained in:
@@ -4,8 +4,10 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from . import formatter_input as input_policy
|
||||||
from . import formatter_target as target_policy
|
from . import formatter_target as target_policy
|
||||||
except ImportError: # pragma: no cover - plain-script smoke tests
|
except ImportError: # pragma: no cover - plain-script smoke tests
|
||||||
|
import formatter_input as input_policy
|
||||||
import formatter_target as target_policy
|
import formatter_target as target_policy
|
||||||
|
|
||||||
|
|
||||||
@@ -53,7 +55,7 @@ def naturalize_caption_result(
|
|||||||
request: CaptionFormatRequest,
|
request: CaptionFormatRequest,
|
||||||
deps: CaptionFormatDependencies,
|
deps: CaptionFormatDependencies,
|
||||||
) -> CaptionFormatRoute:
|
) -> CaptionFormatRoute:
|
||||||
input_hint = request.input_hint if request.input_hint in ("auto", "metadata_json", "caption_or_prompt") else "auto"
|
input_hint = input_policy.normalize_input_hint(request.input_hint, text_hint=input_policy.INPUT_HINT_CAPTION_OR_PROMPT)
|
||||||
target = target_policy.normalize_target(request.target)
|
target = target_policy.normalize_target(request.target)
|
||||||
detail_level, style_policy, include_trigger = deps.apply_caption_profile(
|
detail_level, style_policy, include_trigger = deps.apply_caption_profile(
|
||||||
request.caption_profile,
|
request.caption_profile,
|
||||||
|
|||||||
@@ -69,6 +69,8 @@ Formatter input/fallback parsing now has one home:
|
|||||||
It owns route-neutral parsing shared by Krea2, SDXL, and natural-caption
|
It owns route-neutral parsing shared by Krea2, SDXL, and natural-caption
|
||||||
routes:
|
routes:
|
||||||
|
|
||||||
|
- input-hint choice lists and normalization for `auto`, `metadata_json`, and
|
||||||
|
route-specific text modes;
|
||||||
- whitespace and punctuation normalization before formatter parsing;
|
- whitespace and punctuation normalization before formatter parsing;
|
||||||
- JSON row detection from `metadata_json` or source text;
|
- JSON row detection from `metadata_json` or source text;
|
||||||
- trigger-prefix stripping with route-specific trigger candidate lists;
|
- trigger-prefix stripping with route-specific trigger candidate lists;
|
||||||
|
|||||||
+38
-1
@@ -29,6 +29,26 @@ DEFAULT_PROMPT_FIELD_LABELS = (
|
|||||||
"Avoid",
|
"Avoid",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
INPUT_HINT_AUTO = "auto"
|
||||||
|
INPUT_HINT_METADATA = "metadata_json"
|
||||||
|
INPUT_HINT_PROMPT = "prompt"
|
||||||
|
INPUT_HINT_CAPTION_OR_PROMPT = "caption_or_prompt"
|
||||||
|
TEXT_INPUT_HINTS = (INPUT_HINT_PROMPT, INPUT_HINT_CAPTION_OR_PROMPT)
|
||||||
|
FORMATTER_INPUT_HINTS = (INPUT_HINT_AUTO, INPUT_HINT_METADATA, INPUT_HINT_PROMPT, INPUT_HINT_CAPTION_OR_PROMPT)
|
||||||
|
METADATA_INPUT_HINTS = (INPUT_HINT_AUTO, INPUT_HINT_METADATA)
|
||||||
|
|
||||||
|
_INPUT_HINT_ALIASES = {
|
||||||
|
"caption": INPUT_HINT_CAPTION_OR_PROMPT,
|
||||||
|
"caption_prompt": INPUT_HINT_CAPTION_OR_PROMPT,
|
||||||
|
"caption_or_text": INPUT_HINT_CAPTION_OR_PROMPT,
|
||||||
|
"metadata": INPUT_HINT_METADATA,
|
||||||
|
"metadata json": INPUT_HINT_METADATA,
|
||||||
|
"source_json": INPUT_HINT_AUTO,
|
||||||
|
"source text": INPUT_HINT_PROMPT,
|
||||||
|
"source_text": INPUT_HINT_PROMPT,
|
||||||
|
"text": INPUT_HINT_PROMPT,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def prompt_field_labels() -> tuple[str, ...]:
|
def prompt_field_labels() -> tuple[str, ...]:
|
||||||
return DEFAULT_PROMPT_FIELD_LABELS
|
return DEFAULT_PROMPT_FIELD_LABELS
|
||||||
@@ -53,13 +73,30 @@ def maybe_json(text: Any) -> dict[str, Any] | None:
|
|||||||
return value if isinstance(value, dict) else None
|
return value if isinstance(value, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_input_hint(value: Any, *, text_hint: str = INPUT_HINT_PROMPT) -> str:
|
||||||
|
hint = clean_text(value).lower().replace("-", "_")
|
||||||
|
hint = _INPUT_HINT_ALIASES.get(hint, hint)
|
||||||
|
if hint in (INPUT_HINT_AUTO, INPUT_HINT_METADATA):
|
||||||
|
return hint
|
||||||
|
if hint in TEXT_INPUT_HINTS:
|
||||||
|
return text_hint if text_hint in TEXT_INPUT_HINTS else hint
|
||||||
|
return INPUT_HINT_AUTO
|
||||||
|
|
||||||
|
|
||||||
|
def input_hint_choices(*, text_hint: str = INPUT_HINT_PROMPT) -> list[str]:
|
||||||
|
text_hint = text_hint if text_hint in TEXT_INPUT_HINTS else INPUT_HINT_PROMPT
|
||||||
|
return [INPUT_HINT_AUTO, INPUT_HINT_METADATA, text_hint]
|
||||||
|
|
||||||
|
|
||||||
def row_from_inputs(
|
def row_from_inputs(
|
||||||
source_text: str,
|
source_text: str,
|
||||||
metadata_json: str,
|
metadata_json: str,
|
||||||
input_hint: str,
|
input_hint: str,
|
||||||
*,
|
*,
|
||||||
metadata_methods: tuple[str, ...] = ("auto", "metadata_json"),
|
metadata_methods: tuple[str, ...] = METADATA_INPUT_HINTS,
|
||||||
|
text_hint: str = INPUT_HINT_PROMPT,
|
||||||
) -> tuple[dict[str, Any] | None, str]:
|
) -> tuple[dict[str, Any] | None, str]:
|
||||||
|
input_hint = normalize_input_hint(input_hint, text_hint=text_hint)
|
||||||
if input_hint in metadata_methods:
|
if input_hint in metadata_methods:
|
||||||
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
|
for text, method in ((metadata_json, "metadata_json"), (source_text, "source_json")):
|
||||||
row = maybe_json(text)
|
row = maybe_json(text)
|
||||||
|
|||||||
+5
-3
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
try:
|
try:
|
||||||
from .caption_naturalizer import naturalize_caption
|
from .caption_naturalizer import naturalize_caption
|
||||||
from .caption_policy import caption_profile_choices
|
from .caption_policy import caption_profile_choices
|
||||||
|
from .formatter_input import INPUT_HINT_CAPTION_OR_PROMPT, INPUT_HINT_PROMPT, input_hint_choices
|
||||||
from .krea_formatter import format_krea2_prompt
|
from .krea_formatter import format_krea2_prompt
|
||||||
from .sdxl_formatter import (
|
from .sdxl_formatter import (
|
||||||
format_sdxl_prompt,
|
format_sdxl_prompt,
|
||||||
@@ -13,6 +14,7 @@ try:
|
|||||||
except ImportError: # Allows local smoke tests from the repository root.
|
except ImportError: # Allows local smoke tests from the repository root.
|
||||||
from caption_naturalizer import naturalize_caption
|
from caption_naturalizer import naturalize_caption
|
||||||
from caption_policy import caption_profile_choices
|
from caption_policy import caption_profile_choices
|
||||||
|
from formatter_input import INPUT_HINT_CAPTION_OR_PROMPT, INPUT_HINT_PROMPT, input_hint_choices
|
||||||
from krea_formatter import format_krea2_prompt
|
from krea_formatter import format_krea2_prompt
|
||||||
from sdxl_formatter import (
|
from sdxl_formatter import (
|
||||||
format_sdxl_prompt,
|
format_sdxl_prompt,
|
||||||
@@ -28,7 +30,7 @@ class SxCPCaptionNaturalizer:
|
|||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"source_text": ("STRING", {"default": "", "multiline": True}),
|
"source_text": ("STRING", {"default": "", "multiline": True}),
|
||||||
"input_hint": (["auto", "metadata_json", "caption_or_prompt"], {"default": "auto"}),
|
"input_hint": (input_hint_choices(text_hint=INPUT_HINT_CAPTION_OR_PROMPT), {"default": "auto"}),
|
||||||
"caption_profile": (caption_profile_choices(), {"default": "manual_controls"}),
|
"caption_profile": (caption_profile_choices(), {"default": "manual_controls"}),
|
||||||
"detail_level": (["balanced", "concise", "dense"], {"default": "balanced"}),
|
"detail_level": (["balanced", "concise", "dense"], {"default": "balanced"}),
|
||||||
"style_policy": (["drop_style_tail", "keep_style_terms"], {"default": "drop_style_tail"}),
|
"style_policy": (["drop_style_tail", "keep_style_terms"], {"default": "drop_style_tail"}),
|
||||||
@@ -80,7 +82,7 @@ class SxCPKrea2Formatter:
|
|||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"source_text": ("STRING", {"default": "", "multiline": True}),
|
"source_text": ("STRING", {"default": "", "multiline": True}),
|
||||||
"input_hint": (["auto", "metadata_json", "prompt"], {"default": "auto"}),
|
"input_hint": (input_hint_choices(text_hint=INPUT_HINT_PROMPT), {"default": "auto"}),
|
||||||
"target": (["auto", "single", "softcore", "hardcore"], {"default": "auto"}),
|
"target": (["auto", "single", "softcore", "hardcore"], {"default": "auto"}),
|
||||||
"detail_level": (["balanced", "concise", "dense"], {"default": "balanced"}),
|
"detail_level": (["balanced", "concise", "dense"], {"default": "balanced"}),
|
||||||
"style_mode": (["preserve", "photographic", "minimal"], {"default": "preserve"}),
|
"style_mode": (["preserve", "photographic", "minimal"], {"default": "preserve"}),
|
||||||
@@ -152,7 +154,7 @@ class SxCPSDXLFormatter:
|
|||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"source_text": ("STRING", {"default": "", "multiline": True}),
|
"source_text": ("STRING", {"default": "", "multiline": True}),
|
||||||
"input_hint": (["auto", "metadata_json", "prompt"], {"default": "auto"}),
|
"input_hint": (input_hint_choices(text_hint=INPUT_HINT_PROMPT), {"default": "auto"}),
|
||||||
"target": (["auto", "single", "softcore", "hardcore"], {"default": "auto"}),
|
"target": (["auto", "single", "softcore", "hardcore"], {"default": "auto"}),
|
||||||
"formatter_profile": (sdxl_formatter_profile_choices(), {"default": "manual_controls"}),
|
"formatter_profile": (sdxl_formatter_profile_choices(), {"default": "manual_controls"}),
|
||||||
"style_preset": (sdxl_style_preset_choices(), {"default": "flat_vector_pony"}),
|
"style_preset": (sdxl_style_preset_choices(), {"default": "flat_vector_pony"}),
|
||||||
|
|||||||
@@ -2685,9 +2685,36 @@ def smoke_formatter_input_policy() -> None:
|
|||||||
}
|
}
|
||||||
source_json = _json(source_row)
|
source_json = _json(source_row)
|
||||||
|
|
||||||
|
_expect(
|
||||||
|
formatter_input.input_hint_choices(text_hint=formatter_input.INPUT_HINT_PROMPT) == ["auto", "metadata_json", "prompt"],
|
||||||
|
"Formatter prompt input-hint choices changed",
|
||||||
|
)
|
||||||
|
_expect(
|
||||||
|
formatter_input.input_hint_choices(text_hint=formatter_input.INPUT_HINT_CAPTION_OR_PROMPT)
|
||||||
|
== ["auto", "metadata_json", "caption_or_prompt"],
|
||||||
|
"Formatter caption input-hint choices changed",
|
||||||
|
)
|
||||||
|
_expect(
|
||||||
|
formatter_input.normalize_input_hint("bad_hint") == "auto",
|
||||||
|
"Formatter input-hint policy should normalize invalid values to auto",
|
||||||
|
)
|
||||||
|
_expect(
|
||||||
|
formatter_input.normalize_input_hint("caption", text_hint=formatter_input.INPUT_HINT_CAPTION_OR_PROMPT)
|
||||||
|
== "caption_or_prompt",
|
||||||
|
"Formatter input-hint policy lost caption alias",
|
||||||
|
)
|
||||||
|
_expect(
|
||||||
|
formatter_input.normalize_input_hint("caption_or_prompt", text_hint=formatter_input.INPUT_HINT_PROMPT) == "prompt",
|
||||||
|
"Formatter input-hint policy should map text hints to the route's text mode",
|
||||||
|
)
|
||||||
|
|
||||||
row, method = formatter_input.row_from_inputs(source_json, "", "auto")
|
row, method = formatter_input.row_from_inputs(source_json, "", "auto")
|
||||||
_expect(method == "source_json", "Formatter input parser should read source JSON when metadata is empty")
|
_expect(method == "source_json", "Formatter input parser should read source JSON when metadata is empty")
|
||||||
_expect(row == source_row, "Formatter input parser changed parsed JSON row")
|
_expect(row == source_row, "Formatter input parser changed parsed JSON row")
|
||||||
|
row, method = formatter_input.row_from_inputs(source_json, "", "bad_hint")
|
||||||
|
_expect(method == "source_json" and row == source_row, "Formatter input parser should treat invalid hints as auto")
|
||||||
|
row, method = formatter_input.row_from_inputs(source_json, "", "prompt")
|
||||||
|
_expect(row is None and method == "text", "Formatter input parser should not parse source JSON in explicit prompt mode")
|
||||||
_expect(formatter_input.split_avoid("Prompt body. Avoid: blur, watermark") == ("Prompt body", "blur, watermark"), "Avoid split changed")
|
_expect(formatter_input.split_avoid("Prompt body. Avoid: blur, watermark") == ("Prompt body", "blur, watermark"), "Avoid split changed")
|
||||||
_expect(
|
_expect(
|
||||||
formatter_input.prompt_field(source_row["prompt"], "Setting") == "quiet studio",
|
formatter_input.prompt_field(source_row["prompt"], "Setting") == "quiet studio",
|
||||||
@@ -2745,6 +2772,32 @@ def smoke_formatter_input_policy() -> None:
|
|||||||
_expect_text("formatter_input.krea_prompt", krea.get("krea_prompt"), 20)
|
_expect_text("formatter_input.krea_prompt", krea.get("krea_prompt"), 20)
|
||||||
_expect_text("formatter_input.sdxl_prompt", sdxl.get("sdxl_prompt"), 20)
|
_expect_text("formatter_input.sdxl_prompt", sdxl.get("sdxl_prompt"), 20)
|
||||||
_expect_text("formatter_input.caption", caption, 20)
|
_expect_text("formatter_input.caption", caption, 20)
|
||||||
|
|
||||||
|
bad_hint_krea = krea_formatter.format_krea2_prompt(source_json, input_hint="bad_hint")
|
||||||
|
bad_hint_sdxl = sdxl_formatter.format_sdxl_prompt(
|
||||||
|
source_json,
|
||||||
|
input_hint="bad_hint",
|
||||||
|
trigger=SdxlTrigger,
|
||||||
|
prepend_trigger=True,
|
||||||
|
)
|
||||||
|
bad_hint_caption, bad_hint_caption_method = caption_naturalizer.naturalize_caption(
|
||||||
|
source_json,
|
||||||
|
input_hint="bad_hint",
|
||||||
|
trigger=Trigger,
|
||||||
|
)
|
||||||
|
_expect(
|
||||||
|
bad_hint_krea.get("method", "").startswith("source_json:krea2("),
|
||||||
|
"Krea formatter did not normalize bad input hint to auto",
|
||||||
|
)
|
||||||
|
_expect(
|
||||||
|
bad_hint_sdxl.get("method", "").startswith("source_json:sdxl("),
|
||||||
|
"SDXL formatter did not normalize bad input hint to auto",
|
||||||
|
)
|
||||||
|
_expect(
|
||||||
|
bad_hint_caption_method.startswith("source_json:metadata("),
|
||||||
|
"Caption formatter did not normalize bad input hint to auto",
|
||||||
|
)
|
||||||
|
|
||||||
fallback_sdxl = sdxl_formatter.format_sdxl_prompt(
|
fallback_sdxl = sdxl_formatter.format_sdxl_prompt(
|
||||||
"Characters: woman. Erotic outfit: sheer dress. Camera: side view. Avoid: blur",
|
"Characters: woman. Erotic outfit: sheer dress. Camera: side view. Avoid: blur",
|
||||||
input_hint="prompt",
|
input_hint="prompt",
|
||||||
|
|||||||
Reference in New Issue
Block a user