from __future__ import annotations import re from dataclasses import dataclass from typing import Any, Callable @dataclass(frozen=True) class SDXLRowTagRequest: row: dict[str, Any] nude_weight: float @dataclass(frozen=True) class SDXLPairTagRequest: row: dict[str, Any] root: dict[str, Any] nude_weight: float @dataclass(frozen=True) class SDXLTagRoute: tags: list[str] def as_text(self) -> str: return ", ".join(self.tags) @dataclass(frozen=True) class SDXLTagRouteDependencies: clean: Callable[[Any], str] row_value: Callable[[dict[str, Any], str, tuple[str, ...]], str] tag_key: Callable[[str], str] add: Callable[[list[str], set[str], Any], None] add_one: Callable[[list[str], set[str], str], None] count_tag: Callable[[int, int], list[str]] infer_counts: Callable[[dict[str, Any]], tuple[int, int]] normal_character_tags: Callable[[dict[str, Any]], list[str]] character_tags_from_descriptor: Callable[[Any], list[str]] metadata_family_tags: Callable[[dict[str, Any]], list[str]] formatter_hint_tags: Callable[..., list[str]] camera_tags: Callable[..., list[str]] explicit_tags: Callable[[str, float], list[str]] softcore_pair_tags: Callable[[dict[str, Any], dict[str, Any]], list[str]] def _descriptor_counts(root: dict[str, Any]) -> tuple[int, int]: descriptors = root.get("shared_cast_descriptors") if not isinstance(descriptors, list): return 0, 0 women = 0 men = 0 for descriptor in descriptors: text = str(descriptor).lower() if re.search(r"\bwoman\b", text): women += 1 elif re.search(r"\bman\b", text): men += 1 return women, men def _pair_counts(row: dict[str, Any], root: dict[str, Any]) -> tuple[int, int]: try: women = int(root.get("hardcore_women_count") or row.get("women_count") or 0) men = int(root.get("hardcore_men_count") or row.get("men_count") or 0) except (TypeError, ValueError): women, men = 0, 0 if women or men: return women, men return _descriptor_counts(root) def _row_explicit_signal_text( row: dict[str, Any], *, item: str, pose: str, role_graph: str, expression: str, composition: str, deps: SDXLTagRouteDependencies, ) -> str: values = ( item, pose, role_graph, deps.clean(row.get("hardcore_clothing_state")), deps.clean(row.get("clothing_state")), deps.clean(row.get("clothing")), deps.clean(row.get("scene_kind")), expression, composition, ) return " ".join(deps.clean(value) for value in values if deps.clean(value)) def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute: row = request.row tags: list[str] = [] seen: set[str] = set() women, men = deps.infer_counts(row) for tag in deps.count_tag(women, men): deps.add_one(tags, seen, tag) for tag in deps.normal_character_tags(row): deps.add_one(tags, seen, tag) for tag in deps.metadata_family_tags(row): deps.add_one(tags, seen, tag) for tag in deps.formatter_hint_tags(row): deps.add(tags, seen, tag) item = deps.row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or deps.clean( row.get("custom_item") ) pose = deps.row_value(row, "pose", ("Sexual pose", "Pose")) role_graph = deps.clean(row.get("source_role_graph") or row.get("role_graph")) scene = deps.row_value(row, "scene_text", ("Setting", "Scene")) or deps.clean(row.get("scene")) expression = deps.row_value(row, "character_expression_text") or deps.row_value( row, "expression", ("Facial expressions", "Facial expression"), ) composition = deps.row_value(row, "composition", ("Composition",)) for value in ( item, pose, role_graph, scene and f"in {scene}", expression, composition, ): deps.add(tags, seen, value) for tag in deps.camera_tags(row): deps.add_one(tags, seen, tag) combined = _row_explicit_signal_text( row, item=item, pose=pose, role_graph=role_graph, expression=expression, composition=composition, deps=deps, ) for tag in deps.explicit_tags(combined, request.nude_weight): deps.add_one(tags, seen, tag) return SDXLTagRoute(tags) def soft_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute: row = request.row root = request.root tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags seen = {deps.tag_key(tag) for tag in tags} women, men = _pair_counts(row, root) for tag in deps.count_tag(women, men): deps.add_one(tags, seen, tag) for tag in deps.formatter_hint_tags(root): deps.add(tags, seen, tag) descriptors = root.get("shared_cast_descriptors") if isinstance(descriptors, list) and descriptors: for descriptor in descriptors: for tag in deps.character_tags_from_descriptor(descriptor): deps.add_one(tags, seen, tag) else: descriptor = deps.clean(root.get("shared_descriptor")) for tag in deps.character_tags_from_descriptor(descriptor): deps.add_one(tags, seen, tag) partner = root.get("softcore_partner_styling") if isinstance(partner, dict): deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item))) deps.add(tags, seen, partner.get("pose")) for tag in deps.softcore_pair_tags(row, root): deps.add_one(tags, seen, tag) deps.add_one(tags, seen, "sexy") deps.add_one(tags, seen, "looking at viewer") return SDXLTagRoute(tags) def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute: row = request.row root = request.root tags: list[str] = [] seen: set[str] = set() women, men = _pair_counts(row, root) if not women and not men: women, men = 1, 1 for tag in deps.count_tag(women, men): deps.add_one(tags, seen, tag) descriptors = root.get("shared_cast_descriptors") if isinstance(descriptors, list): for descriptor in descriptors: for tag in deps.character_tags_from_descriptor(descriptor): deps.add_one(tags, seen, tag) else: for tag in deps.normal_character_tags(row): deps.add_one(tags, seen, tag) for tag in deps.metadata_family_tags(row): deps.add_one(tags, seen, tag) for tag in deps.formatter_hint_tags(row, root): deps.add(tags, seen, tag) hard_scene = deps.clean(row.get("scene_text")) hard_item = deps.clean(row.get("item")) hard_role = deps.clean(row.get("source_role_graph") or row.get("role_graph")) hard_clothing = deps.clean(root.get("hardcore_clothing_state")) expression = deps.clean(row.get("character_expression_text") or row.get("expression")) composition = deps.clean(row.get("composition")) for value in ( hard_role, hard_item, hard_clothing, hard_scene and f"in {hard_scene}", expression, composition, ): deps.add(tags, seen, value) for tag in deps.camera_tags(row, root.get("hardcore_camera_directive"), root.get("hardcore_camera_config")): deps.add_one(tags, seen, tag) combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition]) for tag in deps.explicit_tags(combined, request.nude_weight): deps.add_one(tags, seen, tag) return SDXLTagRoute(tags) def row_core_tags(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> list[str]: return row_core_tags_result(request, deps).tags def soft_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str: return soft_tags_result(request, deps).as_text() def hard_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str: return hard_tags_result(request, deps).as_text()