Extract SDXL tag route assembly
This commit is contained in:
@@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SDXLRowTagRequest:
|
||||
row: dict[str, Any]
|
||||
nude_weight: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SDXLPairTagRequest:
|
||||
row: dict[str, Any]
|
||||
root: dict[str, Any]
|
||||
nude_weight: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SDXLTagRoute:
|
||||
tags: list[str]
|
||||
|
||||
def as_text(self) -> str:
|
||||
return ", ".join(self.tags)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SDXLTagRouteDependencies:
|
||||
clean: Callable[[Any], str]
|
||||
row_value: Callable[[dict[str, Any], str, tuple[str, ...]], str]
|
||||
tag_key: Callable[[str], str]
|
||||
add: Callable[[list[str], set[str], Any], None]
|
||||
add_one: Callable[[list[str], set[str], str], None]
|
||||
count_tag: Callable[[int, int], list[str]]
|
||||
infer_counts: Callable[[dict[str, Any]], tuple[int, int]]
|
||||
normal_character_tags: Callable[[dict[str, Any]], list[str]]
|
||||
character_tags_from_descriptor: Callable[[Any], list[str]]
|
||||
metadata_family_tags: Callable[[dict[str, Any]], list[str]]
|
||||
formatter_hint_tags: Callable[..., list[str]]
|
||||
camera_tags: Callable[..., list[str]]
|
||||
explicit_tags: Callable[[str, float], list[str]]
|
||||
|
||||
|
||||
def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
|
||||
row = request.row
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
women, men = deps.infer_counts(row)
|
||||
for tag in deps.count_tag(women, men):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
for tag in deps.normal_character_tags(row):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
for tag in deps.metadata_family_tags(row):
|
||||
deps.add_one(tags, seen, tag)
|
||||
for tag in deps.formatter_hint_tags(row):
|
||||
deps.add(tags, seen, tag)
|
||||
|
||||
item = deps.row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or deps.clean(
|
||||
row.get("custom_item")
|
||||
)
|
||||
pose = deps.row_value(row, "pose", ("Sexual pose", "Pose"))
|
||||
role_graph = deps.clean(row.get("source_role_graph") or row.get("role_graph"))
|
||||
scene = deps.row_value(row, "scene_text", ("Setting", "Scene")) or deps.clean(row.get("scene"))
|
||||
expression = deps.row_value(row, "character_expression_text") or deps.row_value(
|
||||
row,
|
||||
"expression",
|
||||
("Facial expressions", "Facial expression"),
|
||||
)
|
||||
composition = deps.row_value(row, "composition", ("Composition",))
|
||||
for value in (
|
||||
item,
|
||||
pose,
|
||||
role_graph,
|
||||
scene and f"in {scene}",
|
||||
expression,
|
||||
composition,
|
||||
):
|
||||
deps.add(tags, seen, value)
|
||||
for tag in deps.camera_tags(row):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
combined = " ".join(deps.clean(value) for value in (item, pose, role_graph, row.get("prompt", "")))
|
||||
for tag in deps.explicit_tags(combined, request.nude_weight):
|
||||
deps.add_one(tags, seen, tag)
|
||||
return SDXLTagRoute(tags)
|
||||
|
||||
|
||||
def soft_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
|
||||
row = request.row
|
||||
root = request.root
|
||||
tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags
|
||||
seen = {deps.tag_key(tag) for tag in tags}
|
||||
for tag in deps.formatter_hint_tags(root):
|
||||
deps.add(tags, seen, tag)
|
||||
descriptor = deps.clean(root.get("shared_descriptor"))
|
||||
if descriptor and not any("woman" in deps.tag_key(tag) for tag in tags):
|
||||
for tag in deps.character_tags_from_descriptor(descriptor):
|
||||
deps.add_one(tags, seen, tag)
|
||||
partner = root.get("softcore_partner_styling")
|
||||
if isinstance(partner, dict):
|
||||
deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item)))
|
||||
deps.add(tags, seen, partner.get("pose"))
|
||||
deps.add_one(tags, seen, "sexy")
|
||||
deps.add_one(tags, seen, "looking at viewer")
|
||||
return SDXLTagRoute(tags)
|
||||
|
||||
|
||||
def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
|
||||
row = request.row
|
||||
root = request.root
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
try:
|
||||
women = int(root.get("hardcore_women_count") or row.get("women_count") or 1)
|
||||
men = int(root.get("hardcore_men_count") or row.get("men_count") or 1)
|
||||
except (TypeError, ValueError):
|
||||
women, men = 1, 1
|
||||
for tag in deps.count_tag(women, men):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
descriptors = root.get("shared_cast_descriptors")
|
||||
if isinstance(descriptors, list):
|
||||
for descriptor in descriptors:
|
||||
for tag in deps.character_tags_from_descriptor(descriptor):
|
||||
deps.add_one(tags, seen, tag)
|
||||
else:
|
||||
for tag in deps.normal_character_tags(row):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
for tag in deps.metadata_family_tags(row):
|
||||
deps.add_one(tags, seen, tag)
|
||||
for tag in deps.formatter_hint_tags(row, root):
|
||||
deps.add(tags, seen, tag)
|
||||
|
||||
hard_scene = deps.clean(row.get("scene_text"))
|
||||
hard_item = deps.clean(row.get("item"))
|
||||
hard_role = deps.clean(row.get("source_role_graph") or row.get("role_graph"))
|
||||
hard_clothing = deps.clean(root.get("hardcore_clothing_state"))
|
||||
expression = deps.clean(row.get("character_expression_text") or row.get("expression"))
|
||||
composition = deps.clean(row.get("composition"))
|
||||
for value in (
|
||||
hard_role,
|
||||
hard_item,
|
||||
hard_clothing,
|
||||
hard_scene and f"in {hard_scene}",
|
||||
expression,
|
||||
composition,
|
||||
):
|
||||
deps.add(tags, seen, value)
|
||||
for tag in deps.camera_tags(row, root.get("hardcore_camera_directive"), root.get("hardcore_camera_config")):
|
||||
deps.add_one(tags, seen, tag)
|
||||
combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition, root.get("hardcore_prompt", "") or ""])
|
||||
for tag in deps.explicit_tags(combined, request.nude_weight):
|
||||
deps.add_one(tags, seen, tag)
|
||||
return SDXLTagRoute(tags)
|
||||
|
||||
|
||||
def row_core_tags(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> list[str]:
|
||||
return row_core_tags_result(request, deps).tags
|
||||
|
||||
|
||||
def soft_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str:
|
||||
return soft_tags_result(request, deps).as_text()
|
||||
|
||||
|
||||
def hard_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str:
|
||||
return hard_tags_result(request, deps).as_text()
|
||||
Reference in New Issue
Block a user