Align SDXL soft pair tags
This commit is contained in:
+45
-6
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
@@ -40,6 +41,33 @@ class SDXLTagRouteDependencies:
|
||||
formatter_hint_tags: Callable[..., list[str]]
|
||||
camera_tags: Callable[..., list[str]]
|
||||
explicit_tags: Callable[[str, float], list[str]]
|
||||
softcore_pair_tags: Callable[[dict[str, Any], dict[str, Any]], list[str]]
|
||||
|
||||
|
||||
def _descriptor_counts(root: dict[str, Any]) -> tuple[int, int]:
|
||||
descriptors = root.get("shared_cast_descriptors")
|
||||
if not isinstance(descriptors, list):
|
||||
return 0, 0
|
||||
women = 0
|
||||
men = 0
|
||||
for descriptor in descriptors:
|
||||
text = str(descriptor).lower()
|
||||
if re.search(r"\bwoman\b", text):
|
||||
women += 1
|
||||
elif re.search(r"\bman\b", text):
|
||||
men += 1
|
||||
return women, men
|
||||
|
||||
|
||||
def _pair_counts(row: dict[str, Any], root: dict[str, Any]) -> tuple[int, int]:
|
||||
try:
|
||||
women = int(root.get("hardcore_women_count") or row.get("women_count") or 0)
|
||||
men = int(root.get("hardcore_men_count") or row.get("men_count") or 0)
|
||||
except (TypeError, ValueError):
|
||||
women, men = 0, 0
|
||||
if women or men:
|
||||
return women, men
|
||||
return _descriptor_counts(root)
|
||||
|
||||
|
||||
def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
|
||||
@@ -93,16 +121,29 @@ def soft_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies
|
||||
root = request.root
|
||||
tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags
|
||||
seen = {deps.tag_key(tag) for tag in tags}
|
||||
women, men = _pair_counts(row, root)
|
||||
for tag in deps.count_tag(women, men):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
for tag in deps.formatter_hint_tags(root):
|
||||
deps.add(tags, seen, tag)
|
||||
descriptor = deps.clean(root.get("shared_descriptor"))
|
||||
if descriptor and not any("woman" in deps.tag_key(tag) for tag in tags):
|
||||
|
||||
descriptors = root.get("shared_cast_descriptors")
|
||||
if isinstance(descriptors, list) and descriptors:
|
||||
for descriptor in descriptors:
|
||||
for tag in deps.character_tags_from_descriptor(descriptor):
|
||||
deps.add_one(tags, seen, tag)
|
||||
else:
|
||||
descriptor = deps.clean(root.get("shared_descriptor"))
|
||||
for tag in deps.character_tags_from_descriptor(descriptor):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
partner = root.get("softcore_partner_styling")
|
||||
if isinstance(partner, dict):
|
||||
deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item)))
|
||||
deps.add(tags, seen, partner.get("pose"))
|
||||
for tag in deps.softcore_pair_tags(row, root):
|
||||
deps.add_one(tags, seen, tag)
|
||||
deps.add_one(tags, seen, "sexy")
|
||||
deps.add_one(tags, seen, "looking at viewer")
|
||||
return SDXLTagRoute(tags)
|
||||
@@ -113,10 +154,8 @@ def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies
|
||||
root = request.root
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
try:
|
||||
women = int(root.get("hardcore_women_count") or row.get("women_count") or 1)
|
||||
men = int(root.get("hardcore_men_count") or row.get("men_count") or 1)
|
||||
except (TypeError, ValueError):
|
||||
women, men = _pair_counts(row, root)
|
||||
if not women and not men:
|
||||
women, men = 1, 1
|
||||
for tag in deps.count_tag(women, men):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
Reference in New Issue
Block a user