265 lines
9.1 KiB
Python
265 lines
9.1 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SDXLRowTagRequest:
|
|
row: dict[str, Any]
|
|
nude_weight: float
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SDXLPairTagRequest:
|
|
row: dict[str, Any]
|
|
root: dict[str, Any]
|
|
nude_weight: float
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SDXLTagRoute:
|
|
tags: list[str]
|
|
|
|
def as_text(self) -> str:
|
|
return ", ".join(self.tags)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SDXLTagRouteDependencies:
|
|
clean: Callable[[Any], str]
|
|
row_value: Callable[[dict[str, Any], str, tuple[str, ...]], str]
|
|
tag_key: Callable[[str], str]
|
|
add: Callable[[list[str], set[str], Any], None]
|
|
add_one: Callable[[list[str], set[str], str], None]
|
|
count_tag: Callable[[int, int], list[str]]
|
|
infer_counts: Callable[[dict[str, Any]], tuple[int, int]]
|
|
normal_character_tags: Callable[[dict[str, Any]], list[str]]
|
|
character_tags_from_descriptor: Callable[[Any], list[str]]
|
|
metadata_family_tags: Callable[[dict[str, Any]], list[str]]
|
|
formatter_hint_tags: Callable[..., list[str]]
|
|
axis_value_tags: Callable[[dict[str, Any]], list[str]]
|
|
camera_tags: Callable[..., list[str]]
|
|
explicit_tags: Callable[[str, float], list[str]]
|
|
filter_incompatible_route_tags: Callable[[list[str], dict[str, Any]], list[str]]
|
|
softcore_pair_tags: Callable[[dict[str, Any], dict[str, Any]], list[str]]
|
|
|
|
|
|
def _descriptor_counts(root: dict[str, Any]) -> tuple[int, int]:
|
|
descriptors = root.get("shared_cast_descriptors")
|
|
if not isinstance(descriptors, list):
|
|
return 0, 0
|
|
women = 0
|
|
men = 0
|
|
for descriptor in descriptors:
|
|
text = str(descriptor).lower()
|
|
if re.search(r"\bwoman\b", text):
|
|
women += 1
|
|
elif re.search(r"\bman\b", text):
|
|
men += 1
|
|
return women, men
|
|
|
|
|
|
def _pair_counts(row: dict[str, Any], root: dict[str, Any]) -> tuple[int, int]:
|
|
try:
|
|
women = int(root.get("hardcore_women_count") or row.get("women_count") or 0)
|
|
men = int(root.get("hardcore_men_count") or row.get("men_count") or 0)
|
|
except (TypeError, ValueError):
|
|
women, men = 0, 0
|
|
if women or men:
|
|
return women, men
|
|
return _descriptor_counts(root)
|
|
|
|
|
|
def _composition_tags_text(text: str) -> str:
|
|
text = re.sub(r"^vertical\s+", "", str(text or "").strip(), flags=re.IGNORECASE)
|
|
text = re.sub(r"\s+composition$", "", text, flags=re.IGNORECASE)
|
|
text = re.sub(r"\bcomposition\b", "frame", text, flags=re.IGNORECASE)
|
|
return text.strip(" ,")
|
|
|
|
|
|
def _row_explicit_signal_text(
|
|
row: dict[str, Any],
|
|
*,
|
|
item: str,
|
|
pose: str,
|
|
role_graph: str,
|
|
expression: str,
|
|
composition: str,
|
|
deps: SDXLTagRouteDependencies,
|
|
) -> str:
|
|
values = (
|
|
item,
|
|
pose,
|
|
role_graph,
|
|
deps.clean(row.get("hardcore_clothing_state")),
|
|
deps.clean(row.get("clothing_state")),
|
|
deps.clean(row.get("clothing")),
|
|
deps.clean(row.get("scene_kind")),
|
|
expression,
|
|
composition,
|
|
)
|
|
return " ".join(deps.clean(value) for value in values if deps.clean(value))
|
|
|
|
|
|
def _uses_hardcore_action_route(row: dict[str, Any]) -> bool:
|
|
return (
|
|
str(row.get("category_slug") or "").strip() == "hardcore_sexual_poses"
|
|
or bool(str(row.get("action_family") or "").strip())
|
|
or bool(str(row.get("position_family") or "").strip())
|
|
)
|
|
|
|
|
|
def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
|
|
row = request.row
|
|
tags: list[str] = []
|
|
seen: set[str] = set()
|
|
women, men = deps.infer_counts(row)
|
|
for tag in deps.count_tag(women, men):
|
|
deps.add_one(tags, seen, tag)
|
|
|
|
for tag in deps.normal_character_tags(row):
|
|
deps.add_one(tags, seen, tag)
|
|
|
|
for tag in deps.metadata_family_tags(row):
|
|
deps.add_one(tags, seen, tag)
|
|
for tag in deps.formatter_hint_tags(row):
|
|
deps.add(tags, seen, tag)
|
|
for tag in deps.axis_value_tags(row):
|
|
deps.add(tags, seen, tag)
|
|
|
|
item = deps.row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or deps.clean(
|
|
row.get("custom_item")
|
|
)
|
|
pose = "" if _uses_hardcore_action_route(row) else deps.row_value(row, "pose", ("Sexual pose", "Pose"))
|
|
role_graph = deps.clean(row.get("source_role_graph") or row.get("role_graph"))
|
|
scene = deps.row_value(row, "scene_text", ("Setting", "Scene")) or deps.clean(row.get("scene"))
|
|
expression = deps.row_value(row, "character_expression_text") or deps.row_value(
|
|
row,
|
|
"expression",
|
|
("Facial expressions", "Facial expression"),
|
|
)
|
|
composition = _composition_tags_text(deps.row_value(row, "composition", ("Composition",)))
|
|
for value in (
|
|
item,
|
|
pose,
|
|
role_graph,
|
|
scene and f"in {scene}",
|
|
expression,
|
|
composition,
|
|
):
|
|
deps.add(tags, seen, value)
|
|
for tag in deps.camera_tags(row):
|
|
deps.add_one(tags, seen, tag)
|
|
|
|
combined = _row_explicit_signal_text(
|
|
row,
|
|
item=item,
|
|
pose=pose,
|
|
role_graph=role_graph,
|
|
expression=expression,
|
|
composition=composition,
|
|
deps=deps,
|
|
)
|
|
for tag in deps.explicit_tags(combined, request.nude_weight):
|
|
deps.add_one(tags, seen, tag)
|
|
tags = deps.filter_incompatible_route_tags(tags, row)
|
|
return SDXLTagRoute(tags)
|
|
|
|
|
|
def soft_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
|
|
row = request.row
|
|
root = request.root
|
|
tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags
|
|
seen = {deps.tag_key(tag) for tag in tags}
|
|
women, men = _pair_counts(row, root)
|
|
for tag in deps.count_tag(women, men):
|
|
deps.add_one(tags, seen, tag)
|
|
|
|
for tag in deps.formatter_hint_tags(root):
|
|
deps.add(tags, seen, tag)
|
|
|
|
descriptors = root.get("shared_cast_descriptors")
|
|
if isinstance(descriptors, list) and descriptors:
|
|
for descriptor in descriptors:
|
|
for tag in deps.character_tags_from_descriptor(descriptor):
|
|
deps.add_one(tags, seen, tag)
|
|
else:
|
|
descriptor = deps.clean(root.get("shared_descriptor"))
|
|
for tag in deps.character_tags_from_descriptor(descriptor):
|
|
deps.add_one(tags, seen, tag)
|
|
|
|
partner = root.get("softcore_partner_styling")
|
|
if isinstance(partner, dict):
|
|
deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item)))
|
|
deps.add(tags, seen, partner.get("pose"))
|
|
for tag in deps.softcore_pair_tags(row, root):
|
|
deps.add_one(tags, seen, tag)
|
|
deps.add_one(tags, seen, "sexy")
|
|
deps.add_one(tags, seen, "looking at viewer")
|
|
return SDXLTagRoute(tags)
|
|
|
|
|
|
def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
|
|
row = request.row
|
|
root = request.root
|
|
tags: list[str] = []
|
|
seen: set[str] = set()
|
|
women, men = _pair_counts(row, root)
|
|
if not women and not men:
|
|
women, men = 1, 1
|
|
for tag in deps.count_tag(women, men):
|
|
deps.add_one(tags, seen, tag)
|
|
|
|
descriptors = root.get("shared_cast_descriptors")
|
|
if isinstance(descriptors, list):
|
|
for descriptor in descriptors:
|
|
for tag in deps.character_tags_from_descriptor(descriptor):
|
|
deps.add_one(tags, seen, tag)
|
|
else:
|
|
for tag in deps.normal_character_tags(row):
|
|
deps.add_one(tags, seen, tag)
|
|
|
|
for tag in deps.metadata_family_tags(row):
|
|
deps.add_one(tags, seen, tag)
|
|
for tag in deps.formatter_hint_tags(row, root):
|
|
deps.add(tags, seen, tag)
|
|
for tag in deps.axis_value_tags(row):
|
|
deps.add(tags, seen, tag)
|
|
|
|
hard_scene = deps.clean(row.get("scene_text"))
|
|
hard_item = deps.clean(row.get("item"))
|
|
hard_role = deps.clean(row.get("source_role_graph") or row.get("role_graph"))
|
|
hard_clothing = deps.clean(root.get("hardcore_clothing_state"))
|
|
expression = deps.clean(row.get("character_expression_text") or row.get("expression"))
|
|
composition = _composition_tags_text(deps.clean(row.get("composition")))
|
|
for value in (
|
|
hard_role,
|
|
hard_item,
|
|
hard_clothing,
|
|
hard_scene and f"in {hard_scene}",
|
|
expression,
|
|
composition,
|
|
):
|
|
deps.add(tags, seen, value)
|
|
for tag in deps.camera_tags(row, root.get("hardcore_camera_directive"), root.get("hardcore_camera_config")):
|
|
deps.add_one(tags, seen, tag)
|
|
combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition])
|
|
for tag in deps.explicit_tags(combined, request.nude_weight):
|
|
deps.add_one(tags, seen, tag)
|
|
tags = deps.filter_incompatible_route_tags(tags, row)
|
|
return SDXLTagRoute(tags)
|
|
|
|
|
|
def row_core_tags(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> list[str]:
|
|
return row_core_tags_result(request, deps).tags
|
|
|
|
|
|
def soft_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str:
|
|
return soft_tags_result(request, deps).as_text()
|
|
|
|
|
|
def hard_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str:
|
|
return hard_tags_result(request, deps).as_text()
|