Align SDXL soft pair tags

This commit is contained in:
2026-06-27 16:37:31 +02:00
parent 9cd1f03bfe
commit 96ff37a5a0
4 changed files with 69 additions and 7 deletions
+14
View File
@@ -8,11 +8,13 @@ try:
from . import route_metadata as route_metadata_policy from . import route_metadata as route_metadata_policy
from . import sdxl_presets as sdxl_policy from . import sdxl_presets as sdxl_policy
from . import sdxl_tag_routes from . import sdxl_tag_routes
from . import softcore_text_policy
except ImportError: # Allows local smoke tests with `python -c`. except ImportError: # Allows local smoke tests with `python -c`.
import formatter_input as input_policy import formatter_input as input_policy
import route_metadata as route_metadata_policy import route_metadata as route_metadata_policy
import sdxl_presets as sdxl_policy import sdxl_presets as sdxl_policy
import sdxl_tag_routes import sdxl_tag_routes
import softcore_text_policy
PROMPT_FIELD_LABELS = input_policy.prompt_field_labels() PROMPT_FIELD_LABELS = input_policy.prompt_field_labels()
@@ -239,6 +241,17 @@ def explicit_tags(text: str, nude_weight: float) -> list[str]:
return tags return tags
def softcore_pair_tags(row: dict[str, Any], root: dict[str, Any]) -> list[str]:
tags = ["softcore teaser", softcore_text_policy.softcore_style_tag()]
options = root.get("options") if isinstance(root.get("options"), dict) else {}
cast_mode = clean(options.get("softcore_cast")).lower()
if cast_mode == "same_as_hardcore" or root.get("shared_cast_descriptors"):
tags.append("same-cast creator frame")
elif "solo" in clean(row.get("subject_type") or row.get("primary_subject")).lower():
tags.append("solo creator frame")
return tags
def tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies: def tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies:
return sdxl_tag_routes.SDXLTagRouteDependencies( return sdxl_tag_routes.SDXLTagRouteDependencies(
clean=clean, clean=clean,
@@ -254,4 +267,5 @@ def tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies:
formatter_hint_tags=formatter_hint_tags, formatter_hint_tags=formatter_hint_tags,
camera_tags=camera_tags, camera_tags=camera_tags,
explicit_tags=explicit_tags, explicit_tags=explicit_tags,
softcore_pair_tags=softcore_pair_tags,
) )
+45 -6
View File
@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable from typing import Any, Callable
@@ -40,6 +41,33 @@ class SDXLTagRouteDependencies:
formatter_hint_tags: Callable[..., list[str]] formatter_hint_tags: Callable[..., list[str]]
camera_tags: Callable[..., list[str]] camera_tags: Callable[..., list[str]]
explicit_tags: Callable[[str, float], list[str]] explicit_tags: Callable[[str, float], list[str]]
softcore_pair_tags: Callable[[dict[str, Any], dict[str, Any]], list[str]]
def _descriptor_counts(root: dict[str, Any]) -> tuple[int, int]:
descriptors = root.get("shared_cast_descriptors")
if not isinstance(descriptors, list):
return 0, 0
women = 0
men = 0
for descriptor in descriptors:
text = str(descriptor).lower()
if re.search(r"\bwoman\b", text):
women += 1
elif re.search(r"\bman\b", text):
men += 1
return women, men
def _pair_counts(row: dict[str, Any], root: dict[str, Any]) -> tuple[int, int]:
try:
women = int(root.get("hardcore_women_count") or row.get("women_count") or 0)
men = int(root.get("hardcore_men_count") or row.get("men_count") or 0)
except (TypeError, ValueError):
women, men = 0, 0
if women or men:
return women, men
return _descriptor_counts(root)
def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute: def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
@@ -93,16 +121,29 @@ def soft_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies
root = request.root root = request.root
tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags
seen = {deps.tag_key(tag) for tag in tags} seen = {deps.tag_key(tag) for tag in tags}
women, men = _pair_counts(row, root)
for tag in deps.count_tag(women, men):
deps.add_one(tags, seen, tag)
for tag in deps.formatter_hint_tags(root): for tag in deps.formatter_hint_tags(root):
deps.add(tags, seen, tag) deps.add(tags, seen, tag)
descriptor = deps.clean(root.get("shared_descriptor"))
if descriptor and not any("woman" in deps.tag_key(tag) for tag in tags): descriptors = root.get("shared_cast_descriptors")
if isinstance(descriptors, list) and descriptors:
for descriptor in descriptors:
for tag in deps.character_tags_from_descriptor(descriptor): for tag in deps.character_tags_from_descriptor(descriptor):
deps.add_one(tags, seen, tag) deps.add_one(tags, seen, tag)
else:
descriptor = deps.clean(root.get("shared_descriptor"))
for tag in deps.character_tags_from_descriptor(descriptor):
deps.add_one(tags, seen, tag)
partner = root.get("softcore_partner_styling") partner = root.get("softcore_partner_styling")
if isinstance(partner, dict): if isinstance(partner, dict):
deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item))) deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item)))
deps.add(tags, seen, partner.get("pose")) deps.add(tags, seen, partner.get("pose"))
for tag in deps.softcore_pair_tags(row, root):
deps.add_one(tags, seen, tag)
deps.add_one(tags, seen, "sexy") deps.add_one(tags, seen, "sexy")
deps.add_one(tags, seen, "looking at viewer") deps.add_one(tags, seen, "looking at viewer")
return SDXLTagRoute(tags) return SDXLTagRoute(tags)
@@ -113,10 +154,8 @@ def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies
root = request.root root = request.root
tags: list[str] = [] tags: list[str] = []
seen: set[str] = set() seen: set[str] = set()
try: women, men = _pair_counts(row, root)
women = int(root.get("hardcore_women_count") or row.get("women_count") or 1) if not women and not men:
men = int(root.get("hardcore_men_count") or row.get("men_count") or 1)
except (TypeError, ValueError):
women, men = 1, 1 women, men = 1, 1
for tag in deps.count_tag(women, men): for tag in deps.count_tag(women, men):
deps.add_one(tags, seen, tag) deps.add_one(tags, seen, tag)
+5 -1
View File
@@ -44,4 +44,8 @@ def softcore_caption_setup_phrase(*, same_cast: bool, target_auto: bool = False)
def softcore_style_directive() -> str: def softcore_style_directive() -> str:
return "Use seductive creator-shot teaser styling." return f"Use {softcore_style_tag()}."
def softcore_style_tag() -> str:
return "seductive creator-shot teaser styling"
+5
View File
@@ -4128,6 +4128,7 @@ def smoke_sdxl_tag_policy() -> None:
_expect(deps.metadata_family_tags is sdxl_tag_policy.metadata_family_tags, "SDXL route deps lost metadata family policy") _expect(deps.metadata_family_tags is sdxl_tag_policy.metadata_family_tags, "SDXL route deps lost metadata family policy")
_expect(deps.camera_tags is sdxl_tag_policy.camera_tags, "SDXL route deps lost camera tag policy") _expect(deps.camera_tags is sdxl_tag_policy.camera_tags, "SDXL route deps lost camera tag policy")
_expect(deps.explicit_tags is sdxl_tag_policy.explicit_tags, "SDXL route deps lost explicit tag policy") _expect(deps.explicit_tags is sdxl_tag_policy.explicit_tags, "SDXL route deps lost explicit tag policy")
_expect(deps.softcore_pair_tags is sdxl_tag_policy.softcore_pair_tags, "SDXL route deps lost softcore pair tag policy")
def smoke_sdxl_tag_routes() -> None: def smoke_sdxl_tag_routes() -> None:
@@ -6586,6 +6587,10 @@ def smoke_formatter_metadata_fixtures() -> None:
_expect_trigger_once("fixture_external_pair.sdxl_soft", sdxl_pair.get("sdxl_softcore_prompt"), SdxlTrigger) _expect_trigger_once("fixture_external_pair.sdxl_soft", sdxl_pair.get("sdxl_softcore_prompt"), SdxlTrigger)
_expect_trigger_once("fixture_external_pair.sdxl_hard", sdxl_pair.get("sdxl_hardcore_prompt"), SdxlTrigger) _expect_trigger_once("fixture_external_pair.sdxl_hard", sdxl_pair.get("sdxl_hardcore_prompt"), SdxlTrigger)
_expect("black buttoned shirt" in sdxl_soft, "External pair SDXL soft route lost embedded partner styling") _expect("black buttoned shirt" in sdxl_soft, "External pair SDXL soft route lost embedded partner styling")
_expect("1man" in sdxl_soft, "External pair SDXL soft route lost partner count tag")
_expect("40-year-old adult man" in sdxl_soft, "External pair SDXL soft route lost partner descriptor")
_expect("softcore teaser" in sdxl_soft, "External pair SDXL soft route lost softcore teaser tag")
_expect_no_softcore_noise("fixture_external_pair.sdxl_soft", sdxl_pair.get("sdxl_softcore_prompt"))
_expect("red satin lingerie set" in sdxl_hard, "External pair SDXL hard route lost embedded clothing state") _expect("red satin lingerie set" in sdxl_hard, "External pair SDXL hard route lost embedded clothing state")
_expect("row hard right-side view" in sdxl_hard, "External pair SDXL hard route lost embedded camera directive") _expect("row hard right-side view" in sdxl_hard, "External pair SDXL hard route lost embedded camera directive")
_expect_no_duplicate_comma_items("fixture_external_pair.sdxl_negative", sdxl_pair.get("negative_prompt")) _expect_no_duplicate_comma_items("fixture_external_pair.sdxl_negative", sdxl_pair.get("negative_prompt"))