Align SDXL soft pair tags

This commit is contained in:
2026-06-27 16:37:31 +02:00
parent 9cd1f03bfe
commit 96ff37a5a0
4 changed files with 69 additions and 7 deletions
+45 -6
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any, Callable
@@ -40,6 +41,33 @@ class SDXLTagRouteDependencies:
formatter_hint_tags: Callable[..., list[str]]
camera_tags: Callable[..., list[str]]
explicit_tags: Callable[[str, float], list[str]]
softcore_pair_tags: Callable[[dict[str, Any], dict[str, Any]], list[str]]
def _descriptor_counts(root: dict[str, Any]) -> tuple[int, int]:
descriptors = root.get("shared_cast_descriptors")
if not isinstance(descriptors, list):
return 0, 0
women = 0
men = 0
for descriptor in descriptors:
text = str(descriptor).lower()
if re.search(r"\bwoman\b", text):
women += 1
elif re.search(r"\bman\b", text):
men += 1
return women, men
def _pair_counts(row: dict[str, Any], root: dict[str, Any]) -> tuple[int, int]:
try:
women = int(root.get("hardcore_women_count") or row.get("women_count") or 0)
men = int(root.get("hardcore_men_count") or row.get("men_count") or 0)
except (TypeError, ValueError):
women, men = 0, 0
if women or men:
return women, men
return _descriptor_counts(root)
def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
@@ -93,16 +121,29 @@ def soft_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies
root = request.root
tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags
seen = {deps.tag_key(tag) for tag in tags}
women, men = _pair_counts(row, root)
for tag in deps.count_tag(women, men):
deps.add_one(tags, seen, tag)
for tag in deps.formatter_hint_tags(root):
deps.add(tags, seen, tag)
descriptor = deps.clean(root.get("shared_descriptor"))
if descriptor and not any("woman" in deps.tag_key(tag) for tag in tags):
descriptors = root.get("shared_cast_descriptors")
if isinstance(descriptors, list) and descriptors:
for descriptor in descriptors:
for tag in deps.character_tags_from_descriptor(descriptor):
deps.add_one(tags, seen, tag)
else:
descriptor = deps.clean(root.get("shared_descriptor"))
for tag in deps.character_tags_from_descriptor(descriptor):
deps.add_one(tags, seen, tag)
partner = root.get("softcore_partner_styling")
if isinstance(partner, dict):
deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item)))
deps.add(tags, seen, partner.get("pose"))
for tag in deps.softcore_pair_tags(row, root):
deps.add_one(tags, seen, tag)
deps.add_one(tags, seen, "sexy")
deps.add_one(tags, seen, "looking at viewer")
return SDXLTagRoute(tags)
@@ -113,10 +154,8 @@ def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies
root = request.root
tags: list[str] = []
seen: set[str] = set()
try:
women = int(root.get("hardcore_women_count") or row.get("women_count") or 1)
men = int(root.get("hardcore_men_count") or row.get("men_count") or 1)
except (TypeError, ValueError):
women, men = _pair_counts(row, root)
if not women and not men:
women, men = 1, 1
for tag in deps.count_tag(women, men):
deps.add_one(tags, seen, tag)