Extract SDXL tag route assembly
This commit is contained in:
@@ -396,9 +396,16 @@ Keep here:
|
||||
|
||||
- trigger behavior;
|
||||
- style and quality presets;
|
||||
- tag ordering;
|
||||
- weighted explicit tags;
|
||||
- final style/body/quality prompt assembly;
|
||||
- nude-weight setting and explicit-tag helper policy;
|
||||
- negative-prompt assembly.
|
||||
|
||||
Already isolated:
|
||||
|
||||
- `sdxl_tag_routes.py` owns normal metadata row tags and Insta/OF pair soft/hard
|
||||
tag extraction behind `SDXLRowTagRequest`, `SDXLPairTagRequest`,
|
||||
`SDXLTagRouteDependencies`, and `SDXLTagRoute`; `sdxl_formatter.py` keeps
|
||||
compatibility wrappers plus final style/quality/trigger assembly.
|
||||
- metadata-family tag hints from `action_family`, `position_family`, and
|
||||
`position_keys`.
|
||||
- shared row route metadata reads from `route_metadata.py`.
|
||||
|
||||
@@ -699,8 +699,10 @@ not parse metadata. That is a wiring/input-hint issue, not a prompt pool issue.
|
||||
|
||||
`format_sdxl_prompt` chooses between:
|
||||
|
||||
- Pair metadata: `_soft_tags` and `_hard_tags`.
|
||||
- Normal metadata row: `_row_core_tags`.
|
||||
- Pair metadata: `sdxl_tag_routes.soft_tags_result` and
|
||||
`sdxl_tag_routes.hard_tags_result` through compatibility wrappers.
|
||||
- Normal metadata row: `sdxl_tag_routes.row_core_tags_result` through the
|
||||
`_row_core_tags` compatibility wrapper.
|
||||
- Plain text fallback: `_fallback_text_to_sdxl`.
|
||||
|
||||
Use this route for style triggers, weighted tag style, nude weighting, formatter
|
||||
@@ -710,9 +712,9 @@ SDXL field consumption:
|
||||
|
||||
| Branch | Reads most from | Key functions |
|
||||
| --- | --- | --- |
|
||||
| Normal metadata | cast descriptors, age/body/skin/hair/eyes, `action_family`, `position_family`, `position_keys`, item, role graph, scene, camera config/directive | `_row_core_tags`, `_metadata_family_tags`, `_camera_tags` |
|
||||
| Pair softcore | `softcore_row`, pair partner styling, root soft camera config | `_soft_tags` |
|
||||
| Pair hardcore | `hardcore_row`, `action_family`, `position_family`, `position_keys`, `hardcore_clothing_state`, hard camera fields, hard prompt text | `_hard_tags`, `_metadata_family_tags` |
|
||||
| Normal metadata | cast descriptors, age/body/skin/hair/eyes, `action_family`, `position_family`, `position_keys`, item, role graph, scene, camera config/directive | `sdxl_tag_routes.row_core_tags_result`, `_metadata_family_tags`, `_camera_tags` |
|
||||
| Pair softcore | `softcore_row`, pair partner styling, root soft camera config | `sdxl_tag_routes.soft_tags_result` |
|
||||
| Pair hardcore | `hardcore_row`, `action_family`, `position_family`, `position_keys`, `hardcore_clothing_state`, hard camera fields, hard prompt text | `sdxl_tag_routes.hard_tags_result`, `_metadata_family_tags` |
|
||||
| Text fallback | `source_text`, preserve-trigger setting, shared field-label stripping | `_fallback_text_to_sdxl` |
|
||||
|
||||
SDXL is the right place for model trigger handling, tag ordering, weight syntax,
|
||||
|
||||
+32
-96
@@ -6,11 +6,13 @@ from typing import Any
|
||||
try:
|
||||
from . import formatter_input as input_policy
|
||||
from . import route_metadata as route_metadata_policy
|
||||
from . import sdxl_tag_routes
|
||||
from . import sdxl_presets as sdxl_policy
|
||||
from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
|
||||
except ImportError: # Allows local smoke tests with `python -c`.
|
||||
import formatter_input as input_policy
|
||||
import route_metadata as route_metadata_policy
|
||||
import sdxl_tag_routes
|
||||
import sdxl_presets as sdxl_policy
|
||||
from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt
|
||||
|
||||
@@ -283,43 +285,29 @@ def _explicit_tags(text: str, nude_weight: float) -> list[str]:
|
||||
return tags
|
||||
|
||||
|
||||
def _sdxl_tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies:
|
||||
return sdxl_tag_routes.SDXLTagRouteDependencies(
|
||||
clean=_clean,
|
||||
row_value=_row_value,
|
||||
tag_key=_tag_key,
|
||||
add=_add,
|
||||
add_one=_add_one,
|
||||
count_tag=_count_tag,
|
||||
infer_counts=_infer_counts,
|
||||
normal_character_tags=_normal_character_tags,
|
||||
character_tags_from_descriptor=_character_tags_from_descriptor,
|
||||
metadata_family_tags=_metadata_family_tags,
|
||||
formatter_hint_tags=_formatter_hint_tags,
|
||||
camera_tags=_camera_tags,
|
||||
explicit_tags=_explicit_tags,
|
||||
)
|
||||
|
||||
|
||||
def _row_core_tags(row: dict[str, Any], nude_weight: float) -> list[str]:
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
women, men = _infer_counts(row)
|
||||
for tag in _count_tag(women, men):
|
||||
_add_one(tags, seen, tag)
|
||||
|
||||
for tag in _normal_character_tags(row):
|
||||
_add_one(tags, seen, tag)
|
||||
|
||||
for tag in _metadata_family_tags(row):
|
||||
_add_one(tags, seen, tag)
|
||||
for tag in _formatter_hint_tags(row):
|
||||
_add(tags, seen, tag)
|
||||
|
||||
item = _row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or _clean(row.get("custom_item"))
|
||||
pose = _row_value(row, "pose", ("Sexual pose", "Pose"))
|
||||
role_graph = _clean(row.get("source_role_graph") or row.get("role_graph"))
|
||||
scene = _row_value(row, "scene_text", ("Setting", "Scene")) or _clean(row.get("scene"))
|
||||
expression = _row_value(row, "character_expression_text") or _row_value(row, "expression", ("Facial expressions", "Facial expression"))
|
||||
composition = _row_value(row, "composition", ("Composition",))
|
||||
for value in (
|
||||
item,
|
||||
pose,
|
||||
role_graph,
|
||||
scene and f"in {scene}",
|
||||
expression,
|
||||
composition,
|
||||
):
|
||||
_add(tags, seen, value)
|
||||
for tag in _camera_tags(row):
|
||||
_add_one(tags, seen, tag)
|
||||
|
||||
combined = " ".join(_clean(value) for value in (item, pose, role_graph, row.get("prompt", "")))
|
||||
for tag in _explicit_tags(combined, nude_weight):
|
||||
_add_one(tags, seen, tag)
|
||||
return tags
|
||||
return sdxl_tag_routes.row_core_tags(
|
||||
sdxl_tag_routes.SDXLRowTagRequest(row, nude_weight),
|
||||
_sdxl_tag_route_dependencies(),
|
||||
)
|
||||
|
||||
|
||||
def _style_prefix(style_preset: str, trigger: str, prepend_trigger: bool, custom_style: str) -> str:
|
||||
@@ -341,69 +329,17 @@ def _quality_tail(quality_preset: str, custom_quality: str) -> str:
|
||||
|
||||
|
||||
def _soft_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str:
|
||||
tags = _row_core_tags(row, nude_weight)
|
||||
seen = {_tag_key(tag) for tag in tags}
|
||||
for tag in _formatter_hint_tags(root):
|
||||
_add(tags, seen, tag)
|
||||
descriptor = _clean(root.get("shared_descriptor"))
|
||||
if descriptor and not any("woman" in _tag_key(tag) for tag in tags):
|
||||
for tag in _character_tags_from_descriptor(descriptor):
|
||||
_add_one(tags, seen, tag)
|
||||
partner = root.get("softcore_partner_styling")
|
||||
if isinstance(partner, dict):
|
||||
_add(tags, seen, "; ".join(_clean(item) for item in partner.get("outfits", []) if _clean(item)))
|
||||
_add(tags, seen, partner.get("pose"))
|
||||
_add_one(tags, seen, "sexy")
|
||||
_add_one(tags, seen, "looking at viewer")
|
||||
return ", ".join(tags)
|
||||
return sdxl_tag_routes.soft_tags(
|
||||
sdxl_tag_routes.SDXLPairTagRequest(row, root, nude_weight),
|
||||
_sdxl_tag_route_dependencies(),
|
||||
)
|
||||
|
||||
|
||||
def _hard_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str:
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
try:
|
||||
women = int(root.get("hardcore_women_count") or row.get("women_count") or 1)
|
||||
men = int(root.get("hardcore_men_count") or row.get("men_count") or 1)
|
||||
except (TypeError, ValueError):
|
||||
women, men = 1, 1
|
||||
for tag in _count_tag(women, men):
|
||||
_add_one(tags, seen, tag)
|
||||
|
||||
descriptors = root.get("shared_cast_descriptors")
|
||||
if isinstance(descriptors, list):
|
||||
for descriptor in descriptors:
|
||||
for tag in _character_tags_from_descriptor(descriptor):
|
||||
_add_one(tags, seen, tag)
|
||||
else:
|
||||
for tag in _normal_character_tags(row):
|
||||
_add_one(tags, seen, tag)
|
||||
|
||||
for tag in _metadata_family_tags(row):
|
||||
_add_one(tags, seen, tag)
|
||||
for tag in _formatter_hint_tags(row, root):
|
||||
_add(tags, seen, tag)
|
||||
|
||||
hard_scene = _clean(row.get("scene_text"))
|
||||
hard_item = _clean(row.get("item"))
|
||||
hard_role = _clean(row.get("source_role_graph") or row.get("role_graph"))
|
||||
hard_clothing = _clean(root.get("hardcore_clothing_state"))
|
||||
expression = _clean(row.get("character_expression_text") or row.get("expression"))
|
||||
composition = _clean(row.get("composition"))
|
||||
for value in (
|
||||
hard_role,
|
||||
hard_item,
|
||||
hard_clothing,
|
||||
hard_scene and f"in {hard_scene}",
|
||||
expression,
|
||||
composition,
|
||||
):
|
||||
_add(tags, seen, value)
|
||||
for tag in _camera_tags(row, root.get("hardcore_camera_directive"), root.get("hardcore_camera_config")):
|
||||
_add_one(tags, seen, tag)
|
||||
combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition, root.get("hardcore_prompt", "") or ""])
|
||||
for tag in _explicit_tags(combined, nude_weight):
|
||||
_add_one(tags, seen, tag)
|
||||
return ", ".join(tags)
|
||||
return sdxl_tag_routes.hard_tags(
|
||||
sdxl_tag_routes.SDXLPairTagRequest(row, root, nude_weight),
|
||||
_sdxl_tag_route_dependencies(),
|
||||
)
|
||||
|
||||
|
||||
def _assemble_prompt(
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SDXLRowTagRequest:
|
||||
row: dict[str, Any]
|
||||
nude_weight: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SDXLPairTagRequest:
|
||||
row: dict[str, Any]
|
||||
root: dict[str, Any]
|
||||
nude_weight: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SDXLTagRoute:
|
||||
tags: list[str]
|
||||
|
||||
def as_text(self) -> str:
|
||||
return ", ".join(self.tags)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SDXLTagRouteDependencies:
|
||||
clean: Callable[[Any], str]
|
||||
row_value: Callable[[dict[str, Any], str, tuple[str, ...]], str]
|
||||
tag_key: Callable[[str], str]
|
||||
add: Callable[[list[str], set[str], Any], None]
|
||||
add_one: Callable[[list[str], set[str], str], None]
|
||||
count_tag: Callable[[int, int], list[str]]
|
||||
infer_counts: Callable[[dict[str, Any]], tuple[int, int]]
|
||||
normal_character_tags: Callable[[dict[str, Any]], list[str]]
|
||||
character_tags_from_descriptor: Callable[[Any], list[str]]
|
||||
metadata_family_tags: Callable[[dict[str, Any]], list[str]]
|
||||
formatter_hint_tags: Callable[..., list[str]]
|
||||
camera_tags: Callable[..., list[str]]
|
||||
explicit_tags: Callable[[str, float], list[str]]
|
||||
|
||||
|
||||
def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
|
||||
row = request.row
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
women, men = deps.infer_counts(row)
|
||||
for tag in deps.count_tag(women, men):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
for tag in deps.normal_character_tags(row):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
for tag in deps.metadata_family_tags(row):
|
||||
deps.add_one(tags, seen, tag)
|
||||
for tag in deps.formatter_hint_tags(row):
|
||||
deps.add(tags, seen, tag)
|
||||
|
||||
item = deps.row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or deps.clean(
|
||||
row.get("custom_item")
|
||||
)
|
||||
pose = deps.row_value(row, "pose", ("Sexual pose", "Pose"))
|
||||
role_graph = deps.clean(row.get("source_role_graph") or row.get("role_graph"))
|
||||
scene = deps.row_value(row, "scene_text", ("Setting", "Scene")) or deps.clean(row.get("scene"))
|
||||
expression = deps.row_value(row, "character_expression_text") or deps.row_value(
|
||||
row,
|
||||
"expression",
|
||||
("Facial expressions", "Facial expression"),
|
||||
)
|
||||
composition = deps.row_value(row, "composition", ("Composition",))
|
||||
for value in (
|
||||
item,
|
||||
pose,
|
||||
role_graph,
|
||||
scene and f"in {scene}",
|
||||
expression,
|
||||
composition,
|
||||
):
|
||||
deps.add(tags, seen, value)
|
||||
for tag in deps.camera_tags(row):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
combined = " ".join(deps.clean(value) for value in (item, pose, role_graph, row.get("prompt", "")))
|
||||
for tag in deps.explicit_tags(combined, request.nude_weight):
|
||||
deps.add_one(tags, seen, tag)
|
||||
return SDXLTagRoute(tags)
|
||||
|
||||
|
||||
def soft_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
|
||||
row = request.row
|
||||
root = request.root
|
||||
tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags
|
||||
seen = {deps.tag_key(tag) for tag in tags}
|
||||
for tag in deps.formatter_hint_tags(root):
|
||||
deps.add(tags, seen, tag)
|
||||
descriptor = deps.clean(root.get("shared_descriptor"))
|
||||
if descriptor and not any("woman" in deps.tag_key(tag) for tag in tags):
|
||||
for tag in deps.character_tags_from_descriptor(descriptor):
|
||||
deps.add_one(tags, seen, tag)
|
||||
partner = root.get("softcore_partner_styling")
|
||||
if isinstance(partner, dict):
|
||||
deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item)))
|
||||
deps.add(tags, seen, partner.get("pose"))
|
||||
deps.add_one(tags, seen, "sexy")
|
||||
deps.add_one(tags, seen, "looking at viewer")
|
||||
return SDXLTagRoute(tags)
|
||||
|
||||
|
||||
def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
|
||||
row = request.row
|
||||
root = request.root
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
try:
|
||||
women = int(root.get("hardcore_women_count") or row.get("women_count") or 1)
|
||||
men = int(root.get("hardcore_men_count") or row.get("men_count") or 1)
|
||||
except (TypeError, ValueError):
|
||||
women, men = 1, 1
|
||||
for tag in deps.count_tag(women, men):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
descriptors = root.get("shared_cast_descriptors")
|
||||
if isinstance(descriptors, list):
|
||||
for descriptor in descriptors:
|
||||
for tag in deps.character_tags_from_descriptor(descriptor):
|
||||
deps.add_one(tags, seen, tag)
|
||||
else:
|
||||
for tag in deps.normal_character_tags(row):
|
||||
deps.add_one(tags, seen, tag)
|
||||
|
||||
for tag in deps.metadata_family_tags(row):
|
||||
deps.add_one(tags, seen, tag)
|
||||
for tag in deps.formatter_hint_tags(row, root):
|
||||
deps.add(tags, seen, tag)
|
||||
|
||||
hard_scene = deps.clean(row.get("scene_text"))
|
||||
hard_item = deps.clean(row.get("item"))
|
||||
hard_role = deps.clean(row.get("source_role_graph") or row.get("role_graph"))
|
||||
hard_clothing = deps.clean(root.get("hardcore_clothing_state"))
|
||||
expression = deps.clean(row.get("character_expression_text") or row.get("expression"))
|
||||
composition = deps.clean(row.get("composition"))
|
||||
for value in (
|
||||
hard_role,
|
||||
hard_item,
|
||||
hard_clothing,
|
||||
hard_scene and f"in {hard_scene}",
|
||||
expression,
|
||||
composition,
|
||||
):
|
||||
deps.add(tags, seen, value)
|
||||
for tag in deps.camera_tags(row, root.get("hardcore_camera_directive"), root.get("hardcore_camera_config")):
|
||||
deps.add_one(tags, seen, tag)
|
||||
combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition, root.get("hardcore_prompt", "") or ""])
|
||||
for tag in deps.explicit_tags(combined, request.nude_weight):
|
||||
deps.add_one(tags, seen, tag)
|
||||
return SDXLTagRoute(tags)
|
||||
|
||||
|
||||
def row_core_tags(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> list[str]:
|
||||
return row_core_tags_result(request, deps).tags
|
||||
|
||||
|
||||
def soft_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str:
|
||||
return soft_tags_result(request, deps).as_text()
|
||||
|
||||
|
||||
def hard_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str:
|
||||
return hard_tags_result(request, deps).as_text()
|
||||
@@ -73,6 +73,7 @@ import row_subject_route # noqa: E402
|
||||
import server_routes # noqa: E402
|
||||
import sdxl_formatter # noqa: E402
|
||||
import sdxl_presets # noqa: E402
|
||||
import sdxl_tag_routes # noqa: E402
|
||||
import seed_config # noqa: E402
|
||||
import krea_pov # noqa: E402
|
||||
import subject_context # noqa: E402
|
||||
@@ -2336,6 +2337,67 @@ def smoke_sdxl_presets_policy() -> None:
|
||||
_expect("score_9" not in profiled_prompt, "SDXL photo profile should switch away from Pony score quality tail")
|
||||
|
||||
|
||||
def smoke_sdxl_tag_routes() -> None:
|
||||
row = _fixture_hardcore_row(
|
||||
formatter_hints={
|
||||
"all": ["shared route anchor"],
|
||||
"sdxl": ["sdxl route tag"],
|
||||
}
|
||||
)
|
||||
deps = sdxl_formatter._sdxl_tag_route_dependencies()
|
||||
typed_row = sdxl_tag_routes.row_core_tags_result(
|
||||
sdxl_tag_routes.SDXLRowTagRequest(row, 1.29),
|
||||
deps,
|
||||
)
|
||||
_expect(
|
||||
typed_row.tags == sdxl_formatter._row_core_tags(row, 1.29),
|
||||
"Typed SDXL row tag route should match legacy wrapper output",
|
||||
)
|
||||
_expect("sdxl route tag" in typed_row.as_text(), "Typed SDXL row tag route lost route-specific formatter hint")
|
||||
|
||||
pair = pb.build_insta_of_pair(
|
||||
row_number=1,
|
||||
start_index=1,
|
||||
seed=3511,
|
||||
ethnicity="any",
|
||||
figure="random",
|
||||
no_plus_women=False,
|
||||
no_black=False,
|
||||
trigger=Trigger,
|
||||
prepend_trigger_to_prompt=True,
|
||||
options_json=_insta_options(hardcore_clothing_continuity="partially_removed"),
|
||||
character_cast=_character_cast(),
|
||||
hardcore_position_config=_action_filter("penetration_only"),
|
||||
)
|
||||
_expect_pair(pair, "sdxl_tag_routes_pair")
|
||||
soft_row = pair.get("softcore_row") if isinstance(pair.get("softcore_row"), dict) else {}
|
||||
hard_row = pair.get("hardcore_row") if isinstance(pair.get("hardcore_row"), dict) else {}
|
||||
typed_soft = sdxl_tag_routes.soft_tags_result(
|
||||
sdxl_tag_routes.SDXLPairTagRequest(soft_row, pair, 1.29),
|
||||
deps,
|
||||
)
|
||||
typed_hard = sdxl_tag_routes.hard_tags_result(
|
||||
sdxl_tag_routes.SDXLPairTagRequest(hard_row, pair, 1.29),
|
||||
deps,
|
||||
)
|
||||
_expect(
|
||||
typed_soft.as_text() == sdxl_formatter._soft_tags(soft_row, pair, 1.29),
|
||||
"Typed SDXL pair soft tag route should match legacy wrapper output",
|
||||
)
|
||||
_expect(
|
||||
typed_hard.as_text() == sdxl_formatter._hard_tags(hard_row, pair, 1.29),
|
||||
"Typed SDXL pair hard tag route should match legacy wrapper output",
|
||||
)
|
||||
formatted = sdxl_formatter.format_sdxl_prompt(
|
||||
"",
|
||||
metadata_json=_json(pair),
|
||||
target="hardcore",
|
||||
trigger=SdxlTrigger,
|
||||
prepend_trigger=True,
|
||||
)
|
||||
_expect("sdxl(insta_of_pair)" in formatted.get("method", ""), "SDXL pair formatter route changed method")
|
||||
|
||||
|
||||
def smoke_hardcore_position_config_policy() -> None:
|
||||
_expect(
|
||||
pb.HARDCORE_POSITION_FAMILY_CHOICES is hardcore_position_config.HARDCORE_POSITION_FAMILY_CHOICES,
|
||||
@@ -5090,6 +5152,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
|
||||
("formatter_cast_policy", smoke_formatter_cast_policy),
|
||||
("caption_policy", smoke_caption_policy),
|
||||
("sdxl_presets_policy", smoke_sdxl_presets_policy),
|
||||
("sdxl_tag_routes", smoke_sdxl_tag_routes),
|
||||
("hardcore_position_config_policy", smoke_hardcore_position_config_policy),
|
||||
("row_route_metadata_policy", smoke_row_route_metadata_policy),
|
||||
("category_library_route", smoke_category_library_route),
|
||||
|
||||
Reference in New Issue
Block a user