Files
ComfyUI-Ethanfel-Prompt-Bui…/sdxl_tag_routes.py
T

242 lines
8.0 KiB
Python

from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any, Callable
@dataclass(frozen=True)
class SDXLRowTagRequest:
row: dict[str, Any]
nude_weight: float
@dataclass(frozen=True)
class SDXLPairTagRequest:
row: dict[str, Any]
root: dict[str, Any]
nude_weight: float
@dataclass(frozen=True)
class SDXLTagRoute:
tags: list[str]
def as_text(self) -> str:
return ", ".join(self.tags)
@dataclass(frozen=True)
class SDXLTagRouteDependencies:
clean: Callable[[Any], str]
row_value: Callable[[dict[str, Any], str, tuple[str, ...]], str]
tag_key: Callable[[str], str]
add: Callable[[list[str], set[str], Any], None]
add_one: Callable[[list[str], set[str], str], None]
count_tag: Callable[[int, int], list[str]]
infer_counts: Callable[[dict[str, Any]], tuple[int, int]]
normal_character_tags: Callable[[dict[str, Any]], list[str]]
character_tags_from_descriptor: Callable[[Any], list[str]]
metadata_family_tags: Callable[[dict[str, Any]], list[str]]
formatter_hint_tags: Callable[..., list[str]]
camera_tags: Callable[..., list[str]]
explicit_tags: Callable[[str, float], list[str]]
softcore_pair_tags: Callable[[dict[str, Any], dict[str, Any]], list[str]]
def _descriptor_counts(root: dict[str, Any]) -> tuple[int, int]:
descriptors = root.get("shared_cast_descriptors")
if not isinstance(descriptors, list):
return 0, 0
women = 0
men = 0
for descriptor in descriptors:
text = str(descriptor).lower()
if re.search(r"\bwoman\b", text):
women += 1
elif re.search(r"\bman\b", text):
men += 1
return women, men
def _pair_counts(row: dict[str, Any], root: dict[str, Any]) -> tuple[int, int]:
try:
women = int(root.get("hardcore_women_count") or row.get("women_count") or 0)
men = int(root.get("hardcore_men_count") or row.get("men_count") or 0)
except (TypeError, ValueError):
women, men = 0, 0
if women or men:
return women, men
return _descriptor_counts(root)
def _row_explicit_signal_text(
row: dict[str, Any],
*,
item: str,
pose: str,
role_graph: str,
expression: str,
composition: str,
deps: SDXLTagRouteDependencies,
) -> str:
values = (
item,
pose,
role_graph,
deps.clean(row.get("hardcore_clothing_state")),
deps.clean(row.get("clothing_state")),
deps.clean(row.get("clothing")),
deps.clean(row.get("scene_kind")),
expression,
composition,
)
return " ".join(deps.clean(value) for value in values if deps.clean(value))
def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
row = request.row
tags: list[str] = []
seen: set[str] = set()
women, men = deps.infer_counts(row)
for tag in deps.count_tag(women, men):
deps.add_one(tags, seen, tag)
for tag in deps.normal_character_tags(row):
deps.add_one(tags, seen, tag)
for tag in deps.metadata_family_tags(row):
deps.add_one(tags, seen, tag)
for tag in deps.formatter_hint_tags(row):
deps.add(tags, seen, tag)
item = deps.row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or deps.clean(
row.get("custom_item")
)
pose = deps.row_value(row, "pose", ("Sexual pose", "Pose"))
role_graph = deps.clean(row.get("source_role_graph") or row.get("role_graph"))
scene = deps.row_value(row, "scene_text", ("Setting", "Scene")) or deps.clean(row.get("scene"))
expression = deps.row_value(row, "character_expression_text") or deps.row_value(
row,
"expression",
("Facial expressions", "Facial expression"),
)
composition = deps.row_value(row, "composition", ("Composition",))
for value in (
item,
pose,
role_graph,
scene and f"in {scene}",
expression,
composition,
):
deps.add(tags, seen, value)
for tag in deps.camera_tags(row):
deps.add_one(tags, seen, tag)
combined = _row_explicit_signal_text(
row,
item=item,
pose=pose,
role_graph=role_graph,
expression=expression,
composition=composition,
deps=deps,
)
for tag in deps.explicit_tags(combined, request.nude_weight):
deps.add_one(tags, seen, tag)
return SDXLTagRoute(tags)
def soft_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
row = request.row
root = request.root
tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags
seen = {deps.tag_key(tag) for tag in tags}
women, men = _pair_counts(row, root)
for tag in deps.count_tag(women, men):
deps.add_one(tags, seen, tag)
for tag in deps.formatter_hint_tags(root):
deps.add(tags, seen, tag)
descriptors = root.get("shared_cast_descriptors")
if isinstance(descriptors, list) and descriptors:
for descriptor in descriptors:
for tag in deps.character_tags_from_descriptor(descriptor):
deps.add_one(tags, seen, tag)
else:
descriptor = deps.clean(root.get("shared_descriptor"))
for tag in deps.character_tags_from_descriptor(descriptor):
deps.add_one(tags, seen, tag)
partner = root.get("softcore_partner_styling")
if isinstance(partner, dict):
deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item)))
deps.add(tags, seen, partner.get("pose"))
for tag in deps.softcore_pair_tags(row, root):
deps.add_one(tags, seen, tag)
deps.add_one(tags, seen, "sexy")
deps.add_one(tags, seen, "looking at viewer")
return SDXLTagRoute(tags)
def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
row = request.row
root = request.root
tags: list[str] = []
seen: set[str] = set()
women, men = _pair_counts(row, root)
if not women and not men:
women, men = 1, 1
for tag in deps.count_tag(women, men):
deps.add_one(tags, seen, tag)
descriptors = root.get("shared_cast_descriptors")
if isinstance(descriptors, list):
for descriptor in descriptors:
for tag in deps.character_tags_from_descriptor(descriptor):
deps.add_one(tags, seen, tag)
else:
for tag in deps.normal_character_tags(row):
deps.add_one(tags, seen, tag)
for tag in deps.metadata_family_tags(row):
deps.add_one(tags, seen, tag)
for tag in deps.formatter_hint_tags(row, root):
deps.add(tags, seen, tag)
hard_scene = deps.clean(row.get("scene_text"))
hard_item = deps.clean(row.get("item"))
hard_role = deps.clean(row.get("source_role_graph") or row.get("role_graph"))
hard_clothing = deps.clean(root.get("hardcore_clothing_state"))
expression = deps.clean(row.get("character_expression_text") or row.get("expression"))
composition = deps.clean(row.get("composition"))
for value in (
hard_role,
hard_item,
hard_clothing,
hard_scene and f"in {hard_scene}",
expression,
composition,
):
deps.add(tags, seen, value)
for tag in deps.camera_tags(row, root.get("hardcore_camera_directive"), root.get("hardcore_camera_config")):
deps.add_one(tags, seen, tag)
combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition])
for tag in deps.explicit_tags(combined, request.nude_weight):
deps.add_one(tags, seen, tag)
return SDXLTagRoute(tags)
def row_core_tags(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> list[str]:
return row_core_tags_result(request, deps).tags
def soft_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str:
return soft_tags_result(request, deps).as_text()
def hard_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str:
return hard_tags_result(request, deps).as_text()