from __future__ import annotations from typing import Any try: from . import formatter_input as input_policy from . import sdxl_format_route from . import sdxl_tag_policy from . import sdxl_tag_routes from . import sdxl_presets as sdxl_policy from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt except ImportError: # Allows local smoke tests with `python -c`. import formatter_input as input_policy import sdxl_format_route import sdxl_tag_policy import sdxl_tag_routes import sdxl_presets as sdxl_policy from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt TRIGGER_CANDIDATES = ( "sxcpinup_coloredpencil", "sxcppnl7", "mythp0rt", ) SDXL_STYLE_PRESETS = sdxl_policy.SDXL_STYLE_PRESETS SDXL_QUALITY_PRESETS = sdxl_policy.SDXL_QUALITY_PRESETS SDXL_FORMATTER_PROFILES = sdxl_policy.SDXL_FORMATTER_PROFILES SDXL_DEFAULT_NEGATIVE = sdxl_policy.SDXL_DEFAULT_NEGATIVE SDXL_ACTION_FAMILY_TAGS = sdxl_policy.SDXL_ACTION_FAMILY_TAGS SDXL_POSITION_FAMILY_TAGS = sdxl_policy.SDXL_POSITION_FAMILY_TAGS PROMPT_FIELD_LABELS = input_policy.prompt_field_labels() def sdxl_style_preset_choices() -> list[str]: return sdxl_policy.sdxl_style_preset_choices() def sdxl_quality_preset_choices() -> list[str]: return sdxl_policy.sdxl_quality_preset_choices() def sdxl_formatter_profile_choices() -> list[str]: return sdxl_policy.sdxl_formatter_profile_choices() def _clean(value: Any) -> str: return input_policy.clean_text(value) def _maybe_json(text: str) -> dict[str, Any] | None: return input_policy.maybe_json(text) def _row_from_inputs(source_text: str, metadata_json: str, input_hint: str) -> tuple[dict[str, Any] | None, str]: return input_policy.row_from_inputs(source_text, metadata_json, input_hint) def _strip_trigger(text: str, preserve_trigger: bool) -> str: return input_policy.strip_trigger_prefix(text, TRIGGER_CANDIDATES, preserve_trigger=preserve_trigger) def _split_avoid(text: str) -> tuple[str, str]: return input_policy.split_avoid(text) def _strip_prompt_field_labels(text: str) -> str: return input_policy.strip_prompt_field_labels(text, field_labels=PROMPT_FIELD_LABELS) def _split_tag_text(text: Any) -> list[str]: return sdxl_tag_policy.split_tag_text(text) def _tag_key(tag: str) -> str: return sdxl_tag_policy.tag_key(tag) def _add(tags: list[str], seen: set[str], value: Any) -> None: sdxl_tag_policy.add(tags, seen, value) def _add_one(tags: list[str], seen: set[str], tag: str) -> None: sdxl_tag_policy.add_one(tags, seen, tag) def _metadata_family_tags(row: dict[str, Any]) -> list[str]: return sdxl_tag_policy.metadata_family_tags(row) def _formatter_hint_tags(*rows: dict[str, Any]) -> list[str]: return sdxl_tag_policy.formatter_hint_tags(*rows) def _combine_tags(*parts: Any) -> str: return sdxl_tag_policy.combine_tags(*parts) def _combine_negative(*parts: Any) -> str: return sdxl_tag_policy.combine_negative(*parts) def _count_tag(women_count: int = 0, men_count: int = 0) -> list[str]: return sdxl_tag_policy.count_tag(women_count, men_count) def _infer_counts(row: dict[str, Any]) -> tuple[int, int]: return sdxl_tag_policy.infer_counts(row) def _character_tags_from_descriptor(descriptor: Any) -> list[str]: return sdxl_tag_policy.character_tags_from_descriptor(descriptor) def _normal_character_tags(row: dict[str, Any]) -> list[str]: return sdxl_tag_policy.normal_character_tags(row) def _camera_tags_from_config(config: Any) -> list[str]: return sdxl_tag_policy.camera_tags_from_config(config) def _camera_tags(row: dict[str, Any], directive: Any = "", config: Any = None) -> list[str]: return sdxl_tag_policy.camera_tags(row, directive, config) def _explicit_tags(text: str, nude_weight: float) -> list[str]: return sdxl_tag_policy.explicit_tags(text, nude_weight) def _sdxl_tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies: return sdxl_tag_policy.tag_route_dependencies() def _row_core_tags(row: dict[str, Any], nude_weight: float) -> list[str]: return sdxl_tag_routes.row_core_tags( sdxl_tag_routes.SDXLRowTagRequest(row, nude_weight), _sdxl_tag_route_dependencies(), ) def _style_prefix(style_preset: str, trigger: str, prepend_trigger: bool, custom_style: str) -> str: style = custom_style if _clean(custom_style) else SDXL_STYLE_PRESETS.get( style_preset, SDXL_STYLE_PRESETS[sdxl_policy.DEFAULT_STYLE_PRESET], ) trigger = _clean(trigger) if prepend_trigger and trigger: return _combine_tags(style, trigger) return style def _quality_tail(quality_preset: str, custom_quality: str) -> str: return _clean(custom_quality) or SDXL_QUALITY_PRESETS.get( quality_preset, SDXL_QUALITY_PRESETS[sdxl_policy.DEFAULT_QUALITY_PRESET], ) def _soft_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str: return sdxl_tag_routes.soft_tags( sdxl_tag_routes.SDXLPairTagRequest(row, root, nude_weight), _sdxl_tag_route_dependencies(), ) def _hard_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str: return sdxl_tag_routes.hard_tags( sdxl_tag_routes.SDXLPairTagRequest(row, root, nude_weight), _sdxl_tag_route_dependencies(), ) def _assemble_prompt( body_tags: str, style_preset: str, quality_preset: str, trigger: str, prepend_trigger: bool, custom_style: str, custom_quality: str, extra_positive: str, ) -> str: return sanitize_tag_prompt( _combine_tags( _style_prefix(style_preset, trigger, prepend_trigger, custom_style), body_tags, _quality_tail(quality_preset, custom_quality), extra_positive, ), triggers=(trigger,), ) def _fallback_text_to_sdxl( source_text: str, preserve_trigger: bool, nude_weight: float, ) -> tuple[str, str, str]: positive, negative = _split_avoid(_strip_trigger(source_text, preserve_trigger)) positive = _strip_prompt_field_labels(positive) tags = _combine_tags(positive, ", ".join(_explicit_tags(positive, nude_weight))) return tags, negative, "text(fallback)" def _sdxl_format_dependencies() -> sdxl_format_route.SDXLFormatDependencies: return sdxl_format_route.SDXLFormatDependencies( default_negative=SDXL_DEFAULT_NEGATIVE, apply_formatter_profile=lambda profile, style, quality: sdxl_policy.apply_formatter_profile( profile, style_preset=style, quality_preset=quality, ), clean=_clean, row_from_inputs=_row_from_inputs, row_core_tags=_row_core_tags, soft_tags=_soft_tags, hard_tags=_hard_tags, fallback_text_to_sdxl=_fallback_text_to_sdxl, assemble_prompt=_assemble_prompt, combine_negative=_combine_negative, sanitize_negative_text=sanitize_negative_text, ) def format_sdxl_prompt( source_text: str, metadata_json: str = "", negative_prompt: str = "", input_hint: str = "auto", target: str = "auto", style_preset: str = "flat_vector_pony", quality_preset: str = "pony_high", trigger: str = "mythp0rt", prepend_trigger: bool = True, preserve_trigger: bool = False, nude_weight: float = 1.29, custom_style: str = "", custom_quality: str = "", extra_positive: str = "", extra_negative: str = "", formatter_profile: str = "manual_controls", ) -> dict[str, str]: return sdxl_format_route.format_sdxl_prompt( sdxl_format_route.SDXLFormatRequest( source_text=source_text, metadata_json=metadata_json, negative_prompt=negative_prompt, input_hint=input_hint, target=target, style_preset=style_preset, quality_preset=quality_preset, trigger=trigger, prepend_trigger=prepend_trigger, preserve_trigger=preserve_trigger, nude_weight=nude_weight, custom_style=custom_style, custom_quality=custom_quality, extra_positive=extra_positive, extra_negative=extra_negative, formatter_profile=formatter_profile, ), _sdxl_format_dependencies(), )