Use hardcore family metadata in SDXL and captions

This commit is contained in:
2026-06-26 16:43:31 +02:00
parent 8668dfec9d
commit 2f7c359fab
5 changed files with 108 additions and 16 deletions
+50
View File
@@ -5,8 +5,10 @@ import re
from typing import Any
try:
from .hardcore_action_metadata import normalize_hardcore_action_family
from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
except ImportError: # Allows local smoke tests with `python -c`.
from hardcore_action_metadata import normalize_hardcore_action_family
from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
@@ -39,6 +41,28 @@ SDXL_DEFAULT_NEGATIVE = (
"watermark, signature, text, logo, blurry, jpeg artifacts, censored, mosaic censor"
)
SDXL_ACTION_FAMILY_TAGS = {
"foreplay": ("foreplay", "body contact"),
"outercourse": ("outercourse", "non-penetrative sex"),
"oral": ("oral sex",),
"penetration": ("penetrative sex", "penetration"),
"toy_double": ("double penetration", "toy-assisted sex"),
"climax": ("climax", "semen"),
}
SDXL_POSITION_FAMILY_TAGS = {
"penetrative": ("penetrative sex",),
"foreplay": ("foreplay",),
"interaction": ("interaction",),
"manual": ("manual stimulation",),
"oral": ("oral sex",),
"outercourse": ("outercourse",),
"anal": ("anal sex",),
"climax": ("climax",),
"threesome": ("threesome",),
"group": ("group sex",),
}
PROMPT_FIELD_LABELS = (
"Ages",
"Body types",
@@ -183,6 +207,26 @@ def _add_one(tags: list[str], seen: set[str], tag: str) -> None:
seen.add(key)
def _metadata_family_tags(row: dict[str, Any]) -> list[str]:
tags: list[str] = []
action_family = normalize_hardcore_action_family(row.get("action_family"))
tags.extend(SDXL_ACTION_FAMILY_TAGS.get(action_family, ()))
position_family = _clean(row.get("position_family")).lower()
tags.extend(SDXL_POSITION_FAMILY_TAGS.get(position_family, ()))
position_keys = row.get("position_keys")
if isinstance(position_keys, list):
keys = position_keys
else:
keys = [row.get("position_key")]
for key in keys:
key_text = _clean(key)
if key_text:
tags.append(key_text.replace("_", " "))
return tags
def _combine_tags(*parts: Any) -> str:
tags: list[str] = []
seen: set[str] = set()
@@ -332,6 +376,9 @@ def _row_core_tags(row: dict[str, Any], nude_weight: float) -> list[str]:
for tag in _normal_character_tags(row):
_add_one(tags, seen, tag)
for tag in _metadata_family_tags(row):
_add_one(tags, seen, tag)
item = _row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or _clean(row.get("custom_item"))
pose = _row_value(row, "pose", ("Sexual pose", "Pose"))
role_graph = _clean(row.get("source_role_graph") or row.get("role_graph"))
@@ -404,6 +451,9 @@ def _hard_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) ->
for tag in _normal_character_tags(row):
_add_one(tags, seen, tag)
for tag in _metadata_family_tags(row):
_add_one(tags, seen, tag)
hard_scene = _clean(row.get("scene_text"))
hard_item = _clean(row.get("item"))
hard_role = _clean(row.get("source_role_graph") or row.get("role_graph"))