Files
ComfyUI-Ethanfel-Prompt-Bui…/sdxl_format_route.py
T

204 lines
7.3 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
try:
from . import formatter_input as input_policy
from . import formatter_route_trace as trace_policy
from . import formatter_target as target_policy
except ImportError: # pragma: no cover - plain-script smoke tests
import formatter_input as input_policy
import formatter_route_trace as trace_policy
import formatter_target as target_policy
@dataclass(frozen=True)
class SDXLFormatRequest:
source_text: str
metadata_json: str = ""
negative_prompt: str = ""
input_hint: str = "auto"
target: str = "auto"
style_preset: str = "flat_vector_pony"
quality_preset: str = "pony_high"
trigger: str = "mythp0rt"
prepend_trigger: bool = True
preserve_trigger: bool = False
nude_weight: float = 1.29
custom_style: str = ""
custom_quality: str = ""
extra_positive: str = ""
extra_negative: str = ""
formatter_profile: str = "manual_controls"
@dataclass(frozen=True)
class SDXLFormatRoute:
output: dict[str, str]
branch: str
method: str
target: str
style_preset: str
quality_preset: str
nude_weight: float
@dataclass(frozen=True)
class SDXLFormatDependencies:
default_negative: str
apply_formatter_profile: Callable[[str, str, str], tuple[str, str]]
clean: Callable[[Any], str]
row_from_inputs: Callable[[str, str, str], tuple[dict[str, Any] | None, str]]
row_core_tags: Callable[[dict[str, Any], float], list[str]]
soft_tags: Callable[[dict[str, Any], dict[str, Any], float], str]
hard_tags: Callable[[dict[str, Any], dict[str, Any], float], str]
fallback_text_to_sdxl: Callable[[str, bool, float], tuple[str, str, str]]
assemble_prompt: Callable[[str, str, str, str, bool, str, str, str], str]
combine_negative: Callable[..., str]
sanitize_negative_text: Callable[[str], str]
def format_sdxl_prompt_result(request: SDXLFormatRequest, deps: SDXLFormatDependencies) -> SDXLFormatRoute:
style_preset, quality_preset = deps.apply_formatter_profile(
request.formatter_profile,
request.style_preset,
request.quality_preset,
)
target = target_policy.normalize_target(request.target)
input_hint = input_policy.normalize_input_hint(request.input_hint, text_hint=input_policy.INPUT_HINT_PROMPT)
nude_weight = max(0.1, min(3.0, float(request.nude_weight)))
row, method = deps.row_from_inputs(request.source_text, request.metadata_json, request.input_hint)
if row and input_policy.is_pair_metadata(row):
pair_target = target_policy.pair_policy(target)
soft_row = row.get("softcore_row") if isinstance(row.get("softcore_row"), dict) else {}
hard_row = row.get("hardcore_row") if isinstance(row.get("hardcore_row"), dict) else {}
soft_body = deps.soft_tags(soft_row, row, nude_weight)
hard_body = deps.hard_tags(hard_row, row, nude_weight)
soft_prompt = deps.assemble_prompt(
soft_body,
style_preset,
quality_preset,
request.trigger,
request.prepend_trigger,
request.custom_style,
request.custom_quality,
request.extra_positive,
)
hard_prompt = deps.assemble_prompt(
hard_body,
style_preset,
quality_preset,
request.trigger,
request.prepend_trigger,
request.custom_style,
request.custom_quality,
request.extra_positive,
)
selected = hard_prompt if pair_target.selected_side == "hardcore" else soft_prompt
selected_negative = (
row.get("hardcore_negative_prompt")
if pair_target.selected_side == "hardcore"
else row.get("softcore_negative_prompt")
)
output = {
"sdxl_prompt": selected,
"negative_prompt": deps.sanitize_negative_text(
deps.combine_negative(
deps.default_negative,
selected_negative,
request.negative_prompt,
request.extra_negative,
)
),
"sdxl_softcore_prompt": soft_prompt,
"sdxl_hardcore_prompt": hard_prompt,
"softcore_negative_prompt": deps.sanitize_negative_text(
deps.combine_negative(deps.default_negative, row.get("softcore_negative_prompt"), request.extra_negative)
),
"hardcore_negative_prompt": deps.sanitize_negative_text(
deps.combine_negative(deps.default_negative, row.get("hardcore_negative_prompt"), request.extra_negative)
),
"method": f"{method}:sdxl(insta_of_pair)",
}
output["route_trace_json"] = trace_policy.route_trace_json(
formatter="sdxl",
branch="insta_of_pair",
method=output["method"],
input_hint=input_hint,
target=target,
style_preset=style_preset,
quality_preset=quality_preset,
nude_weight=nude_weight,
**trace_policy.metadata_trace_fields(row, target=target, selected_side=pair_target.selected_side),
)
return SDXLFormatRoute(
output=output,
branch="insta_of_pair",
method=output["method"],
target=target,
style_preset=style_preset,
quality_preset=quality_preset,
nude_weight=nude_weight,
)
if row:
body = ", ".join(deps.row_core_tags(row, nude_weight))
extracted_negative = deps.clean(row.get("negative_prompt"))
method = f"{method}:sdxl(metadata)"
branch = "metadata"
else:
body, extracted_negative, method = deps.fallback_text_to_sdxl(
request.source_text,
request.preserve_trigger,
nude_weight,
)
branch = "fallback"
prompt = deps.assemble_prompt(
body,
style_preset,
quality_preset,
request.trigger,
request.prepend_trigger,
request.custom_style,
request.custom_quality,
request.extra_positive,
)
output = {
"sdxl_prompt": prompt,
"negative_prompt": deps.sanitize_negative_text(
deps.combine_negative(deps.default_negative, extracted_negative, request.negative_prompt, request.extra_negative)
),
"sdxl_softcore_prompt": "",
"sdxl_hardcore_prompt": "",
"softcore_negative_prompt": "",
"hardcore_negative_prompt": "",
"method": method,
}
output["route_trace_json"] = trace_policy.route_trace_json(
formatter="sdxl",
branch=branch,
method=method,
input_hint=input_hint,
target=target,
style_preset=style_preset,
quality_preset=quality_preset,
nude_weight=nude_weight,
**trace_policy.metadata_trace_fields(row, target=target),
)
return SDXLFormatRoute(
output=output,
branch=branch,
method=method,
target=target,
style_preset=style_preset,
quality_preset=quality_preset,
nude_weight=nude_weight,
)
def format_sdxl_prompt(request: SDXLFormatRequest, deps: SDXLFormatDependencies) -> dict[str, str]:
return format_sdxl_prompt_result(request, deps).output