diff --git a/docs/prompt-architecture-improvement-plan.md b/docs/prompt-architecture-improvement-plan.md index 98193d9..3fe06bb 100644 --- a/docs/prompt-architecture-improvement-plan.md +++ b/docs/prompt-architecture-improvement-plan.md @@ -399,7 +399,7 @@ Keep here: - trigger behavior; - style and quality presets; - final style/body/quality prompt assembly; -- nude-weight setting and explicit-tag helper policy; +- nude-weight setting; - negative-prompt assembly. Already isolated: @@ -408,8 +408,11 @@ Already isolated: tag extraction behind `SDXLRowTagRequest`, `SDXLPairTagRequest`, `SDXLTagRouteDependencies`, and `SDXLTagRoute`; `sdxl_formatter.py` keeps compatibility wrappers plus final style/quality/trigger assembly. -- metadata-family tag hints from `action_family`, `position_family`, and - `position_keys`. +- `sdxl_tag_policy.py` owns SDXL tag splitting, tag-key dedupe, count inference, + character descriptor tags, metadata-family hint tags, camera tags, + explicit/nude helper tags, and route dependency assembly. +- metadata-family tag hint data from `action_family`, `position_family`, and + `position_keys` stays in `sdxl_presets.py` and is read by `sdxl_tag_policy.py`. - shared row route metadata reads from `route_metadata.py`. - shared formatter input parsing from `formatter_input.py`. - style presets, quality presets, default negative prompt, and action/position diff --git a/docs/prompt-pool-routing-map.md b/docs/prompt-pool-routing-map.md index 705e633..b3dc005 100644 --- a/docs/prompt-pool-routing-map.md +++ b/docs/prompt-pool-routing-map.md @@ -123,6 +123,7 @@ Core helper ownership: | `node_tooltips.py` | Node input tooltip inventory, node-specific overrides, dynamic-input fallback rules, and tooltip injection installer used by `__init__.py`. | | `server_routes.py` | Pure payload handlers for profile-save and accumulator server endpoints, used by ComfyUI routes and smoke tests without importing ComfyUI. | | `sdxl_presets.py` | SDXL formatter profiles, style presets, quality presets, default negative prompt, and metadata-family tag hints used by the SDXL formatter and node choice lists. | +| `sdxl_tag_policy.py` | SDXL tag splitting, tag-key dedupe, count inference, character descriptor tags, metadata-family/camera/explicit helper tags, and route dependency assembly used by `sdxl_formatter.py` and `sdxl_tag_routes.py`. | | `caption_policy.py` | Caption naturalizer policy data and helpers: caption profiles, style tails, item labels, metadata-family caption labels, detail/style-policy normalization, clothing cleanup, and composition cleanup. | ## Node IO Map @@ -158,7 +159,7 @@ These recipes identify the intended road before editing prompt text. | Force a custom frame/composition | `SxCP Composition Pool` or `SxCP Location Theme` -> builder/pair | `combine_mode=replace` to force; `add` to mix | `_composition_pool`, `row_location.apply_composition_config_to_legacy_row`, Krea composition phrase | | Use Qwen/orbit camera geometry | Qwen/orbit node -> camera_config -> builder/pair | For pair, use `softcore_camera_config` and/or `hardcore_camera_config`; set mode from config in options | `_camera_config_with_mode`, `_camera_directive`, `_camera_scene_directive_for_context` | | Use Krea2 for only hard prompt from a pair | Pair `metadata_json` -> Krea2 Formatter | `target=hardcore`, `input_hint=metadata_json` or auto with metadata connected | `_insta_pair_to_krea`, hard row fields | -| Convert builder output to SDXL tags | Builder/pair metadata -> SDXL Formatter | Use metadata input; set `target`; select style and quality preset | `_row_core_tags`, `_soft_tags`, `_hard_tags` | +| Convert builder output to SDXL tags | Builder/pair metadata -> SDXL Formatter | Use metadata input; set `target`; select style and quality preset | `sdxl_tag_routes.py`, `sdxl_tag_policy.py`, compatibility wrappers `_row_core_tags` / `_soft_tags` / `_hard_tags` | | Save/reuse character | Slot/profile nodes -> Profile Save/Load -> slot/builder | Save from the row/profile data you want, not a freshly randomized disconnected route | `character_profile.py`, `web/profile_buttons.js`, profile JSON | ## Seed Axes @@ -705,6 +706,9 @@ not parse metadata. That is a wiring/input-hint issue, not a prompt pool issue. `sdxl_tag_routes.hard_tags_result` through compatibility wrappers. - Normal metadata row: `sdxl_tag_routes.row_core_tags_result` through the `_row_core_tags` compatibility wrapper. +- Tag mechanics: `sdxl_tag_policy.py` supplies splitting, dedupe, count, + character, metadata-family, camera, and explicit helper tags to the route + layer. - Plain text fallback: `_fallback_text_to_sdxl`. Use this route for style triggers, weighted tag style, nude weighting, formatter @@ -714,9 +718,9 @@ SDXL field consumption: | Branch | Reads most from | Key functions | | --- | --- | --- | -| Normal metadata | cast descriptors, age/body/skin/hair/eyes, `action_family`, `position_family`, `position_keys`, item, role graph, scene, camera config/directive | `sdxl_tag_routes.row_core_tags_result`, `_metadata_family_tags`, `_camera_tags` | +| Normal metadata | cast descriptors, age/body/skin/hair/eyes, `action_family`, `position_family`, `position_keys`, item, role graph, scene, camera config/directive | `sdxl_tag_routes.row_core_tags_result`, `sdxl_tag_policy.metadata_family_tags`, `sdxl_tag_policy.camera_tags` | | Pair softcore | `softcore_row`, pair partner styling, root soft camera config | `sdxl_tag_routes.soft_tags_result` | -| Pair hardcore | `hardcore_row`, `action_family`, `position_family`, `position_keys`, `hardcore_clothing_state`, hard camera fields, hard prompt text | `sdxl_tag_routes.hard_tags_result`, `_metadata_family_tags` | +| Pair hardcore | `hardcore_row`, `action_family`, `position_family`, `position_keys`, `hardcore_clothing_state`, hard camera fields, hard prompt text | `sdxl_tag_routes.hard_tags_result`, `sdxl_tag_policy.metadata_family_tags` | | Text fallback | `source_text`, preserve-trigger setting, shared field-label stripping | `_fallback_text_to_sdxl` | SDXL is the right place for model trigger handling, tag ordering, weight syntax, @@ -883,7 +887,7 @@ pair metadata through the core Python APIs, then verifies: | Man appears described in POV | POV labels, `krea_cast.cast_prose` omit labels, `krea_pov_actions.pov_action_phrase`. | | Camera prompt missing from Krea2 | Row `camera_directive` / `camera_scene_directive`, then Krea `_camera_phrase`. | | Trigger missing in Krea2 fallback | `format_krea2_prompt` preserve-trigger fallback behavior. | -| SDXL tags too weak/wrong style | `sdxl_formatter.py` presets and `_row_core_tags` / `_soft_tags` / `_hard_tags`. | +| SDXL tags too weak/wrong style | `sdxl_presets.py`, `sdxl_tag_policy.py`, then `sdxl_tag_routes.py`; formatter wrappers `_row_core_tags` / `_soft_tags` / `_hard_tags` should stay compatibility-only. | | Duplicate punctuation, empty labels, repeated trigger, repeated tag item | `prompt_hygiene.py`, then the route-specific formatter if the repeated content is semantic. | | Bed/sheet/couch or malformed surface wording leaks into hardcore prompts | `hardcore_text_cleanup.py`, then the relevant category pool/template. | | Saved profile does not match liked character | Profile save/load path and whether the saved input is row metadata or regenerated slot config. | diff --git a/sdxl_formatter.py b/sdxl_formatter.py index 2c94c97..b0e0b85 100644 --- a/sdxl_formatter.py +++ b/sdxl_formatter.py @@ -1,17 +1,16 @@ from __future__ import annotations -import re from typing import Any try: from . import formatter_input as input_policy - from . import route_metadata as route_metadata_policy + from . import sdxl_tag_policy from . import sdxl_tag_routes from . import sdxl_presets as sdxl_policy from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt except ImportError: # Allows local smoke tests with `python -c`. import formatter_input as input_policy - import route_metadata as route_metadata_policy + import sdxl_tag_policy import sdxl_tag_routes import sdxl_presets as sdxl_policy from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt @@ -69,238 +68,68 @@ def _strip_prompt_field_labels(text: str) -> str: return input_policy.strip_prompt_field_labels(text, field_labels=PROMPT_FIELD_LABELS) -def _prompt_field(text: str, label: str) -> str: - return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS) - - -def _row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str: - return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS) - - def _split_tag_text(text: Any) -> list[str]: - text = _clean(text) - if not text: - return [] - text = re.sub(r"\bWoman [A-Z]'s\b", "woman's", text) - text = re.sub(r"\bMan [A-Z]'s\b", "man's", text) - text = re.sub(r"\bWoman [A-Z]\b", "woman", text) - text = re.sub(r"\bMan [A-Z]\b", "man", text) - text = re.sub( - r"\b(?:Clothing state|Visual clothing state|visible remaining styling|teaser outfit detail|softcore visual reference|Sexual scene|Role graph):\s*", - "", - text, - flags=re.IGNORECASE, - ) - text = re.sub(r"\b(?:and|with)\b", ",", text, flags=re.IGNORECASE) - parts = re.split(r"\s*[,;]\s*", text) - return [_clean(part).strip(" .") for part in parts if _clean(part).strip(" .")] + return sdxl_tag_policy.split_tag_text(text) def _tag_key(tag: str) -> str: - text = _clean(tag).lower() - text = re.sub(r"^\((.*?):[0-9.]+\)$", r"\1", text) - text = text.strip("() ") - return text + return sdxl_tag_policy.tag_key(tag) def _add(tags: list[str], seen: set[str], value: Any) -> None: - for tag in _split_tag_text(value): - key = _tag_key(tag) - if key and key not in seen: - tags.append(tag) - seen.add(key) + sdxl_tag_policy.add(tags, seen, value) def _add_one(tags: list[str], seen: set[str], tag: str) -> None: - tag = _clean(tag).strip(" ,") - key = _tag_key(tag) - if tag and key and key not in seen: - tags.append(tag) - seen.add(key) + sdxl_tag_policy.add_one(tags, seen, tag) def _metadata_family_tags(row: dict[str, Any]) -> list[str]: - tags: list[str] = [] - action_family = route_metadata_policy.row_action_family(row) - tags.extend(SDXL_ACTION_FAMILY_TAGS.get(action_family, ())) - - position_family = route_metadata_policy.row_position_family(row) - tags.extend(SDXL_POSITION_FAMILY_TAGS.get(position_family, ())) - - for key in route_metadata_policy.row_position_keys(row, include_unknown=True): - key_text = _clean(key) - if key_text: - tags.append(key_text.replace("_", " ")) - return tags + return sdxl_tag_policy.metadata_family_tags(row) def _formatter_hint_tags(*rows: dict[str, Any]) -> list[str]: - tags: list[str] = [] - for row in rows: - if not isinstance(row, dict): - continue - for hint in route_metadata_policy.row_formatter_hints(row, "sdxl"): - hint = _clean(hint).strip(" ,.") - if hint and hint not in tags: - tags.append(hint) - return tags + return sdxl_tag_policy.formatter_hint_tags(*rows) def _combine_tags(*parts: Any) -> str: - tags: list[str] = [] - seen: set[str] = set() - for part in parts: - _add(tags, seen, part) - return ", ".join(tags) + return sdxl_tag_policy.combine_tags(*parts) def _combine_negative(*parts: Any) -> str: - return _combine_tags(*(part for part in parts if _clean(part))) + return sdxl_tag_policy.combine_negative(*parts) def _count_tag(women_count: int = 0, men_count: int = 0) -> list[str]: - tags = [] - if women_count > 0: - tags.append(f"{women_count}woman" if women_count == 1 else f"{women_count}women") - if men_count > 0: - tags.append(f"{men_count}man" if men_count == 1 else f"{men_count}men") - return tags + return sdxl_tag_policy.count_tag(women_count, men_count) def _infer_counts(row: dict[str, Any]) -> tuple[int, int]: - try: - women = int(row.get("women_count") or 0) - men = int(row.get("men_count") or 0) - except (TypeError, ValueError): - women = men = 0 - if women or men: - return women, men - subject = _clean(row.get("subject_type") or row.get("primary_subject")).lower() - phrase = _clean(row.get("subject_phrase")).lower() - text = f"{subject} {phrase}" - if "two women" in text: - return 2, 0 - if "two men" in text: - return 0, 2 - if "woman and" in text or "woman a" in text and "man a" in text: - return 1, 1 - if "group" in text: - return 2, 2 - if "man" in text and "woman" not in text: - return 0, 1 - return 1, 0 + return sdxl_tag_policy.infer_counts(row) def _character_tags_from_descriptor(descriptor: Any) -> list[str]: - text = _clean(descriptor) - text = re.sub(r"\bWoman [A-Z]\s*/\s*primary creator:\s*", "", text) - text = re.sub(r"\b(?:Woman|Man) [A-Z]:\s*", "", text) - text = re.sub(r"\balongside\b", ",", text, flags=re.IGNORECASE) - parts = _split_tag_text(text) - cleaned = [] - for part in parts: - part = re.sub(r"\bfigure\b", "build", part, flags=re.IGNORECASE) - part = part.replace("adult adult", "adult") - cleaned.append(part) - return cleaned + return sdxl_tag_policy.character_tags_from_descriptor(descriptor) def _normal_character_tags(row: dict[str, Any]) -> list[str]: - descriptor = ( - _clean(row.get("cast_descriptor_text")) - or _prompt_field(row.get("prompt", ""), "Characters") - or _prompt_field(row.get("prompt", ""), "Cast descriptors") - ) - if descriptor: - return _character_tags_from_descriptor(descriptor) - - parts = [ - _clean(row.get("age") or row.get("age_band")), - _clean(row.get("subject_phrase") or row.get("subject_type") or row.get("primary_subject")), - _clean(row.get("body_phrase") or row.get("body") or row.get("body_type")), - _clean(row.get("skin")), - _clean(row.get("hair")), - _clean(row.get("eyes")), - ] - return [part for part in parts if part and part not in ("woman", "man", "single_any")] + return sdxl_tag_policy.normal_character_tags(row) def _camera_tags_from_config(config: Any) -> list[str]: - if not isinstance(config, dict): - return [] - if _clean(config.get("camera_detail")) == "off" or _clean(config.get("camera_mode")) == "disabled": - return [] - custom = _clean(config.get("custom_camera_prompt")) - tags = _split_tag_text(custom) - direction = _clean(config.get("orbit_direction")) - elevation = _clean(config.get("orbit_elevation_label")) - distance = _clean(config.get("orbit_distance_label")) - for value in (direction, elevation, distance): - if value and value != "auto": - tags.extend(_split_tag_text(value)) - for key in ("angle", "shot_size", "distance", "lens", "orientation", "subject_focus"): - value = _clean(config.get(key)).replace("_", " ") - if value and value != "auto": - tags.append(value) - return tags + return sdxl_tag_policy.camera_tags_from_config(config) def _camera_tags(row: dict[str, Any], directive: Any = "", config: Any = None) -> list[str]: - tags = _split_tag_text(directive) - tags.extend(_camera_tags_from_config(config if config is not None else row.get("camera_config"))) - camera_directive = _clean(row.get("camera_directive")) - if camera_directive: - tags.extend(_split_tag_text(camera_directive)) - out = [] - for tag in tags: - tag = tag.replace("0-degree front view", "(front facing:1.15)") - tag = tag.replace("front view", "(front facing:1.15)") - tag = tag.replace("right side view", "side view") - tag = tag.replace("left side view", "side view") - out.append(tag) - return out + return sdxl_tag_policy.camera_tags(row, directive, config) def _explicit_tags(text: str, nude_weight: float) -> list[str]: - lower = text.lower() - tags: list[str] = [] - if any(token in lower for token in ("fully nude", "fully exposed", "naked", "bare skin unobstructed", "explicit_nude")): - tags.append(f"(naked:{nude_weight:.2f})") - if any(token in lower for token in ("nipples", "breasts exposed", "bare breasts", "nipple")): - tags.append("nipples") - if any(token in lower for token in ("pussy", "vulva", "genitals")): - tags.append("pussy") - if any(token in lower for token in ("penis", "cock")): - tags.append("penis") - if "penetration" in lower or "thrust" in lower: - tags.append("penetration") - if "vaginal" in lower: - tags.append("pussy") - if "oral" in lower or "mouth" in lower: - tags.append("oral sex") - if "anal" in lower: - tags.append("anal sex") - if any(token in lower for token in ("semen", "ejaculates", "cum ")): - tags.append("semen") - return tags + return sdxl_tag_policy.explicit_tags(text, nude_weight) def _sdxl_tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies: - return sdxl_tag_routes.SDXLTagRouteDependencies( - clean=_clean, - row_value=_row_value, - tag_key=_tag_key, - add=_add, - add_one=_add_one, - count_tag=_count_tag, - infer_counts=_infer_counts, - normal_character_tags=_normal_character_tags, - character_tags_from_descriptor=_character_tags_from_descriptor, - metadata_family_tags=_metadata_family_tags, - formatter_hint_tags=_formatter_hint_tags, - camera_tags=_camera_tags, - explicit_tags=_explicit_tags, - ) + return sdxl_tag_policy.tag_route_dependencies() def _row_core_tags(row: dict[str, Any], nude_weight: float) -> list[str]: diff --git a/sdxl_tag_policy.py b/sdxl_tag_policy.py new file mode 100644 index 0000000..3e23ec2 --- /dev/null +++ b/sdxl_tag_policy.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +import re +from typing import Any + +try: + from . import formatter_input as input_policy + from . import route_metadata as route_metadata_policy + from . import sdxl_presets as sdxl_policy + from . import sdxl_tag_routes +except ImportError: # Allows local smoke tests with `python -c`. + import formatter_input as input_policy + import route_metadata as route_metadata_policy + import sdxl_presets as sdxl_policy + import sdxl_tag_routes + + +PROMPT_FIELD_LABELS = input_policy.prompt_field_labels() + + +def clean(value: Any) -> str: + return input_policy.clean_text(value) + + +def prompt_field(text: str, label: str) -> str: + return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS) + + +def row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str: + return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS) + + +def split_tag_text(text: Any) -> list[str]: + text = clean(text) + if not text: + return [] + text = re.sub(r"\bWoman [A-Z]'s\b", "woman's", text) + text = re.sub(r"\bMan [A-Z]'s\b", "man's", text) + text = re.sub(r"\bWoman [A-Z]\b", "woman", text) + text = re.sub(r"\bMan [A-Z]\b", "man", text) + text = re.sub( + r"\b(?:Clothing state|Visual clothing state|visible remaining styling|teaser outfit detail|softcore visual reference|Sexual scene|Role graph):\s*", + "", + text, + flags=re.IGNORECASE, + ) + text = re.sub(r"\b(?:and|with)\b", ",", text, flags=re.IGNORECASE) + parts = re.split(r"\s*[,;]\s*", text) + return [clean(part).strip(" .") for part in parts if clean(part).strip(" .")] + + +def tag_key(tag: str) -> str: + text = clean(tag).lower() + text = re.sub(r"^\((.*?):[0-9.]+\)$", r"\1", text) + text = text.strip("() ") + return text + + +def add(tags: list[str], seen: set[str], value: Any) -> None: + for tag in split_tag_text(value): + key = tag_key(tag) + if key and key not in seen: + tags.append(tag) + seen.add(key) + + +def add_one(tags: list[str], seen: set[str], tag: str) -> None: + tag = clean(tag).strip(" ,") + key = tag_key(tag) + if tag and key and key not in seen: + tags.append(tag) + seen.add(key) + + +def metadata_family_tags(row: dict[str, Any]) -> list[str]: + tags: list[str] = [] + action_family = route_metadata_policy.row_action_family(row) + tags.extend(sdxl_policy.SDXL_ACTION_FAMILY_TAGS.get(action_family, ())) + + position_family = route_metadata_policy.row_position_family(row) + tags.extend(sdxl_policy.SDXL_POSITION_FAMILY_TAGS.get(position_family, ())) + + for key in route_metadata_policy.row_position_keys(row, include_unknown=True): + key_text = clean(key) + if key_text: + tags.append(key_text.replace("_", " ")) + return tags + + +def formatter_hint_tags(*rows: dict[str, Any]) -> list[str]: + tags: list[str] = [] + for row in rows: + if not isinstance(row, dict): + continue + for hint in route_metadata_policy.row_formatter_hints(row, "sdxl"): + hint = clean(hint).strip(" ,.") + if hint and hint not in tags: + tags.append(hint) + return tags + + +def combine_tags(*parts: Any) -> str: + tags: list[str] = [] + seen: set[str] = set() + for part in parts: + add(tags, seen, part) + return ", ".join(tags) + + +def combine_negative(*parts: Any) -> str: + return combine_tags(*(part for part in parts if clean(part))) + + +def count_tag(women_count: int = 0, men_count: int = 0) -> list[str]: + tags = [] + if women_count > 0: + tags.append(f"{women_count}woman" if women_count == 1 else f"{women_count}women") + if men_count > 0: + tags.append(f"{men_count}man" if men_count == 1 else f"{men_count}men") + return tags + + +def infer_counts(row: dict[str, Any]) -> tuple[int, int]: + try: + women = int(row.get("women_count") or 0) + men = int(row.get("men_count") or 0) + except (TypeError, ValueError): + women = men = 0 + if women or men: + return women, men + subject = clean(row.get("subject_type") or row.get("primary_subject")).lower() + phrase = clean(row.get("subject_phrase")).lower() + text = f"{subject} {phrase}" + if "two women" in text: + return 2, 0 + if "two men" in text: + return 0, 2 + if "woman and" in text or "woman a" in text and "man a" in text: + return 1, 1 + if "group" in text: + return 2, 2 + if "man" in text and "woman" not in text: + return 0, 1 + return 1, 0 + + +def character_tags_from_descriptor(descriptor: Any) -> list[str]: + text = clean(descriptor) + text = re.sub(r"\bWoman [A-Z]\s*/\s*primary creator:\s*", "", text) + text = re.sub(r"\b(?:Woman|Man) [A-Z]:\s*", "", text) + text = re.sub(r"\balongside\b", ",", text, flags=re.IGNORECASE) + parts = split_tag_text(text) + cleaned = [] + for part in parts: + part = re.sub(r"\bfigure\b", "build", part, flags=re.IGNORECASE) + part = part.replace("adult adult", "adult") + cleaned.append(part) + return cleaned + + +def normal_character_tags(row: dict[str, Any]) -> list[str]: + descriptor = ( + clean(row.get("cast_descriptor_text")) + or prompt_field(row.get("prompt", ""), "Characters") + or prompt_field(row.get("prompt", ""), "Cast descriptors") + ) + if descriptor: + return character_tags_from_descriptor(descriptor) + + parts = [ + clean(row.get("age") or row.get("age_band")), + clean(row.get("subject_phrase") or row.get("subject_type") or row.get("primary_subject")), + clean(row.get("body_phrase") or row.get("body") or row.get("body_type")), + clean(row.get("skin")), + clean(row.get("hair")), + clean(row.get("eyes")), + ] + return [part for part in parts if part and part not in ("woman", "man", "single_any")] + + +def camera_tags_from_config(config: Any) -> list[str]: + if not isinstance(config, dict): + return [] + if clean(config.get("camera_detail")) == "off" or clean(config.get("camera_mode")) == "disabled": + return [] + custom = clean(config.get("custom_camera_prompt")) + tags = split_tag_text(custom) + direction = clean(config.get("orbit_direction")) + elevation = clean(config.get("orbit_elevation_label")) + distance = clean(config.get("orbit_distance_label")) + for value in (direction, elevation, distance): + if value and value != "auto": + tags.extend(split_tag_text(value)) + for key in ("angle", "shot_size", "distance", "lens", "orientation", "subject_focus"): + value = clean(config.get(key)).replace("_", " ") + if value and value != "auto": + tags.append(value) + return tags + + +def camera_tags(row: dict[str, Any], directive: Any = "", config: Any = None) -> list[str]: + tags = split_tag_text(directive) + tags.extend(camera_tags_from_config(config if config is not None else row.get("camera_config"))) + camera_directive = clean(row.get("camera_directive")) + if camera_directive: + tags.extend(split_tag_text(camera_directive)) + out = [] + for tag in tags: + tag = tag.replace("0-degree front view", "(front facing:1.15)") + tag = tag.replace("front view", "(front facing:1.15)") + tag = tag.replace("right side view", "side view") + tag = tag.replace("left side view", "side view") + out.append(tag) + return out + + +def explicit_tags(text: str, nude_weight: float) -> list[str]: + lower = text.lower() + tags: list[str] = [] + if any(token in lower for token in ("fully nude", "fully exposed", "naked", "bare skin unobstructed", "explicit_nude")): + tags.append(f"(naked:{nude_weight:.2f})") + if any(token in lower for token in ("nipples", "breasts exposed", "bare breasts", "nipple")): + tags.append("nipples") + if any(token in lower for token in ("pussy", "vulva", "genitals")): + tags.append("pussy") + if any(token in lower for token in ("penis", "cock")): + tags.append("penis") + if "penetration" in lower or "thrust" in lower: + tags.append("penetration") + if "vaginal" in lower: + tags.append("pussy") + if "oral" in lower or "mouth" in lower: + tags.append("oral sex") + if "anal" in lower: + tags.append("anal sex") + if any(token in lower for token in ("semen", "ejaculates", "cum ")): + tags.append("semen") + return tags + + +def tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies: + return sdxl_tag_routes.SDXLTagRouteDependencies( + clean=clean, + row_value=row_value, + tag_key=tag_key, + add=add, + add_one=add_one, + count_tag=count_tag, + infer_counts=infer_counts, + normal_character_tags=normal_character_tags, + character_tags_from_descriptor=character_tags_from_descriptor, + metadata_family_tags=metadata_family_tags, + formatter_hint_tags=formatter_hint_tags, + camera_tags=camera_tags, + explicit_tags=explicit_tags, + ) diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 9951c91..6e5a1ee 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -76,6 +76,7 @@ import row_subject_route # noqa: E402 import server_routes # noqa: E402 import sdxl_formatter # noqa: E402 import sdxl_presets # noqa: E402 +import sdxl_tag_policy # noqa: E402 import sdxl_tag_routes # noqa: E402 import seed_config # noqa: E402 import krea_pov # noqa: E402 @@ -2489,6 +2490,41 @@ def smoke_sdxl_presets_policy() -> None: _expect("score_9" not in profiled_prompt, "SDXL photo profile should switch away from Pony score quality tail") +def smoke_sdxl_tag_policy() -> None: + row = _fixture_hardcore_row( + action_family="oral", + position_family="oral", + position_key="kneeling_oral", + position_keys=["kneeling_oral"], + formatter_hints={"sdxl": ["policy route tag"]}, + ) + _expect( + sdxl_formatter._split_tag_text("Woman A with camera, Man A") + == sdxl_tag_policy.split_tag_text("Woman A with camera, Man A"), + "SDXL formatter split helper should delegate to sdxl_tag_policy", + ) + _expect( + sdxl_formatter._metadata_family_tags(row) == sdxl_tag_policy.metadata_family_tags(row), + "SDXL formatter metadata-family helper should delegate to sdxl_tag_policy", + ) + _expect( + sdxl_formatter._camera_tags(row) == sdxl_tag_policy.camera_tags(row), + "SDXL formatter camera helper should delegate to sdxl_tag_policy", + ) + _expect( + sdxl_formatter._combine_tags("a, b", "a", "c") + == sdxl_tag_policy.combine_tags("a, b", "a", "c") + == "a, b, c", + "SDXL tag combining changed", + ) + deps = sdxl_formatter._sdxl_tag_route_dependencies() + _expect(deps.tag_key is sdxl_tag_policy.tag_key, "SDXL route deps lost policy tag_key") + _expect(deps.normal_character_tags is sdxl_tag_policy.normal_character_tags, "SDXL route deps lost character tag policy") + _expect(deps.metadata_family_tags is sdxl_tag_policy.metadata_family_tags, "SDXL route deps lost metadata family policy") + _expect(deps.camera_tags is sdxl_tag_policy.camera_tags, "SDXL route deps lost camera tag policy") + _expect(deps.explicit_tags is sdxl_tag_policy.explicit_tags, "SDXL route deps lost explicit tag policy") + + def smoke_sdxl_tag_routes() -> None: row = _fixture_hardcore_row( formatter_hints={ @@ -5315,6 +5351,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [ ("caption_policy", smoke_caption_policy), ("caption_metadata_routes", smoke_caption_metadata_routes), ("sdxl_presets_policy", smoke_sdxl_presets_policy), + ("sdxl_tag_policy", smoke_sdxl_tag_policy), ("sdxl_tag_routes", smoke_sdxl_tag_routes), ("hardcore_position_config_policy", smoke_hardcore_position_config_policy), ("row_route_metadata_policy", smoke_row_route_metadata_policy),