Extract SDXL tag route assembly

This commit is contained in:
2026-06-27 11:26:07 +02:00
parent 09fc31f078
commit 0ccb87799b
5 changed files with 281 additions and 103 deletions
+9 -2
View File
@@ -396,9 +396,16 @@ Keep here:
- trigger behavior; - trigger behavior;
- style and quality presets; - style and quality presets;
- tag ordering; - final style/body/quality prompt assembly;
- weighted explicit tags; - nude-weight setting and explicit-tag helper policy;
- negative-prompt assembly. - negative-prompt assembly.
Already isolated:
- `sdxl_tag_routes.py` owns normal metadata row tags and Insta/OF pair soft/hard
tag extraction behind `SDXLRowTagRequest`, `SDXLPairTagRequest`,
`SDXLTagRouteDependencies`, and `SDXLTagRoute`; `sdxl_formatter.py` keeps
compatibility wrappers plus final style/quality/trigger assembly.
- metadata-family tag hints from `action_family`, `position_family`, and - metadata-family tag hints from `action_family`, `position_family`, and
`position_keys`. `position_keys`.
- shared row route metadata reads from `route_metadata.py`. - shared row route metadata reads from `route_metadata.py`.
+7 -5
View File
@@ -699,8 +699,10 @@ not parse metadata. That is a wiring/input-hint issue, not a prompt pool issue.
`format_sdxl_prompt` chooses between: `format_sdxl_prompt` chooses between:
- Pair metadata: `_soft_tags` and `_hard_tags`. - Pair metadata: `sdxl_tag_routes.soft_tags_result` and
- Normal metadata row: `_row_core_tags`. `sdxl_tag_routes.hard_tags_result` through compatibility wrappers.
- Normal metadata row: `sdxl_tag_routes.row_core_tags_result` through the
`_row_core_tags` compatibility wrapper.
- Plain text fallback: `_fallback_text_to_sdxl`. - Plain text fallback: `_fallback_text_to_sdxl`.
Use this route for style triggers, weighted tag style, nude weighting, formatter Use this route for style triggers, weighted tag style, nude weighting, formatter
@@ -710,9 +712,9 @@ SDXL field consumption:
| Branch | Reads most from | Key functions | | Branch | Reads most from | Key functions |
| --- | --- | --- | | --- | --- | --- |
| Normal metadata | cast descriptors, age/body/skin/hair/eyes, `action_family`, `position_family`, `position_keys`, item, role graph, scene, camera config/directive | `_row_core_tags`, `_metadata_family_tags`, `_camera_tags` | | Normal metadata | cast descriptors, age/body/skin/hair/eyes, `action_family`, `position_family`, `position_keys`, item, role graph, scene, camera config/directive | `sdxl_tag_routes.row_core_tags_result`, `_metadata_family_tags`, `_camera_tags` |
| Pair softcore | `softcore_row`, pair partner styling, root soft camera config | `_soft_tags` | | Pair softcore | `softcore_row`, pair partner styling, root soft camera config | `sdxl_tag_routes.soft_tags_result` |
| Pair hardcore | `hardcore_row`, `action_family`, `position_family`, `position_keys`, `hardcore_clothing_state`, hard camera fields, hard prompt text | `_hard_tags`, `_metadata_family_tags` | | Pair hardcore | `hardcore_row`, `action_family`, `position_family`, `position_keys`, `hardcore_clothing_state`, hard camera fields, hard prompt text | `sdxl_tag_routes.hard_tags_result`, `_metadata_family_tags` |
| Text fallback | `source_text`, preserve-trigger setting, shared field-label stripping | `_fallback_text_to_sdxl` | | Text fallback | `source_text`, preserve-trigger setting, shared field-label stripping | `_fallback_text_to_sdxl` |
SDXL is the right place for model trigger handling, tag ordering, weight syntax, SDXL is the right place for model trigger handling, tag ordering, weight syntax,
+32 -96
View File
@@ -6,11 +6,13 @@ from typing import Any
try: try:
from . import formatter_input as input_policy from . import formatter_input as input_policy
from . import route_metadata as route_metadata_policy from . import route_metadata as route_metadata_policy
from . import sdxl_tag_routes
from . import sdxl_presets as sdxl_policy from . import sdxl_presets as sdxl_policy
from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
except ImportError: # Allows local smoke tests with `python -c`. except ImportError: # Allows local smoke tests with `python -c`.
import formatter_input as input_policy import formatter_input as input_policy
import route_metadata as route_metadata_policy import route_metadata as route_metadata_policy
import sdxl_tag_routes
import sdxl_presets as sdxl_policy import sdxl_presets as sdxl_policy
from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
@@ -283,43 +285,29 @@ def _explicit_tags(text: str, nude_weight: float) -> list[str]:
return tags return tags
def _sdxl_tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies:
return sdxl_tag_routes.SDXLTagRouteDependencies(
clean=_clean,
row_value=_row_value,
tag_key=_tag_key,
add=_add,
add_one=_add_one,
count_tag=_count_tag,
infer_counts=_infer_counts,
normal_character_tags=_normal_character_tags,
character_tags_from_descriptor=_character_tags_from_descriptor,
metadata_family_tags=_metadata_family_tags,
formatter_hint_tags=_formatter_hint_tags,
camera_tags=_camera_tags,
explicit_tags=_explicit_tags,
)
def _row_core_tags(row: dict[str, Any], nude_weight: float) -> list[str]: def _row_core_tags(row: dict[str, Any], nude_weight: float) -> list[str]:
tags: list[str] = [] return sdxl_tag_routes.row_core_tags(
seen: set[str] = set() sdxl_tag_routes.SDXLRowTagRequest(row, nude_weight),
women, men = _infer_counts(row) _sdxl_tag_route_dependencies(),
for tag in _count_tag(women, men): )
_add_one(tags, seen, tag)
for tag in _normal_character_tags(row):
_add_one(tags, seen, tag)
for tag in _metadata_family_tags(row):
_add_one(tags, seen, tag)
for tag in _formatter_hint_tags(row):
_add(tags, seen, tag)
item = _row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or _clean(row.get("custom_item"))
pose = _row_value(row, "pose", ("Sexual pose", "Pose"))
role_graph = _clean(row.get("source_role_graph") or row.get("role_graph"))
scene = _row_value(row, "scene_text", ("Setting", "Scene")) or _clean(row.get("scene"))
expression = _row_value(row, "character_expression_text") or _row_value(row, "expression", ("Facial expressions", "Facial expression"))
composition = _row_value(row, "composition", ("Composition",))
for value in (
item,
pose,
role_graph,
scene and f"in {scene}",
expression,
composition,
):
_add(tags, seen, value)
for tag in _camera_tags(row):
_add_one(tags, seen, tag)
combined = " ".join(_clean(value) for value in (item, pose, role_graph, row.get("prompt", "")))
for tag in _explicit_tags(combined, nude_weight):
_add_one(tags, seen, tag)
return tags
def _style_prefix(style_preset: str, trigger: str, prepend_trigger: bool, custom_style: str) -> str: def _style_prefix(style_preset: str, trigger: str, prepend_trigger: bool, custom_style: str) -> str:
@@ -341,69 +329,17 @@ def _quality_tail(quality_preset: str, custom_quality: str) -> str:
def _soft_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str: def _soft_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str:
tags = _row_core_tags(row, nude_weight) return sdxl_tag_routes.soft_tags(
seen = {_tag_key(tag) for tag in tags} sdxl_tag_routes.SDXLPairTagRequest(row, root, nude_weight),
for tag in _formatter_hint_tags(root): _sdxl_tag_route_dependencies(),
_add(tags, seen, tag) )
descriptor = _clean(root.get("shared_descriptor"))
if descriptor and not any("woman" in _tag_key(tag) for tag in tags):
for tag in _character_tags_from_descriptor(descriptor):
_add_one(tags, seen, tag)
partner = root.get("softcore_partner_styling")
if isinstance(partner, dict):
_add(tags, seen, "; ".join(_clean(item) for item in partner.get("outfits", []) if _clean(item)))
_add(tags, seen, partner.get("pose"))
_add_one(tags, seen, "sexy")
_add_one(tags, seen, "looking at viewer")
return ", ".join(tags)
def _hard_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str: def _hard_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str:
tags: list[str] = [] return sdxl_tag_routes.hard_tags(
seen: set[str] = set() sdxl_tag_routes.SDXLPairTagRequest(row, root, nude_weight),
try: _sdxl_tag_route_dependencies(),
women = int(root.get("hardcore_women_count") or row.get("women_count") or 1) )
men = int(root.get("hardcore_men_count") or row.get("men_count") or 1)
except (TypeError, ValueError):
women, men = 1, 1
for tag in _count_tag(women, men):
_add_one(tags, seen, tag)
descriptors = root.get("shared_cast_descriptors")
if isinstance(descriptors, list):
for descriptor in descriptors:
for tag in _character_tags_from_descriptor(descriptor):
_add_one(tags, seen, tag)
else:
for tag in _normal_character_tags(row):
_add_one(tags, seen, tag)
for tag in _metadata_family_tags(row):
_add_one(tags, seen, tag)
for tag in _formatter_hint_tags(row, root):
_add(tags, seen, tag)
hard_scene = _clean(row.get("scene_text"))
hard_item = _clean(row.get("item"))
hard_role = _clean(row.get("source_role_graph") or row.get("role_graph"))
hard_clothing = _clean(root.get("hardcore_clothing_state"))
expression = _clean(row.get("character_expression_text") or row.get("expression"))
composition = _clean(row.get("composition"))
for value in (
hard_role,
hard_item,
hard_clothing,
hard_scene and f"in {hard_scene}",
expression,
composition,
):
_add(tags, seen, value)
for tag in _camera_tags(row, root.get("hardcore_camera_directive"), root.get("hardcore_camera_config")):
_add_one(tags, seen, tag)
combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition, root.get("hardcore_prompt", "") or ""])
for tag in _explicit_tags(combined, nude_weight):
_add_one(tags, seen, tag)
return ", ".join(tags)
def _assemble_prompt( def _assemble_prompt(
+170
View File
@@ -0,0 +1,170 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
@dataclass(frozen=True)
class SDXLRowTagRequest:
row: dict[str, Any]
nude_weight: float
@dataclass(frozen=True)
class SDXLPairTagRequest:
row: dict[str, Any]
root: dict[str, Any]
nude_weight: float
@dataclass(frozen=True)
class SDXLTagRoute:
tags: list[str]
def as_text(self) -> str:
return ", ".join(self.tags)
@dataclass(frozen=True)
class SDXLTagRouteDependencies:
clean: Callable[[Any], str]
row_value: Callable[[dict[str, Any], str, tuple[str, ...]], str]
tag_key: Callable[[str], str]
add: Callable[[list[str], set[str], Any], None]
add_one: Callable[[list[str], set[str], str], None]
count_tag: Callable[[int, int], list[str]]
infer_counts: Callable[[dict[str, Any]], tuple[int, int]]
normal_character_tags: Callable[[dict[str, Any]], list[str]]
character_tags_from_descriptor: Callable[[Any], list[str]]
metadata_family_tags: Callable[[dict[str, Any]], list[str]]
formatter_hint_tags: Callable[..., list[str]]
camera_tags: Callable[..., list[str]]
explicit_tags: Callable[[str, float], list[str]]
def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
row = request.row
tags: list[str] = []
seen: set[str] = set()
women, men = deps.infer_counts(row)
for tag in deps.count_tag(women, men):
deps.add_one(tags, seen, tag)
for tag in deps.normal_character_tags(row):
deps.add_one(tags, seen, tag)
for tag in deps.metadata_family_tags(row):
deps.add_one(tags, seen, tag)
for tag in deps.formatter_hint_tags(row):
deps.add(tags, seen, tag)
item = deps.row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or deps.clean(
row.get("custom_item")
)
pose = deps.row_value(row, "pose", ("Sexual pose", "Pose"))
role_graph = deps.clean(row.get("source_role_graph") or row.get("role_graph"))
scene = deps.row_value(row, "scene_text", ("Setting", "Scene")) or deps.clean(row.get("scene"))
expression = deps.row_value(row, "character_expression_text") or deps.row_value(
row,
"expression",
("Facial expressions", "Facial expression"),
)
composition = deps.row_value(row, "composition", ("Composition",))
for value in (
item,
pose,
role_graph,
scene and f"in {scene}",
expression,
composition,
):
deps.add(tags, seen, value)
for tag in deps.camera_tags(row):
deps.add_one(tags, seen, tag)
combined = " ".join(deps.clean(value) for value in (item, pose, role_graph, row.get("prompt", "")))
for tag in deps.explicit_tags(combined, request.nude_weight):
deps.add_one(tags, seen, tag)
return SDXLTagRoute(tags)
def soft_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
row = request.row
root = request.root
tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags
seen = {deps.tag_key(tag) for tag in tags}
for tag in deps.formatter_hint_tags(root):
deps.add(tags, seen, tag)
descriptor = deps.clean(root.get("shared_descriptor"))
if descriptor and not any("woman" in deps.tag_key(tag) for tag in tags):
for tag in deps.character_tags_from_descriptor(descriptor):
deps.add_one(tags, seen, tag)
partner = root.get("softcore_partner_styling")
if isinstance(partner, dict):
deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item)))
deps.add(tags, seen, partner.get("pose"))
deps.add_one(tags, seen, "sexy")
deps.add_one(tags, seen, "looking at viewer")
return SDXLTagRoute(tags)
def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
row = request.row
root = request.root
tags: list[str] = []
seen: set[str] = set()
try:
women = int(root.get("hardcore_women_count") or row.get("women_count") or 1)
men = int(root.get("hardcore_men_count") or row.get("men_count") or 1)
except (TypeError, ValueError):
women, men = 1, 1
for tag in deps.count_tag(women, men):
deps.add_one(tags, seen, tag)
descriptors = root.get("shared_cast_descriptors")
if isinstance(descriptors, list):
for descriptor in descriptors:
for tag in deps.character_tags_from_descriptor(descriptor):
deps.add_one(tags, seen, tag)
else:
for tag in deps.normal_character_tags(row):
deps.add_one(tags, seen, tag)
for tag in deps.metadata_family_tags(row):
deps.add_one(tags, seen, tag)
for tag in deps.formatter_hint_tags(row, root):
deps.add(tags, seen, tag)
hard_scene = deps.clean(row.get("scene_text"))
hard_item = deps.clean(row.get("item"))
hard_role = deps.clean(row.get("source_role_graph") or row.get("role_graph"))
hard_clothing = deps.clean(root.get("hardcore_clothing_state"))
expression = deps.clean(row.get("character_expression_text") or row.get("expression"))
composition = deps.clean(row.get("composition"))
for value in (
hard_role,
hard_item,
hard_clothing,
hard_scene and f"in {hard_scene}",
expression,
composition,
):
deps.add(tags, seen, value)
for tag in deps.camera_tags(row, root.get("hardcore_camera_directive"), root.get("hardcore_camera_config")):
deps.add_one(tags, seen, tag)
combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition, root.get("hardcore_prompt", "") or ""])
for tag in deps.explicit_tags(combined, request.nude_weight):
deps.add_one(tags, seen, tag)
return SDXLTagRoute(tags)
def row_core_tags(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> list[str]:
return row_core_tags_result(request, deps).tags
def soft_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str:
return soft_tags_result(request, deps).as_text()
def hard_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str:
return hard_tags_result(request, deps).as_text()
+63
View File
@@ -73,6 +73,7 @@ import row_subject_route # noqa: E402
import server_routes # noqa: E402 import server_routes # noqa: E402
import sdxl_formatter # noqa: E402 import sdxl_formatter # noqa: E402
import sdxl_presets # noqa: E402 import sdxl_presets # noqa: E402
import sdxl_tag_routes # noqa: E402
import seed_config # noqa: E402 import seed_config # noqa: E402
import krea_pov # noqa: E402 import krea_pov # noqa: E402
import subject_context # noqa: E402 import subject_context # noqa: E402
@@ -2336,6 +2337,67 @@ def smoke_sdxl_presets_policy() -> None:
_expect("score_9" not in profiled_prompt, "SDXL photo profile should switch away from Pony score quality tail") _expect("score_9" not in profiled_prompt, "SDXL photo profile should switch away from Pony score quality tail")
def smoke_sdxl_tag_routes() -> None:
row = _fixture_hardcore_row(
formatter_hints={
"all": ["shared route anchor"],
"sdxl": ["sdxl route tag"],
}
)
deps = sdxl_formatter._sdxl_tag_route_dependencies()
typed_row = sdxl_tag_routes.row_core_tags_result(
sdxl_tag_routes.SDXLRowTagRequest(row, 1.29),
deps,
)
_expect(
typed_row.tags == sdxl_formatter._row_core_tags(row, 1.29),
"Typed SDXL row tag route should match legacy wrapper output",
)
_expect("sdxl route tag" in typed_row.as_text(), "Typed SDXL row tag route lost route-specific formatter hint")
pair = pb.build_insta_of_pair(
row_number=1,
start_index=1,
seed=3511,
ethnicity="any",
figure="random",
no_plus_women=False,
no_black=False,
trigger=Trigger,
prepend_trigger_to_prompt=True,
options_json=_insta_options(hardcore_clothing_continuity="partially_removed"),
character_cast=_character_cast(),
hardcore_position_config=_action_filter("penetration_only"),
)
_expect_pair(pair, "sdxl_tag_routes_pair")
soft_row = pair.get("softcore_row") if isinstance(pair.get("softcore_row"), dict) else {}
hard_row = pair.get("hardcore_row") if isinstance(pair.get("hardcore_row"), dict) else {}
typed_soft = sdxl_tag_routes.soft_tags_result(
sdxl_tag_routes.SDXLPairTagRequest(soft_row, pair, 1.29),
deps,
)
typed_hard = sdxl_tag_routes.hard_tags_result(
sdxl_tag_routes.SDXLPairTagRequest(hard_row, pair, 1.29),
deps,
)
_expect(
typed_soft.as_text() == sdxl_formatter._soft_tags(soft_row, pair, 1.29),
"Typed SDXL pair soft tag route should match legacy wrapper output",
)
_expect(
typed_hard.as_text() == sdxl_formatter._hard_tags(hard_row, pair, 1.29),
"Typed SDXL pair hard tag route should match legacy wrapper output",
)
formatted = sdxl_formatter.format_sdxl_prompt(
"",
metadata_json=_json(pair),
target="hardcore",
trigger=SdxlTrigger,
prepend_trigger=True,
)
_expect("sdxl(insta_of_pair)" in formatted.get("method", ""), "SDXL pair formatter route changed method")
def smoke_hardcore_position_config_policy() -> None: def smoke_hardcore_position_config_policy() -> None:
_expect( _expect(
pb.HARDCORE_POSITION_FAMILY_CHOICES is hardcore_position_config.HARDCORE_POSITION_FAMILY_CHOICES, pb.HARDCORE_POSITION_FAMILY_CHOICES is hardcore_position_config.HARDCORE_POSITION_FAMILY_CHOICES,
@@ -5090,6 +5152,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
("formatter_cast_policy", smoke_formatter_cast_policy), ("formatter_cast_policy", smoke_formatter_cast_policy),
("caption_policy", smoke_caption_policy), ("caption_policy", smoke_caption_policy),
("sdxl_presets_policy", smoke_sdxl_presets_policy), ("sdxl_presets_policy", smoke_sdxl_presets_policy),
("sdxl_tag_routes", smoke_sdxl_tag_routes),
("hardcore_position_config_policy", smoke_hardcore_position_config_policy), ("hardcore_position_config_policy", smoke_hardcore_position_config_policy),
("row_route_metadata_policy", smoke_row_route_metadata_policy), ("row_route_metadata_policy", smoke_row_route_metadata_policy),
("category_library_route", smoke_category_library_route), ("category_library_route", smoke_category_library_route),