Files
ComfyUI-Ethanfel-Prompt-Bui…/krea_format_route.py
T

180 lines
6.8 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
try:
from . import formatter_detail as detail_policy
from . import formatter_input as input_policy
from . import formatter_route_trace as trace_policy
from . import formatter_target as target_policy
except ImportError: # pragma: no cover - plain-script smoke tests
import formatter_detail as detail_policy
import formatter_input as input_policy
import formatter_route_trace as trace_policy
import formatter_target as target_policy
STYLE_MODES = ("preserve", "photographic", "minimal")
DEFAULT_STYLE_MODE = "preserve"
def style_mode_choices() -> list[str]:
return list(STYLE_MODES)
def normalize_style_mode(value: Any) -> str:
mode = str(value or "").strip().lower().replace("-", "_").replace(" ", "_")
return mode if mode in STYLE_MODES else DEFAULT_STYLE_MODE
@dataclass(frozen=True)
class KreaFormatRequest:
source_text: str
metadata_json: str = ""
negative_prompt: str = ""
input_hint: str = "auto"
target: str = "auto"
detail_level: str = "balanced"
style_mode: str = "preserve"
preserve_trigger: bool = False
extra_positive: str = ""
extra_negative: str = ""
@dataclass(frozen=True)
class KreaFormatRoute:
output: dict[str, str]
branch: str
method: str
target: str
detail_level: str
style_mode: str
@dataclass(frozen=True)
class KreaFormatDependencies:
trigger_candidates: tuple[str, ...]
clean: Callable[[Any], str]
row_from_inputs: Callable[[str, str, str], tuple[dict[str, Any] | None, str]]
normal_row_to_krea: Callable[[dict[str, Any], str, str], tuple[str, str]]
insta_pair_to_krea: Callable[[dict[str, Any], str, str], tuple[str, str, str, str]]
fallback_text_to_krea: Callable[[str, bool, str, str], tuple[str, str, str]]
append_formatter_hints: Callable[..., str]
combine_negative: Callable[..., str]
sanitize_prose_text: Callable[..., str]
sanitize_negative_text: Callable[[str], str]
def format_krea2_prompt_result(request: KreaFormatRequest, deps: KreaFormatDependencies) -> KreaFormatRoute:
detail_level = detail_policy.normalize_detail_level(request.detail_level)
style_mode = normalize_style_mode(request.style_mode)
target = target_policy.normalize_target(request.target)
input_hint = input_policy.normalize_input_hint(request.input_hint, text_hint=input_policy.INPUT_HINT_PROMPT)
row, method = deps.row_from_inputs(request.source_text, request.metadata_json, request.input_hint)
if row and input_policy.is_pair_metadata(row):
pair_target = target_policy.pair_policy(target)
soft_prompt, soft_negative, hard_prompt, hard_negative = deps.insta_pair_to_krea(
row,
detail_level,
style_mode,
)
soft_row = row.get("softcore_row") if isinstance(row.get("softcore_row"), dict) else {}
hard_row = row.get("hardcore_row") if isinstance(row.get("hardcore_row"), dict) else {}
soft_prompt = deps.append_formatter_hints(soft_prompt, row, soft_row)
hard_prompt = deps.append_formatter_hints(hard_prompt, row, hard_row)
if request.extra_positive.strip():
soft_prompt = f"{soft_prompt.rstrip()} {request.extra_positive.strip()}"
hard_prompt = f"{hard_prompt.rstrip()} {request.extra_positive.strip()}"
soft_prompt = deps.sanitize_prose_text(soft_prompt, triggers=deps.trigger_candidates)
hard_prompt = deps.sanitize_prose_text(hard_prompt, triggers=deps.trigger_candidates)
selected = hard_prompt if pair_target.selected_side == "hardcore" else soft_prompt
selected_negative = hard_negative if pair_target.selected_side == "hardcore" else soft_negative
negative = deps.sanitize_negative_text(
deps.combine_negative(selected_negative, request.negative_prompt, request.extra_negative)
)
output = {
"krea_prompt": selected,
"negative_prompt": negative,
"krea_softcore_prompt": soft_prompt,
"krea_hardcore_prompt": hard_prompt,
"softcore_negative_prompt": deps.sanitize_negative_text(
deps.combine_negative(soft_negative, request.extra_negative)
),
"hardcore_negative_prompt": deps.sanitize_negative_text(
deps.combine_negative(hard_negative, request.extra_negative)
),
"method": f"{method}:krea2(insta_of_pair)",
}
output["route_trace_json"] = trace_policy.route_trace_json(
formatter="krea2",
branch="insta_of_pair",
method=output["method"],
input_hint=input_hint,
target=target,
selected_side=pair_target.selected_side,
detail_level=detail_level,
style_mode=style_mode,
)
return KreaFormatRoute(
output=output,
branch="insta_of_pair",
method=output["method"],
target=target,
detail_level=detail_level,
style_mode=style_mode,
)
if row:
prompt, kind = deps.normal_row_to_krea(row, detail_level, style_mode)
prompt = deps.append_formatter_hints(prompt, row)
extracted_negative = deps.clean(row.get("negative_prompt"))
method = f"{method}:krea2({kind})"
branch = kind
else:
prompt, extracted_negative, method = deps.fallback_text_to_krea(
request.source_text,
request.preserve_trigger,
detail_level,
style_mode,
)
branch = "fallback"
if request.extra_positive.strip():
prompt = f"{prompt.rstrip()} {request.extra_positive.strip()}"
prompt = deps.sanitize_prose_text(prompt, triggers=deps.trigger_candidates)
negative = deps.sanitize_negative_text(
deps.combine_negative(extracted_negative, request.negative_prompt, request.extra_negative)
)
output = {
"krea_prompt": prompt,
"negative_prompt": negative,
"krea_softcore_prompt": "",
"krea_hardcore_prompt": "",
"softcore_negative_prompt": "",
"hardcore_negative_prompt": "",
"method": method,
}
output["route_trace_json"] = trace_policy.route_trace_json(
formatter="krea2",
branch=branch,
method=method,
input_hint=input_hint,
target=target,
detail_level=detail_level,
style_mode=style_mode,
)
return KreaFormatRoute(
output=output,
branch=branch,
method=method,
target=target,
detail_level=detail_level,
style_mode=style_mode,
)
def format_krea2_prompt(request: KreaFormatRequest, deps: KreaFormatDependencies) -> dict[str, str]:
return format_krea2_prompt_result(request, deps).output