279 lines
9.8 KiB
Python
279 lines
9.8 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import Any
|
|
|
|
try:
|
|
from . import formatter_input as input_policy
|
|
from . import item_axis_policy
|
|
from . import route_metadata as route_metadata_policy
|
|
from . import sdxl_presets as sdxl_policy
|
|
from . import sdxl_tag_routes
|
|
from . import softcore_text_policy
|
|
except ImportError: # Allows local smoke tests with `python -c`.
|
|
import formatter_input as input_policy
|
|
import item_axis_policy
|
|
import route_metadata as route_metadata_policy
|
|
import sdxl_presets as sdxl_policy
|
|
import sdxl_tag_routes
|
|
import softcore_text_policy
|
|
|
|
|
|
PROMPT_FIELD_LABELS = input_policy.prompt_field_labels()
|
|
|
|
|
|
def clean(value: Any) -> str:
|
|
return input_policy.clean_text(value)
|
|
|
|
|
|
def prompt_field(text: str, label: str) -> str:
|
|
return input_policy.prompt_field(text, label, field_labels=PROMPT_FIELD_LABELS)
|
|
|
|
|
|
def row_value(row: dict[str, Any], key: str, labels: tuple[str, ...] = ()) -> str:
|
|
return input_policy.row_value(row, key, labels, field_labels=PROMPT_FIELD_LABELS)
|
|
|
|
|
|
def split_tag_text(text: Any) -> list[str]:
|
|
text = clean(text)
|
|
if not text:
|
|
return []
|
|
text = input_policy.strip_prompt_field_labels(text, field_labels=PROMPT_FIELD_LABELS)
|
|
text = re.sub(r"\bWoman [A-Z]'s\b", "woman's", text)
|
|
text = re.sub(r"\bMan [A-Z]'s\b", "man's", text)
|
|
text = re.sub(r"\bWoman [A-Z]\b", "woman", text)
|
|
text = re.sub(r"\bMan [A-Z]\b", "man", text)
|
|
text = re.sub(
|
|
r"\b(?:Clothing state|Visual clothing state|visible remaining styling|teaser outfit detail|softcore visual reference|Sexual scene|Role graph):\s*",
|
|
"",
|
|
text,
|
|
flags=re.IGNORECASE,
|
|
)
|
|
text = re.sub(r"\b(?:and|with)\b", ",", text, flags=re.IGNORECASE)
|
|
parts = re.split(r"\s*[,;]\s*", text)
|
|
return [clean(part).strip(" .") for part in parts if clean(part).strip(" .")]
|
|
|
|
|
|
def tag_key(tag: str) -> str:
|
|
text = clean(tag).lower()
|
|
text = re.sub(r"^\((.*?):[0-9.]+\)$", r"\1", text)
|
|
text = text.strip("() ")
|
|
return text
|
|
|
|
|
|
def add(tags: list[str], seen: set[str], value: Any) -> None:
|
|
for tag in split_tag_text(value):
|
|
key = tag_key(tag)
|
|
if key and key not in seen:
|
|
tags.append(tag)
|
|
seen.add(key)
|
|
|
|
|
|
def add_one(tags: list[str], seen: set[str], tag: str) -> None:
|
|
tag = clean(tag).strip(" ,")
|
|
key = tag_key(tag)
|
|
if tag and key and key not in seen:
|
|
tags.append(tag)
|
|
seen.add(key)
|
|
|
|
|
|
def metadata_family_tags(row: dict[str, Any]) -> list[str]:
|
|
tags: list[str] = []
|
|
action_family = route_metadata_policy.row_action_family(row)
|
|
tags.extend(sdxl_policy.SDXL_ACTION_FAMILY_TAGS.get(action_family, ()))
|
|
|
|
position_family = route_metadata_policy.row_position_family(row)
|
|
tags.extend(sdxl_policy.SDXL_POSITION_FAMILY_TAGS.get(position_family, ()))
|
|
|
|
for key in route_metadata_policy.row_position_keys(row, include_unknown=True):
|
|
key_text = clean(key)
|
|
if key_text:
|
|
tags.append(key_text.replace("_", " "))
|
|
return tags
|
|
|
|
|
|
def formatter_hint_tags(*rows: dict[str, Any]) -> list[str]:
|
|
tags: list[str] = []
|
|
for row in rows:
|
|
if not isinstance(row, dict):
|
|
continue
|
|
for hint in route_metadata_policy.row_formatter_hints(row, "sdxl"):
|
|
hint = clean(hint).strip(" ,.")
|
|
if hint and hint not in tags:
|
|
tags.append(hint)
|
|
return tags
|
|
|
|
|
|
def axis_value_tags(row: dict[str, Any]) -> list[str]:
|
|
tags: list[str] = []
|
|
seen: set[str] = set()
|
|
for text in item_axis_policy.row_axis_value_texts(row):
|
|
add(tags, seen, text)
|
|
return tags
|
|
|
|
|
|
def combine_tags(*parts: Any) -> str:
|
|
tags: list[str] = []
|
|
seen: set[str] = set()
|
|
for part in parts:
|
|
add(tags, seen, part)
|
|
return ", ".join(tags)
|
|
|
|
|
|
def combine_negative(*parts: Any) -> str:
|
|
return combine_tags(*(part for part in parts if clean(part)))
|
|
|
|
|
|
def count_tag(women_count: int = 0, men_count: int = 0) -> list[str]:
|
|
tags = []
|
|
if women_count > 0:
|
|
tags.append(f"{women_count}woman" if women_count == 1 else f"{women_count}women")
|
|
if men_count > 0:
|
|
tags.append(f"{men_count}man" if men_count == 1 else f"{men_count}men")
|
|
return tags
|
|
|
|
|
|
def infer_counts(row: dict[str, Any]) -> tuple[int, int]:
|
|
try:
|
|
women = int(row.get("women_count") or 0)
|
|
men = int(row.get("men_count") or 0)
|
|
except (TypeError, ValueError):
|
|
women = men = 0
|
|
if women or men:
|
|
return women, men
|
|
subject = clean(row.get("subject_type") or row.get("primary_subject")).lower()
|
|
phrase = clean(row.get("subject_phrase")).lower()
|
|
text = f"{subject} {phrase}"
|
|
if "two women" in text:
|
|
return 2, 0
|
|
if "two men" in text:
|
|
return 0, 2
|
|
if "woman and" in text or "woman a" in text and "man a" in text:
|
|
return 1, 1
|
|
if "group" in text:
|
|
return 2, 2
|
|
if "man" in text and "woman" not in text:
|
|
return 0, 1
|
|
return 1, 0
|
|
|
|
|
|
def character_tags_from_descriptor(descriptor: Any) -> list[str]:
|
|
text = clean(descriptor)
|
|
text = re.sub(r"\bWoman [A-Z]\s*/\s*primary creator:\s*", "", text)
|
|
text = re.sub(r"\b(?:Woman|Man) [A-Z]:\s*", "", text)
|
|
text = re.sub(r"\balongside\b", ",", text, flags=re.IGNORECASE)
|
|
parts = split_tag_text(text)
|
|
cleaned = []
|
|
for part in parts:
|
|
part = re.sub(r"\bfigure\b", "build", part, flags=re.IGNORECASE)
|
|
part = part.replace("adult adult", "adult")
|
|
cleaned.append(part)
|
|
return cleaned
|
|
|
|
|
|
def normal_character_tags(row: dict[str, Any]) -> list[str]:
|
|
descriptor = clean(row.get("cast_descriptor_text"))
|
|
if descriptor:
|
|
return character_tags_from_descriptor(descriptor)
|
|
|
|
parts = [
|
|
clean(row.get("age") or row.get("age_band")),
|
|
clean(row.get("subject_phrase") or row.get("subject_type") or row.get("primary_subject")),
|
|
clean(row.get("body_phrase") or row.get("body") or row.get("body_type")),
|
|
clean(row.get("skin")),
|
|
clean(row.get("hair")),
|
|
clean(row.get("eyes")),
|
|
]
|
|
return [part for part in parts if part and part not in ("woman", "man", "single_any")]
|
|
|
|
|
|
def camera_tags_from_config(config: Any) -> list[str]:
|
|
if not isinstance(config, dict):
|
|
return []
|
|
if clean(config.get("camera_detail")) == "off" or clean(config.get("camera_mode")) == "disabled":
|
|
return []
|
|
custom = clean(config.get("custom_camera_prompt"))
|
|
tags = split_tag_text(custom)
|
|
direction = clean(config.get("orbit_direction"))
|
|
elevation = clean(config.get("orbit_elevation_label"))
|
|
distance = clean(config.get("orbit_distance_label"))
|
|
for value in (direction, elevation, distance):
|
|
if value and value != "auto":
|
|
tags.extend(split_tag_text(value))
|
|
for key in ("angle", "shot_size", "distance", "lens", "orientation", "subject_focus"):
|
|
value = clean(config.get(key)).replace("_", " ")
|
|
if value and value != "auto":
|
|
tags.append(value)
|
|
return tags
|
|
|
|
|
|
def camera_tags(row: dict[str, Any], directive: Any = "", config: Any = None) -> list[str]:
|
|
tags = split_tag_text(directive)
|
|
tags.extend(camera_tags_from_config(config if config is not None else row.get("camera_config")))
|
|
camera_directive = clean(row.get("camera_directive"))
|
|
if camera_directive:
|
|
tags.extend(split_tag_text(camera_directive))
|
|
out = []
|
|
for tag in tags:
|
|
tag = tag.replace("0-degree front view", "(front facing:1.15)")
|
|
tag = tag.replace("front view", "(front facing:1.15)")
|
|
tag = tag.replace("right side view", "side view")
|
|
tag = tag.replace("left side view", "side view")
|
|
out.append(tag)
|
|
return out
|
|
|
|
|
|
def explicit_tags(text: str, nude_weight: float) -> list[str]:
|
|
lower = text.lower()
|
|
tags: list[str] = []
|
|
if any(token in lower for token in ("fully nude", "fully exposed", "naked", "bare skin unobstructed", "explicit_nude")):
|
|
tags.append(f"(naked:{nude_weight:.2f})")
|
|
if any(token in lower for token in ("nipples", "breasts exposed", "bare breasts", "nipple")):
|
|
tags.append("nipples")
|
|
if any(token in lower for token in ("pussy", "vulva", "genitals")):
|
|
tags.append("pussy")
|
|
if any(token in lower for token in ("penis", "cock")):
|
|
tags.append("penis")
|
|
if "penetration" in lower or "thrust" in lower:
|
|
tags.append("penetration")
|
|
if "vaginal" in lower:
|
|
tags.append("pussy")
|
|
if "oral" in lower or "mouth" in lower:
|
|
tags.append("oral sex")
|
|
if "anal" in lower:
|
|
tags.append("anal sex")
|
|
if any(token in lower for token in ("semen", "ejaculates", "cum ")):
|
|
tags.append("semen")
|
|
return tags
|
|
|
|
|
|
def softcore_pair_tags(row: dict[str, Any], root: dict[str, Any]) -> list[str]:
|
|
tags = ["softcore teaser", softcore_text_policy.softcore_style_tag()]
|
|
options = root.get("options") if isinstance(root.get("options"), dict) else {}
|
|
cast_mode = clean(options.get("softcore_cast")).lower()
|
|
if cast_mode == "same_as_hardcore" or root.get("shared_cast_descriptors"):
|
|
tags.append("same-cast creator frame")
|
|
elif "solo" in clean(row.get("subject_type") or row.get("primary_subject")).lower():
|
|
tags.append("solo creator frame")
|
|
return tags
|
|
|
|
|
|
def tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies:
|
|
return sdxl_tag_routes.SDXLTagRouteDependencies(
|
|
clean=clean,
|
|
row_value=row_value,
|
|
tag_key=tag_key,
|
|
add=add,
|
|
add_one=add_one,
|
|
count_tag=count_tag,
|
|
infer_counts=infer_counts,
|
|
normal_character_tags=normal_character_tags,
|
|
character_tags_from_descriptor=character_tags_from_descriptor,
|
|
metadata_family_tags=metadata_family_tags,
|
|
formatter_hint_tags=formatter_hint_tags,
|
|
axis_value_tags=axis_value_tags,
|
|
camera_tags=camera_tags,
|
|
explicit_tags=explicit_tags,
|
|
softcore_pair_tags=softcore_pair_tags,
|
|
)
|