from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable @dataclass(frozen=True) class SDXLRowTagRequest: row: dict[str, Any] nude_weight: float @dataclass(frozen=True) class SDXLPairTagRequest: row: dict[str, Any] root: dict[str, Any] nude_weight: float @dataclass(frozen=True) class SDXLTagRoute: tags: list[str] def as_text(self) -> str: return ", ".join(self.tags) @dataclass(frozen=True) class SDXLTagRouteDependencies: clean: Callable[[Any], str] row_value: Callable[[dict[str, Any], str, tuple[str, ...]], str] tag_key: Callable[[str], str] add: Callable[[list[str], set[str], Any], None] add_one: Callable[[list[str], set[str], str], None] count_tag: Callable[[int, int], list[str]] infer_counts: Callable[[dict[str, Any]], tuple[int, int]] normal_character_tags: Callable[[dict[str, Any]], list[str]] character_tags_from_descriptor: Callable[[Any], list[str]] metadata_family_tags: Callable[[dict[str, Any]], list[str]] formatter_hint_tags: Callable[..., list[str]] camera_tags: Callable[..., list[str]] explicit_tags: Callable[[str, float], list[str]] def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute: row = request.row tags: list[str] = [] seen: set[str] = set() women, men = deps.infer_counts(row) for tag in deps.count_tag(women, men): deps.add_one(tags, seen, tag) for tag in deps.normal_character_tags(row): deps.add_one(tags, seen, tag) for tag in deps.metadata_family_tags(row): deps.add_one(tags, seen, tag) for tag in deps.formatter_hint_tags(row): deps.add(tags, seen, tag) item = deps.row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or deps.clean( row.get("custom_item") ) pose = deps.row_value(row, "pose", ("Sexual pose", "Pose")) role_graph = deps.clean(row.get("source_role_graph") or row.get("role_graph")) scene = deps.row_value(row, "scene_text", ("Setting", "Scene")) or deps.clean(row.get("scene")) expression = deps.row_value(row, "character_expression_text") or deps.row_value( row, "expression", ("Facial expressions", "Facial expression"), ) composition = deps.row_value(row, "composition", ("Composition",)) for value in ( item, pose, role_graph, scene and f"in {scene}", expression, composition, ): deps.add(tags, seen, value) for tag in deps.camera_tags(row): deps.add_one(tags, seen, tag) combined = " ".join(deps.clean(value) for value in (item, pose, role_graph, row.get("prompt", ""))) for tag in deps.explicit_tags(combined, request.nude_weight): deps.add_one(tags, seen, tag) return SDXLTagRoute(tags) def soft_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute: row = request.row root = request.root tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags seen = {deps.tag_key(tag) for tag in tags} for tag in deps.formatter_hint_tags(root): deps.add(tags, seen, tag) descriptor = deps.clean(root.get("shared_descriptor")) if descriptor and not any("woman" in deps.tag_key(tag) for tag in tags): for tag in deps.character_tags_from_descriptor(descriptor): deps.add_one(tags, seen, tag) partner = root.get("softcore_partner_styling") if isinstance(partner, dict): deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item))) deps.add(tags, seen, partner.get("pose")) deps.add_one(tags, seen, "sexy") deps.add_one(tags, seen, "looking at viewer") return SDXLTagRoute(tags) def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute: row = request.row root = request.root tags: list[str] = [] seen: set[str] = set() try: women = int(root.get("hardcore_women_count") or row.get("women_count") or 1) men = int(root.get("hardcore_men_count") or row.get("men_count") or 1) except (TypeError, ValueError): women, men = 1, 1 for tag in deps.count_tag(women, men): deps.add_one(tags, seen, tag) descriptors = root.get("shared_cast_descriptors") if isinstance(descriptors, list): for descriptor in descriptors: for tag in deps.character_tags_from_descriptor(descriptor): deps.add_one(tags, seen, tag) else: for tag in deps.normal_character_tags(row): deps.add_one(tags, seen, tag) for tag in deps.metadata_family_tags(row): deps.add_one(tags, seen, tag) for tag in deps.formatter_hint_tags(row, root): deps.add(tags, seen, tag) hard_scene = deps.clean(row.get("scene_text")) hard_item = deps.clean(row.get("item")) hard_role = deps.clean(row.get("source_role_graph") or row.get("role_graph")) hard_clothing = deps.clean(root.get("hardcore_clothing_state")) expression = deps.clean(row.get("character_expression_text") or row.get("expression")) composition = deps.clean(row.get("composition")) for value in ( hard_role, hard_item, hard_clothing, hard_scene and f"in {hard_scene}", expression, composition, ): deps.add(tags, seen, value) for tag in deps.camera_tags(row, root.get("hardcore_camera_directive"), root.get("hardcore_camera_config")): deps.add_one(tags, seen, tag) combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition, root.get("hardcore_prompt", "") or ""]) for tag in deps.explicit_tags(combined, request.nude_weight): deps.add_one(tags, seen, tag) return SDXLTagRoute(tags) def row_core_tags(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> list[str]: return row_core_tags_result(request, deps).tags def soft_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str: return soft_tags_result(request, deps).as_text() def hard_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str: return hard_tags_result(request, deps).as_text()