Files
ComfyUI-Ethanfel-Prompt-Bui…/sdxl_formatter.py

267 lines
8.2 KiB
Python

from __future__ import annotations
from typing import Any
try:
from . import formatter_input as input_policy
from . import sdxl_format_route
from . import sdxl_tag_policy
from . import sdxl_tag_routes
from . import sdxl_presets as sdxl_policy
from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
except ImportError: # Allows local smoke tests with `python -c`.
import formatter_input as input_policy
import sdxl_format_route
import sdxl_tag_policy
import sdxl_tag_routes
import sdxl_presets as sdxl_policy
from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
TRIGGER_CANDIDATES = (
"sxcpinup_coloredpencil",
"sxcppnl7",
"mythp0rt",
)
SDXL_STYLE_PRESETS = sdxl_policy.SDXL_STYLE_PRESETS
SDXL_QUALITY_PRESETS = sdxl_policy.SDXL_QUALITY_PRESETS
SDXL_FORMATTER_PROFILES = sdxl_policy.SDXL_FORMATTER_PROFILES
SDXL_DEFAULT_NEGATIVE = sdxl_policy.SDXL_DEFAULT_NEGATIVE
SDXL_ACTION_FAMILY_TAGS = sdxl_policy.SDXL_ACTION_FAMILY_TAGS
SDXL_POSITION_FAMILY_TAGS = sdxl_policy.SDXL_POSITION_FAMILY_TAGS
PROMPT_FIELD_LABELS = input_policy.prompt_field_labels()
def sdxl_style_preset_choices() -> list[str]:
return sdxl_policy.sdxl_style_preset_choices()
def sdxl_quality_preset_choices() -> list[str]:
return sdxl_policy.sdxl_quality_preset_choices()
def sdxl_formatter_profile_choices() -> list[str]:
return sdxl_policy.sdxl_formatter_profile_choices()
def _clean(value: Any) -> str:
return input_policy.clean_text(value)
def _maybe_json(text: str) -> dict[str, Any] | None:
return input_policy.maybe_json(text)
def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]:
return input_policy.row_from_inputs(source_text, metadata_json, input_hint)
def _strip_trigger(text: str, preserve_trigger: bool) -> str:
return input_policy.strip_trigger_prefix(text, TRIGGER_CANDIDATES, preserve_trigger=preserve_trigger)
def _split_avoid(text: str) -> tuple[str, str]:
return input_policy.split_avoid(text)
def _strip_prompt_field_labels(text: str) -> str:
return input_policy.strip_prompt_field_labels(text, field_labels=PROMPT_FIELD_LABELS)
def _split_tag_text(text: Any) -> list[str]:
return sdxl_tag_policy.split_tag_text(text)
def _tag_key(tag: str) -> str:
return sdxl_tag_policy.tag_key(tag)
def _add(tags: list[str], seen: set[str], value: Any) -> None:
sdxl_tag_policy.add(tags, seen, value)
def _add_one(tags: list[str], seen: set[str], tag: str) -> None:
sdxl_tag_policy.add_one(tags, seen, tag)
def _metadata_family_tags(row: dict[str, Any]) -> list[str]:
return sdxl_tag_policy.metadata_family_tags(row)
def _formatter_hint_tags(*rows: dict[str, Any]) -> list[str]:
return sdxl_tag_policy.formatter_hint_tags(*rows)
def _combine_tags(*parts: Any) -> str:
return sdxl_tag_policy.combine_tags(*parts)
def _combine_negative(*parts: Any) -> str:
return sdxl_tag_policy.combine_negative(*parts)
def _count_tag(women_count: int = 0, men_count: int = 0) -> list[str]:
return sdxl_tag_policy.count_tag(women_count, men_count)
def _infer_counts(row: dict[str, Any]) -> tuple[int, int]:
return sdxl_tag_policy.infer_counts(row)
def _character_tags_from_descriptor(descriptor: Any) -> list[str]:
return sdxl_tag_policy.character_tags_from_descriptor(descriptor)
def _normal_character_tags(row: dict[str, Any]) -> list[str]:
return sdxl_tag_policy.normal_character_tags(row)
def _camera_tags_from_config(config: Any) -> list[str]:
return sdxl_tag_policy.camera_tags_from_config(config)
def _camera_tags(row: dict[str, Any], directive: Any = "", config: Any = None) -> list[str]:
return sdxl_tag_policy.camera_tags(row, directive, config)
def _explicit_tags(text: str, nude_weight: float) -> list[str]:
return sdxl_tag_policy.explicit_tags(text, nude_weight)
def _sdxl_tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies:
return sdxl_tag_policy.tag_route_dependencies()
def _row_core_tags(row: dict[str, Any], nude_weight: float) -> list[str]:
return sdxl_tag_routes.row_core_tags(
sdxl_tag_routes.SDXLRowTagRequest(row, nude_weight),
_sdxl_tag_route_dependencies(),
)
def _style_prefix(style_preset: str, trigger: str, prepend_trigger: bool, custom_style: str) -> str:
style = custom_style if _clean(custom_style) else SDXL_STYLE_PRESETS.get(
style_preset,
SDXL_STYLE_PRESETS[sdxl_policy.DEFAULT_STYLE_PRESET],
)
trigger = _clean(trigger)
if prepend_trigger and trigger:
return _combine_tags(style, trigger)
return style
def _quality_tail(quality_preset: str, custom_quality: str) -> str:
return _clean(custom_quality) or SDXL_QUALITY_PRESETS.get(
quality_preset,
SDXL_QUALITY_PRESETS[sdxl_policy.DEFAULT_QUALITY_PRESET],
)
def _soft_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str:
return sdxl_tag_routes.soft_tags(
sdxl_tag_routes.SDXLPairTagRequest(row, root, nude_weight),
_sdxl_tag_route_dependencies(),
)
def _hard_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str:
return sdxl_tag_routes.hard_tags(
sdxl_tag_routes.SDXLPairTagRequest(row, root, nude_weight),
_sdxl_tag_route_dependencies(),
)
def _assemble_prompt(
body_tags: str,
style_preset: str,
quality_preset: str,
trigger: str,
prepend_trigger: bool,
custom_style: str,
custom_quality: str,
extra_positive: str,
) -> str:
return sanitize_tag_prompt(
_combine_tags(
_style_prefix(style_preset, trigger, prepend_trigger, custom_style),
body_tags,
_quality_tail(quality_preset, custom_quality),
extra_positive,
),
triggers=(trigger,),
)
def _fallback_text_to_sdxl(
source_text: str,
preserve_trigger: bool,
nude_weight: float,
) -> tuple[str, str, str]:
positive, negative = _split_avoid(_strip_trigger(source_text, preserve_trigger))
positive = _strip_prompt_field_labels(positive)
tags = _combine_tags(positive, ", ".join(_explicit_tags(positive, nude_weight)))
return tags, negative, "text(fallback)"
def _sdxl_format_dependencies() -> sdxl_format_route.SDXLFormatDependencies:
return sdxl_format_route.SDXLFormatDependencies(
default_negative=SDXL_DEFAULT_NEGATIVE,
apply_formatter_profile=lambda profile, style, quality: sdxl_policy.apply_formatter_profile(
profile,
style_preset=style,
quality_preset=quality,
),
clean=_clean,
row_from_inputs=_row_from_inputs,
row_core_tags=_row_core_tags,
soft_tags=_soft_tags,
hard_tags=_hard_tags,
fallback_text_to_sdxl=_fallback_text_to_sdxl,
assemble_prompt=_assemble_prompt,
combine_negative=_combine_negative,
sanitize_negative_text=sanitize_negative_text,
)
def format_sdxl_prompt(
source_text: str,
metadata_json: str = "",
negative_prompt: str = "",
input_hint: str = "auto",
target: str = "auto",
style_preset: str = "flat_vector_pony",
quality_preset: str = "pony_high",
trigger: str = "mythp0rt",
prepend_trigger: bool = True,
preserve_trigger: bool = False,
nude_weight: float = 1.29,
custom_style: str = "",
custom_quality: str = "",
extra_positive: str = "",
extra_negative: str = "",
formatter_profile: str = "manual_controls",
) -> dict[str, str]:
return sdxl_format_route.format_sdxl_prompt(
sdxl_format_route.SDXLFormatRequest(
source_text=source_text,
metadata_json=metadata_json,
negative_prompt=negative_prompt,
input_hint=input_hint,
target=target,
style_preset=style_preset,
quality_preset=quality_preset,
trigger=trigger,
prepend_trigger=prepend_trigger,
preserve_trigger=preserve_trigger,
nude_weight=nude_weight,
custom_style=custom_style,
custom_quality=custom_quality,
extra_positive=extra_positive,
extra_negative=extra_negative,
formatter_profile=formatter_profile,
),
_sdxl_format_dependencies(),
)