Align SDXL soft pair tags
This commit is contained in:
@@ -8,11 +8,13 @@ try:
|
||||
from . import route_metadata as route_metadata_policy
|
||||
from . import sdxl_presets as sdxl_policy
|
||||
from . import sdxl_tag_routes
|
||||
from . import softcore_text_policy
|
||||
except ImportError: # Allows local smoke tests with `python -c`.
|
||||
import formatter_input as input_policy
|
||||
import route_metadata as route_metadata_policy
|
||||
import sdxl_presets as sdxl_policy
|
||||
import sdxl_tag_routes
|
||||
import softcore_text_policy
|
||||
|
||||
|
||||
PROMPT_FIELD_LABELS = input_policy.prompt_field_labels()
|
||||
@@ -239,6 +241,17 @@ def explicit_tags(text: str, nude_weight: float) -> list[str]:
|
||||
return tags
|
||||
|
||||
|
||||
def softcore_pair_tags(row: dict[str, Any], root: dict[str, Any]) -> list[str]:
|
||||
tags = ["softcore teaser", softcore_text_policy.softcore_style_tag()]
|
||||
options = root.get("options") if isinstance(root.get("options"), dict) else {}
|
||||
cast_mode = clean(options.get("softcore_cast")).lower()
|
||||
if cast_mode == "same_as_hardcore" or root.get("shared_cast_descriptors"):
|
||||
tags.append("same-cast creator frame")
|
||||
elif "solo" in clean(row.get("subject_type") or row.get("primary_subject")).lower():
|
||||
tags.append("solo creator frame")
|
||||
return tags
|
||||
|
||||
|
||||
def tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies:
|
||||
return sdxl_tag_routes.SDXLTagRouteDependencies(
|
||||
clean=clean,
|
||||
@@ -254,4 +267,5 @@ def tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies:
|
||||
formatter_hint_tags=formatter_hint_tags,
|
||||
camera_tags=camera_tags,
|
||||
explicit_tags=explicit_tags,
|
||||
softcore_pair_tags=softcore_pair_tags,
|
||||
)
|
||||
|
||||
+45
-6
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
@@ -40,6 +41,33 @@ class SDXLTagRouteDependencies:
|
||||
formatter_hint_tags: Callable[..., list[str]]
|
||||
camera_tags: Callable[..., list[str]]
|
||||
explicit_tags: Callable[[str, float], list[str]]
|
||||
softcore_pair_tags: Callable[[dict[str, Any], dict[str, Any]], list[str]]
|
||||
|
||||
|
||||
def _descriptor_counts(root: dict[str, Any]) -> tuple[int, int]:
|
||||
descriptors = root.get("shared_cast_descriptors")
|
||||
if not isinstance(descriptors, list):
|
||||
return 0, 0
|
||||
women = 0
|
||||
men = 0
|
||||
for descriptor in descriptors:
|
||||
text = str(descriptor).lower()
|
||||
if re.search(r"\bwoman\b", text):
|
||||
women += 1
|
||||
elif re.search(r"\bman\b", text):
|
||||
men += 1
|
||||
return women, men
|
||||
|
||||
|
||||
def _pair_counts(row: dict[str, Any], root: dict[str, Any]) -> tuple[int, int]:
|
||||
try:
|
||||
women = int(root.get("hardcore_women_count") or row.get("women_count") or 0)
|
||||
men = int(root.get("hardcore_men_count") or row.get("men_count") or 0)
|
||||
except (TypeError, ValueError):
|
||||
women, men = 0, 0
|
||||
if women or men:
|
||||
return women, men
|
||||
return _descriptor_counts(root)
|
||||
|
||||
|
||||
def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
|
||||
@@ -93,16 +121,29 @@ def soft_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies
|
||||
root = request.root
|
||||
tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags
|
||||
seen = {deps.tag_key(tag) for tag in tags}
|
||||
women, men = _pair_counts(row, root)
|
||||
for tag in deps.count_tag(women, men):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
for tag in deps.formatter_hint_tags(root):
|
||||
deps.add(tags, seen, tag)
|
||||
descriptor = deps.clean(root.get("shared_descriptor"))
|
||||
if descriptor and not any("woman" in deps.tag_key(tag) for tag in tags):
|
||||
|
||||
descriptors = root.get("shared_cast_descriptors")
|
||||
if isinstance(descriptors, list) and descriptors:
|
||||
for descriptor in descriptors:
|
||||
for tag in deps.character_tags_from_descriptor(descriptor):
|
||||
deps.add_one(tags, seen, tag)
|
||||
else:
|
||||
descriptor = deps.clean(root.get("shared_descriptor"))
|
||||
for tag in deps.character_tags_from_descriptor(descriptor):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
partner = root.get("softcore_partner_styling")
|
||||
if isinstance(partner, dict):
|
||||
deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item)))
|
||||
deps.add(tags, seen, partner.get("pose"))
|
||||
for tag in deps.softcore_pair_tags(row, root):
|
||||
deps.add_one(tags, seen, tag)
|
||||
deps.add_one(tags, seen, "sexy")
|
||||
deps.add_one(tags, seen, "looking at viewer")
|
||||
return SDXLTagRoute(tags)
|
||||
@@ -113,10 +154,8 @@ def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies
|
||||
root = request.root
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
try:
|
||||
women = int(root.get("hardcore_women_count") or row.get("women_count") or 1)
|
||||
men = int(root.get("hardcore_men_count") or row.get("men_count") or 1)
|
||||
except (TypeError, ValueError):
|
||||
women, men = _pair_counts(row, root)
|
||||
if not women and not men:
|
||||
women, men = 1, 1
|
||||
for tag in deps.count_tag(women, men):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
@@ -44,4 +44,8 @@ def softcore_caption_setup_phrase(*, same_cast: bool, target_auto: bool = False)
|
||||
|
||||
|
||||
def softcore_style_directive() -> str:
|
||||
return "Use seductive creator-shot teaser styling."
|
||||
return f"Use {softcore_style_tag()}."
|
||||
|
||||
|
||||
def softcore_style_tag() -> str:
|
||||
return "seductive creator-shot teaser styling"
|
||||
|
||||
@@ -4128,6 +4128,7 @@ def smoke_sdxl_tag_policy() -> None:
|
||||
_expect(deps.metadata_family_tags is sdxl_tag_policy.metadata_family_tags, "SDXL route deps lost metadata family policy")
|
||||
_expect(deps.camera_tags is sdxl_tag_policy.camera_tags, "SDXL route deps lost camera tag policy")
|
||||
_expect(deps.explicit_tags is sdxl_tag_policy.explicit_tags, "SDXL route deps lost explicit tag policy")
|
||||
_expect(deps.softcore_pair_tags is sdxl_tag_policy.softcore_pair_tags, "SDXL route deps lost softcore pair tag policy")
|
||||
|
||||
|
||||
def smoke_sdxl_tag_routes() -> None:
|
||||
@@ -6586,6 +6587,10 @@ def smoke_formatter_metadata_fixtures() -> None:
|
||||
_expect_trigger_once("fixture_external_pair.sdxl_soft", sdxl_pair.get("sdxl_softcore_prompt"), SdxlTrigger)
|
||||
_expect_trigger_once("fixture_external_pair.sdxl_hard", sdxl_pair.get("sdxl_hardcore_prompt"), SdxlTrigger)
|
||||
_expect("black buttoned shirt" in sdxl_soft, "External pair SDXL soft route lost embedded partner styling")
|
||||
_expect("1man" in sdxl_soft, "External pair SDXL soft route lost partner count tag")
|
||||
_expect("40-year-old adult man" in sdxl_soft, "External pair SDXL soft route lost partner descriptor")
|
||||
_expect("softcore teaser" in sdxl_soft, "External pair SDXL soft route lost softcore teaser tag")
|
||||
_expect_no_softcore_noise("fixture_external_pair.sdxl_soft", sdxl_pair.get("sdxl_softcore_prompt"))
|
||||
_expect("red satin lingerie set" in sdxl_hard, "External pair SDXL hard route lost embedded clothing state")
|
||||
_expect("row hard right-side view" in sdxl_hard, "External pair SDXL hard route lost embedded camera directive")
|
||||
_expect_no_duplicate_comma_items("fixture_external_pair.sdxl_negative", sdxl_pair.get("negative_prompt"))
|
||||
|
||||
Reference in New Issue
Block a user