267 lines
8.2 KiB
Python
267 lines
8.2 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
try:
|
|
from . import formatter_input as input_policy
|
|
from . import sdxl_format_route
|
|
from . import sdxl_tag_policy
|
|
from . import sdxl_tag_routes
|
|
from . import sdxl_presets as sdxl_policy
|
|
from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
|
|
except ImportError: # Allows local smoke tests with `python -c`.
|
|
import formatter_input as input_policy
|
|
import sdxl_format_route
|
|
import sdxl_tag_policy
|
|
import sdxl_tag_routes
|
|
import sdxl_presets as sdxl_policy
|
|
from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
|
|
|
|
|
|
TRIGGER_CANDIDATES = (
|
|
"sxcpinup_coloredpencil",
|
|
"sxcppnl7",
|
|
"mythp0rt",
|
|
)
|
|
|
|
SDXL_STYLE_PRESETS = sdxl_policy.SDXL_STYLE_PRESETS
|
|
SDXL_QUALITY_PRESETS = sdxl_policy.SDXL_QUALITY_PRESETS
|
|
SDXL_FORMATTER_PROFILES = sdxl_policy.SDXL_FORMATTER_PROFILES
|
|
SDXL_DEFAULT_NEGATIVE = sdxl_policy.SDXL_DEFAULT_NEGATIVE
|
|
SDXL_ACTION_FAMILY_TAGS = sdxl_policy.SDXL_ACTION_FAMILY_TAGS
|
|
SDXL_POSITION_FAMILY_TAGS = sdxl_policy.SDXL_POSITION_FAMILY_TAGS
|
|
|
|
PROMPT_FIELD_LABELS = input_policy.prompt_field_labels()
|
|
|
|
|
|
def sdxl_style_preset_choices() -> list[str]:
|
|
return sdxl_policy.sdxl_style_preset_choices()
|
|
|
|
|
|
def sdxl_quality_preset_choices() -> list[str]:
|
|
return sdxl_policy.sdxl_quality_preset_choices()
|
|
|
|
|
|
def sdxl_formatter_profile_choices() -> list[str]:
|
|
return sdxl_policy.sdxl_formatter_profile_choices()
|
|
|
|
|
|
def _clean(value: Any) -> str:
|
|
return input_policy.clean_text(value)
|
|
|
|
|
|
def _maybe_json(text: str) -> dict[str, Any] | None:
|
|
return input_policy.maybe_json(text)
|
|
|
|
|
|
def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]:
|
|
return input_policy.row_from_inputs(source_text, metadata_json, input_hint)
|
|
|
|
|
|
def _strip_trigger(text: str, preserve_trigger: bool) -> str:
|
|
return input_policy.strip_trigger_prefix(text, TRIGGER_CANDIDATES, preserve_trigger=preserve_trigger)
|
|
|
|
|
|
def _split_avoid(text: str) -> tuple[str, str]:
|
|
return input_policy.split_avoid(text)
|
|
|
|
|
|
def _strip_prompt_field_labels(text: str) -> str:
|
|
return input_policy.strip_prompt_field_labels(text, field_labels=PROMPT_FIELD_LABELS)
|
|
|
|
|
|
def _split_tag_text(text: Any) -> list[str]:
|
|
return sdxl_tag_policy.split_tag_text(text)
|
|
|
|
|
|
def _tag_key(tag: str) -> str:
|
|
return sdxl_tag_policy.tag_key(tag)
|
|
|
|
|
|
def _add(tags: list[str], seen: set[str], value: Any) -> None:
|
|
sdxl_tag_policy.add(tags, seen, value)
|
|
|
|
|
|
def _add_one(tags: list[str], seen: set[str], tag: str) -> None:
|
|
sdxl_tag_policy.add_one(tags, seen, tag)
|
|
|
|
|
|
def _metadata_family_tags(row: dict[str, Any]) -> list[str]:
|
|
return sdxl_tag_policy.metadata_family_tags(row)
|
|
|
|
|
|
def _formatter_hint_tags(*rows: dict[str, Any]) -> list[str]:
|
|
return sdxl_tag_policy.formatter_hint_tags(*rows)
|
|
|
|
|
|
def _combine_tags(*parts: Any) -> str:
|
|
return sdxl_tag_policy.combine_tags(*parts)
|
|
|
|
|
|
def _combine_negative(*parts: Any) -> str:
|
|
return sdxl_tag_policy.combine_negative(*parts)
|
|
|
|
|
|
def _count_tag(women_count: int = 0, men_count: int = 0) -> list[str]:
|
|
return sdxl_tag_policy.count_tag(women_count, men_count)
|
|
|
|
|
|
def _infer_counts(row: dict[str, Any]) -> tuple[int, int]:
|
|
return sdxl_tag_policy.infer_counts(row)
|
|
|
|
|
|
def _character_tags_from_descriptor(descriptor: Any) -> list[str]:
|
|
return sdxl_tag_policy.character_tags_from_descriptor(descriptor)
|
|
|
|
|
|
def _normal_character_tags(row: dict[str, Any]) -> list[str]:
|
|
return sdxl_tag_policy.normal_character_tags(row)
|
|
|
|
|
|
def _camera_tags_from_config(config: Any) -> list[str]:
|
|
return sdxl_tag_policy.camera_tags_from_config(config)
|
|
|
|
|
|
def _camera_tags(row: dict[str, Any], directive: Any = "", config: Any = None) -> list[str]:
|
|
return sdxl_tag_policy.camera_tags(row, directive, config)
|
|
|
|
|
|
def _explicit_tags(text: str, nude_weight: float) -> list[str]:
|
|
return sdxl_tag_policy.explicit_tags(text, nude_weight)
|
|
|
|
|
|
def _sdxl_tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies:
|
|
return sdxl_tag_policy.tag_route_dependencies()
|
|
|
|
|
|
def _row_core_tags(row: dict[str, Any], nude_weight: float) -> list[str]:
|
|
return sdxl_tag_routes.row_core_tags(
|
|
sdxl_tag_routes.SDXLRowTagRequest(row, nude_weight),
|
|
_sdxl_tag_route_dependencies(),
|
|
)
|
|
|
|
|
|
def _style_prefix(style_preset: str, trigger: str, prepend_trigger: bool, custom_style: str) -> str:
|
|
style = custom_style if _clean(custom_style) else SDXL_STYLE_PRESETS.get(
|
|
style_preset,
|
|
SDXL_STYLE_PRESETS[sdxl_policy.DEFAULT_STYLE_PRESET],
|
|
)
|
|
trigger = _clean(trigger)
|
|
if prepend_trigger and trigger:
|
|
return _combine_tags(style, trigger)
|
|
return style
|
|
|
|
|
|
def _quality_tail(quality_preset: str, custom_quality: str) -> str:
|
|
return _clean(custom_quality) or SDXL_QUALITY_PRESETS.get(
|
|
quality_preset,
|
|
SDXL_QUALITY_PRESETS[sdxl_policy.DEFAULT_QUALITY_PRESET],
|
|
)
|
|
|
|
|
|
def _soft_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str:
|
|
return sdxl_tag_routes.soft_tags(
|
|
sdxl_tag_routes.SDXLPairTagRequest(row, root, nude_weight),
|
|
_sdxl_tag_route_dependencies(),
|
|
)
|
|
|
|
|
|
def _hard_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str:
|
|
return sdxl_tag_routes.hard_tags(
|
|
sdxl_tag_routes.SDXLPairTagRequest(row, root, nude_weight),
|
|
_sdxl_tag_route_dependencies(),
|
|
)
|
|
|
|
|
|
def _assemble_prompt(
|
|
body_tags: str,
|
|
style_preset: str,
|
|
quality_preset: str,
|
|
trigger: str,
|
|
prepend_trigger: bool,
|
|
custom_style: str,
|
|
custom_quality: str,
|
|
extra_positive: str,
|
|
) -> str:
|
|
return sanitize_tag_prompt(
|
|
_combine_tags(
|
|
_style_prefix(style_preset, trigger, prepend_trigger, custom_style),
|
|
body_tags,
|
|
_quality_tail(quality_preset, custom_quality),
|
|
extra_positive,
|
|
),
|
|
triggers=(trigger,),
|
|
)
|
|
|
|
|
|
def _fallback_text_to_sdxl(
|
|
source_text: str,
|
|
preserve_trigger: bool,
|
|
nude_weight: float,
|
|
) -> tuple[str, str, str]:
|
|
positive, negative = _split_avoid(_strip_trigger(source_text, preserve_trigger))
|
|
positive = _strip_prompt_field_labels(positive)
|
|
tags = _combine_tags(positive, ", ".join(_explicit_tags(positive, nude_weight)))
|
|
return tags, negative, "text(fallback)"
|
|
|
|
|
|
def _sdxl_format_dependencies() -> sdxl_format_route.SDXLFormatDependencies:
|
|
return sdxl_format_route.SDXLFormatDependencies(
|
|
default_negative=SDXL_DEFAULT_NEGATIVE,
|
|
apply_formatter_profile=lambda profile, style, quality: sdxl_policy.apply_formatter_profile(
|
|
profile,
|
|
style_preset=style,
|
|
quality_preset=quality,
|
|
),
|
|
clean=_clean,
|
|
row_from_inputs=_row_from_inputs,
|
|
row_core_tags=_row_core_tags,
|
|
soft_tags=_soft_tags,
|
|
hard_tags=_hard_tags,
|
|
fallback_text_to_sdxl=_fallback_text_to_sdxl,
|
|
assemble_prompt=_assemble_prompt,
|
|
combine_negative=_combine_negative,
|
|
sanitize_negative_text=sanitize_negative_text,
|
|
)
|
|
|
|
|
|
def format_sdxl_prompt(
|
|
source_text: str,
|
|
metadata_json: str = "",
|
|
negative_prompt: str = "",
|
|
input_hint: str = "auto",
|
|
target: str = "auto",
|
|
style_preset: str = "flat_vector_pony",
|
|
quality_preset: str = "pony_high",
|
|
trigger: str = "mythp0rt",
|
|
prepend_trigger: bool = True,
|
|
preserve_trigger: bool = False,
|
|
nude_weight: float = 1.29,
|
|
custom_style: str = "",
|
|
custom_quality: str = "",
|
|
extra_positive: str = "",
|
|
extra_negative: str = "",
|
|
formatter_profile: str = "manual_controls",
|
|
) -> dict[str, str]:
|
|
return sdxl_format_route.format_sdxl_prompt(
|
|
sdxl_format_route.SDXLFormatRequest(
|
|
source_text=source_text,
|
|
metadata_json=metadata_json,
|
|
negative_prompt=negative_prompt,
|
|
input_hint=input_hint,
|
|
target=target,
|
|
style_preset=style_preset,
|
|
quality_preset=quality_preset,
|
|
trigger=trigger,
|
|
prepend_trigger=prepend_trigger,
|
|
preserve_trigger=preserve_trigger,
|
|
nude_weight=nude_weight,
|
|
custom_style=custom_style,
|
|
custom_quality=custom_quality,
|
|
extra_positive=extra_positive,
|
|
extra_negative=extra_negative,
|
|
formatter_profile=formatter_profile,
|
|
),
|
|
_sdxl_format_dependencies(),
|
|
)
|