179 lines
6.2 KiB
Python
179 lines
6.2 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable
|
|
|
|
try:
|
|
from . import formatter_input as input_policy
|
|
from . import formatter_target as target_policy
|
|
except ImportError: # pragma: no cover - plain-script smoke tests
|
|
import formatter_input as input_policy
|
|
import formatter_target as target_policy
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SDXLFormatRequest:
|
|
source_text: str
|
|
metadata_json: str = ""
|
|
negative_prompt: str = ""
|
|
input_hint: str = "auto"
|
|
target: str = "auto"
|
|
style_preset: str = "flat_vector_pony"
|
|
quality_preset: str = "pony_high"
|
|
trigger: str = "mythp0rt"
|
|
prepend_trigger: bool = True
|
|
preserve_trigger: bool = False
|
|
nude_weight: float = 1.29
|
|
custom_style: str = ""
|
|
custom_quality: str = ""
|
|
extra_positive: str = ""
|
|
extra_negative: str = ""
|
|
formatter_profile: str = "manual_controls"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SDXLFormatRoute:
|
|
output: dict[str, str]
|
|
branch: str
|
|
method: str
|
|
target: str
|
|
style_preset: str
|
|
quality_preset: str
|
|
nude_weight: float
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SDXLFormatDependencies:
|
|
default_negative: str
|
|
apply_formatter_profile: Callable[[str, str, str], tuple[str, str]]
|
|
clean: Callable[[Any], str]
|
|
row_from_inputs: Callable[[str, str, str], tuple[dict[str, Any] | None, str]]
|
|
row_core_tags: Callable[[dict[str, Any], float], list[str]]
|
|
soft_tags: Callable[[dict[str, Any], dict[str, Any], float], str]
|
|
hard_tags: Callable[[dict[str, Any], dict[str, Any], float], str]
|
|
fallback_text_to_sdxl: Callable[[str, bool, float], tuple[str, str, str]]
|
|
assemble_prompt: Callable[[str, str, str, str, bool, str, str, str], str]
|
|
combine_negative: Callable[..., str]
|
|
sanitize_negative_text: Callable[[str], str]
|
|
|
|
|
|
def format_sdxl_prompt_result(request: SDXLFormatRequest, deps: SDXLFormatDependencies) -> SDXLFormatRoute:
|
|
style_preset, quality_preset = deps.apply_formatter_profile(
|
|
request.formatter_profile,
|
|
request.style_preset,
|
|
request.quality_preset,
|
|
)
|
|
target = target_policy.normalize_target(request.target)
|
|
nude_weight = max(0.1, min(3.0, float(request.nude_weight)))
|
|
row, method = deps.row_from_inputs(request.source_text, request.metadata_json, request.input_hint)
|
|
|
|
if row and input_policy.is_pair_metadata(row):
|
|
pair_target = target_policy.pair_policy(target)
|
|
soft_row = row.get("softcore_row") if isinstance(row.get("softcore_row"), dict) else {}
|
|
hard_row = row.get("hardcore_row") if isinstance(row.get("hardcore_row"), dict) else {}
|
|
soft_body = deps.soft_tags(soft_row, row, nude_weight)
|
|
hard_body = deps.hard_tags(hard_row, row, nude_weight)
|
|
soft_prompt = deps.assemble_prompt(
|
|
soft_body,
|
|
style_preset,
|
|
quality_preset,
|
|
request.trigger,
|
|
request.prepend_trigger,
|
|
request.custom_style,
|
|
request.custom_quality,
|
|
request.extra_positive,
|
|
)
|
|
hard_prompt = deps.assemble_prompt(
|
|
hard_body,
|
|
style_preset,
|
|
quality_preset,
|
|
request.trigger,
|
|
request.prepend_trigger,
|
|
request.custom_style,
|
|
request.custom_quality,
|
|
request.extra_positive,
|
|
)
|
|
selected = hard_prompt if pair_target.selected_side == "hardcore" else soft_prompt
|
|
selected_negative = (
|
|
row.get("hardcore_negative_prompt")
|
|
if pair_target.selected_side == "hardcore"
|
|
else row.get("softcore_negative_prompt")
|
|
)
|
|
output = {
|
|
"sdxl_prompt": selected,
|
|
"negative_prompt": deps.sanitize_negative_text(
|
|
deps.combine_negative(
|
|
deps.default_negative,
|
|
selected_negative,
|
|
request.negative_prompt,
|
|
request.extra_negative,
|
|
)
|
|
),
|
|
"sdxl_softcore_prompt": soft_prompt,
|
|
"sdxl_hardcore_prompt": hard_prompt,
|
|
"softcore_negative_prompt": deps.sanitize_negative_text(
|
|
deps.combine_negative(deps.default_negative, row.get("softcore_negative_prompt"), request.extra_negative)
|
|
),
|
|
"hardcore_negative_prompt": deps.sanitize_negative_text(
|
|
deps.combine_negative(deps.default_negative, row.get("hardcore_negative_prompt"), request.extra_negative)
|
|
),
|
|
"method": f"{method}:sdxl(insta_of_pair)",
|
|
}
|
|
return SDXLFormatRoute(
|
|
output=output,
|
|
branch="insta_of_pair",
|
|
method=output["method"],
|
|
target=target,
|
|
style_preset=style_preset,
|
|
quality_preset=quality_preset,
|
|
nude_weight=nude_weight,
|
|
)
|
|
|
|
if row:
|
|
body = ", ".join(deps.row_core_tags(row, nude_weight))
|
|
extracted_negative = deps.clean(row.get("negative_prompt"))
|
|
method = f"{method}:sdxl(metadata)"
|
|
branch = "metadata"
|
|
else:
|
|
body, extracted_negative, method = deps.fallback_text_to_sdxl(
|
|
request.source_text,
|
|
request.preserve_trigger,
|
|
nude_weight,
|
|
)
|
|
branch = "fallback"
|
|
|
|
prompt = deps.assemble_prompt(
|
|
body,
|
|
style_preset,
|
|
quality_preset,
|
|
request.trigger,
|
|
request.prepend_trigger,
|
|
request.custom_style,
|
|
request.custom_quality,
|
|
request.extra_positive,
|
|
)
|
|
output = {
|
|
"sdxl_prompt": prompt,
|
|
"negative_prompt": deps.sanitize_negative_text(
|
|
deps.combine_negative(deps.default_negative, extracted_negative, request.negative_prompt, request.extra_negative)
|
|
),
|
|
"sdxl_softcore_prompt": "",
|
|
"sdxl_hardcore_prompt": "",
|
|
"softcore_negative_prompt": "",
|
|
"hardcore_negative_prompt": "",
|
|
"method": method,
|
|
}
|
|
return SDXLFormatRoute(
|
|
output=output,
|
|
branch=branch,
|
|
method=method,
|
|
target=target,
|
|
style_preset=style_preset,
|
|
quality_preset=quality_preset,
|
|
nude_weight=nude_weight,
|
|
)
|
|
|
|
|
|
def format_sdxl_prompt(request: SDXLFormatRequest, deps: SDXLFormatDependencies) -> dict[str, str]:
|
|
return format_sdxl_prompt_result(request, deps).output
|