diff --git a/docs/prompt-architecture-improvement-plan.md b/docs/prompt-architecture-improvement-plan.md index 9cdb6fa..4fec46c 100644 --- a/docs/prompt-architecture-improvement-plan.md +++ b/docs/prompt-architecture-improvement-plan.md @@ -396,9 +396,16 @@ Keep here: - trigger behavior; - style and quality presets; -- tag ordering; -- weighted explicit tags; +- final style/body/quality prompt assembly; +- nude-weight setting and explicit-tag helper policy; - negative-prompt assembly. + +Already isolated: + +- `sdxl_tag_routes.py` owns normal metadata row tags and Insta/OF pair soft/hard + tag extraction behind `SDXLRowTagRequest`, `SDXLPairTagRequest`, + `SDXLTagRouteDependencies`, and `SDXLTagRoute`; `sdxl_formatter.py` keeps + compatibility wrappers plus final style/quality/trigger assembly. - metadata-family tag hints from `action_family`, `position_family`, and `position_keys`. - shared row route metadata reads from `route_metadata.py`. diff --git a/docs/prompt-pool-routing-map.md b/docs/prompt-pool-routing-map.md index 0ae44c6..81874ec 100644 --- a/docs/prompt-pool-routing-map.md +++ b/docs/prompt-pool-routing-map.md @@ -699,8 +699,10 @@ not parse metadata. That is a wiring/input-hint issue, not a prompt pool issue. `format_sdxl_prompt` chooses between: -- Pair metadata: `_soft_tags` and `_hard_tags`. -- Normal metadata row: `_row_core_tags`. +- Pair metadata: `sdxl_tag_routes.soft_tags_result` and + `sdxl_tag_routes.hard_tags_result` through compatibility wrappers. +- Normal metadata row: `sdxl_tag_routes.row_core_tags_result` through the + `_row_core_tags` compatibility wrapper. - Plain text fallback: `_fallback_text_to_sdxl`. Use this route for style triggers, weighted tag style, nude weighting, formatter @@ -710,9 +712,9 @@ SDXL field consumption: | Branch | Reads most from | Key functions | | --- | --- | --- | -| Normal metadata | cast descriptors, age/body/skin/hair/eyes, `action_family`, `position_family`, `position_keys`, item, role graph, scene, camera config/directive | `_row_core_tags`, `_metadata_family_tags`, `_camera_tags` | -| Pair softcore | `softcore_row`, pair partner styling, root soft camera config | `_soft_tags` | -| Pair hardcore | `hardcore_row`, `action_family`, `position_family`, `position_keys`, `hardcore_clothing_state`, hard camera fields, hard prompt text | `_hard_tags`, `_metadata_family_tags` | +| Normal metadata | cast descriptors, age/body/skin/hair/eyes, `action_family`, `position_family`, `position_keys`, item, role graph, scene, camera config/directive | `sdxl_tag_routes.row_core_tags_result`, `_metadata_family_tags`, `_camera_tags` | +| Pair softcore | `softcore_row`, pair partner styling, root soft camera config | `sdxl_tag_routes.soft_tags_result` | +| Pair hardcore | `hardcore_row`, `action_family`, `position_family`, `position_keys`, `hardcore_clothing_state`, hard camera fields, hard prompt text | `sdxl_tag_routes.hard_tags_result`, `_metadata_family_tags` | | Text fallback | `source_text`, preserve-trigger setting, shared field-label stripping | `_fallback_text_to_sdxl` | SDXL is the right place for model trigger handling, tag ordering, weight syntax, diff --git a/sdxl_formatter.py b/sdxl_formatter.py index fe3bc76..2c94c97 100644 --- a/sdxl_formatter.py +++ b/sdxl_formatter.py @@ -6,11 +6,13 @@ from typing import Any try: from . import formatter_input as input_policy from . import route_metadata as route_metadata_policy + from . import sdxl_tag_routes from . import sdxl_presets as sdxl_policy from .prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt except ImportError: # Allows local smoke tests with `python -c`. import formatter_input as input_policy import route_metadata as route_metadata_policy + import sdxl_tag_routes import sdxl_presets as sdxl_policy from prompt_hygiene import sanitize_negative_text, sanitize_tag_prompt @@ -283,43 +285,29 @@ def _explicit_tags(text: str, nude_weight: float) -> list[str]: return tags +def _sdxl_tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies: + return sdxl_tag_routes.SDXLTagRouteDependencies( + clean=_clean, + row_value=_row_value, + tag_key=_tag_key, + add=_add, + add_one=_add_one, + count_tag=_count_tag, + infer_counts=_infer_counts, + normal_character_tags=_normal_character_tags, + character_tags_from_descriptor=_character_tags_from_descriptor, + metadata_family_tags=_metadata_family_tags, + formatter_hint_tags=_formatter_hint_tags, + camera_tags=_camera_tags, + explicit_tags=_explicit_tags, + ) + + def _row_core_tags(row: dict[str, Any], nude_weight: float) -> list[str]: - tags: list[str] = [] - seen: set[str] = set() - women, men = _infer_counts(row) - for tag in _count_tag(women, men): - _add_one(tags, seen, tag) - - for tag in _normal_character_tags(row): - _add_one(tags, seen, tag) - - for tag in _metadata_family_tags(row): - _add_one(tags, seen, tag) - for tag in _formatter_hint_tags(row): - _add(tags, seen, tag) - - item = _row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or _clean(row.get("custom_item")) - pose = _row_value(row, "pose", ("Sexual pose", "Pose")) - role_graph = _clean(row.get("source_role_graph") or row.get("role_graph")) - scene = _row_value(row, "scene_text", ("Setting", "Scene")) or _clean(row.get("scene")) - expression = _row_value(row, "character_expression_text") or _row_value(row, "expression", ("Facial expressions", "Facial expression")) - composition = _row_value(row, "composition", ("Composition",)) - for value in ( - item, - pose, - role_graph, - scene and f"in {scene}", - expression, - composition, - ): - _add(tags, seen, value) - for tag in _camera_tags(row): - _add_one(tags, seen, tag) - - combined = " ".join(_clean(value) for value in (item, pose, role_graph, row.get("prompt", ""))) - for tag in _explicit_tags(combined, nude_weight): - _add_one(tags, seen, tag) - return tags + return sdxl_tag_routes.row_core_tags( + sdxl_tag_routes.SDXLRowTagRequest(row, nude_weight), + _sdxl_tag_route_dependencies(), + ) def _style_prefix(style_preset: str, trigger: str, prepend_trigger: bool, custom_style: str) -> str: @@ -341,69 +329,17 @@ def _quality_tail(quality_preset: str, custom_quality: str) -> str: def _soft_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str: - tags = _row_core_tags(row, nude_weight) - seen = {_tag_key(tag) for tag in tags} - for tag in _formatter_hint_tags(root): - _add(tags, seen, tag) - descriptor = _clean(root.get("shared_descriptor")) - if descriptor and not any("woman" in _tag_key(tag) for tag in tags): - for tag in _character_tags_from_descriptor(descriptor): - _add_one(tags, seen, tag) - partner = root.get("softcore_partner_styling") - if isinstance(partner, dict): - _add(tags, seen, "; ".join(_clean(item) for item in partner.get("outfits", []) if _clean(item))) - _add(tags, seen, partner.get("pose")) - _add_one(tags, seen, "sexy") - _add_one(tags, seen, "looking at viewer") - return ", ".join(tags) + return sdxl_tag_routes.soft_tags( + sdxl_tag_routes.SDXLPairTagRequest(row, root, nude_weight), + _sdxl_tag_route_dependencies(), + ) def _hard_tags(row: dict[str, Any], root: dict[str, Any], nude_weight: float) -> str: - tags: list[str] = [] - seen: set[str] = set() - try: - women = int(root.get("hardcore_women_count") or row.get("women_count") or 1) - men = int(root.get("hardcore_men_count") or row.get("men_count") or 1) - except (TypeError, ValueError): - women, men = 1, 1 - for tag in _count_tag(women, men): - _add_one(tags, seen, tag) - - descriptors = root.get("shared_cast_descriptors") - if isinstance(descriptors, list): - for descriptor in descriptors: - for tag in _character_tags_from_descriptor(descriptor): - _add_one(tags, seen, tag) - else: - for tag in _normal_character_tags(row): - _add_one(tags, seen, tag) - - for tag in _metadata_family_tags(row): - _add_one(tags, seen, tag) - for tag in _formatter_hint_tags(row, root): - _add(tags, seen, tag) - - hard_scene = _clean(row.get("scene_text")) - hard_item = _clean(row.get("item")) - hard_role = _clean(row.get("source_role_graph") or row.get("role_graph")) - hard_clothing = _clean(root.get("hardcore_clothing_state")) - expression = _clean(row.get("character_expression_text") or row.get("expression")) - composition = _clean(row.get("composition")) - for value in ( - hard_role, - hard_item, - hard_clothing, - hard_scene and f"in {hard_scene}", - expression, - composition, - ): - _add(tags, seen, value) - for tag in _camera_tags(row, root.get("hardcore_camera_directive"), root.get("hardcore_camera_config")): - _add_one(tags, seen, tag) - combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition, root.get("hardcore_prompt", "") or ""]) - for tag in _explicit_tags(combined, nude_weight): - _add_one(tags, seen, tag) - return ", ".join(tags) + return sdxl_tag_routes.hard_tags( + sdxl_tag_routes.SDXLPairTagRequest(row, root, nude_weight), + _sdxl_tag_route_dependencies(), + ) def _assemble_prompt( diff --git a/sdxl_tag_routes.py b/sdxl_tag_routes.py new file mode 100644 index 0000000..bac967f --- /dev/null +++ b/sdxl_tag_routes.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + + +@dataclass(frozen=True) +class SDXLRowTagRequest: + row: dict[str, Any] + nude_weight: float + + +@dataclass(frozen=True) +class SDXLPairTagRequest: + row: dict[str, Any] + root: dict[str, Any] + nude_weight: float + + +@dataclass(frozen=True) +class SDXLTagRoute: + tags: list[str] + + def as_text(self) -> str: + return ", ".join(self.tags) + + +@dataclass(frozen=True) +class SDXLTagRouteDependencies: + clean: Callable[[Any], str] + row_value: Callable[[dict[str, Any], str, tuple[str, ...]], str] + tag_key: Callable[[str], str] + add: Callable[[list[str], set[str], Any], None] + add_one: Callable[[list[str], set[str], str], None] + count_tag: Callable[[int, int], list[str]] + infer_counts: Callable[[dict[str, Any]], tuple[int, int]] + normal_character_tags: Callable[[dict[str, Any]], list[str]] + character_tags_from_descriptor: Callable[[Any], list[str]] + metadata_family_tags: Callable[[dict[str, Any]], list[str]] + formatter_hint_tags: Callable[..., list[str]] + camera_tags: Callable[..., list[str]] + explicit_tags: Callable[[str, float], list[str]] + + +def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute: + row = request.row + tags: list[str] = [] + seen: set[str] = set() + women, men = deps.infer_counts(row) + for tag in deps.count_tag(women, men): + deps.add_one(tags, seen, tag) + + for tag in deps.normal_character_tags(row): + deps.add_one(tags, seen, tag) + + for tag in deps.metadata_family_tags(row): + deps.add_one(tags, seen, tag) + for tag in deps.formatter_hint_tags(row): + deps.add(tags, seen, tag) + + item = deps.row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or deps.clean( + row.get("custom_item") + ) + pose = deps.row_value(row, "pose", ("Sexual pose", "Pose")) + role_graph = deps.clean(row.get("source_role_graph") or row.get("role_graph")) + scene = deps.row_value(row, "scene_text", ("Setting", "Scene")) or deps.clean(row.get("scene")) + expression = deps.row_value(row, "character_expression_text") or deps.row_value( + row, + "expression", + ("Facial expressions", "Facial expression"), + ) + composition = deps.row_value(row, "composition", ("Composition",)) + for value in ( + item, + pose, + role_graph, + scene and f"in {scene}", + expression, + composition, + ): + deps.add(tags, seen, value) + for tag in deps.camera_tags(row): + deps.add_one(tags, seen, tag) + + combined = " ".join(deps.clean(value) for value in (item, pose, role_graph, row.get("prompt", ""))) + for tag in deps.explicit_tags(combined, request.nude_weight): + deps.add_one(tags, seen, tag) + return SDXLTagRoute(tags) + + +def soft_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute: + row = request.row + root = request.root + tags = row_core_tags_result(SDXLRowTagRequest(row, request.nude_weight), deps).tags + seen = {deps.tag_key(tag) for tag in tags} + for tag in deps.formatter_hint_tags(root): + deps.add(tags, seen, tag) + descriptor = deps.clean(root.get("shared_descriptor")) + if descriptor and not any("woman" in deps.tag_key(tag) for tag in tags): + for tag in deps.character_tags_from_descriptor(descriptor): + deps.add_one(tags, seen, tag) + partner = root.get("softcore_partner_styling") + if isinstance(partner, dict): + deps.add(tags, seen, "; ".join(deps.clean(item) for item in partner.get("outfits", []) if deps.clean(item))) + deps.add(tags, seen, partner.get("pose")) + deps.add_one(tags, seen, "sexy") + deps.add_one(tags, seen, "looking at viewer") + return SDXLTagRoute(tags) + + +def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute: + row = request.row + root = request.root + tags: list[str] = [] + seen: set[str] = set() + try: + women = int(root.get("hardcore_women_count") or row.get("women_count") or 1) + men = int(root.get("hardcore_men_count") or row.get("men_count") or 1) + except (TypeError, ValueError): + women, men = 1, 1 + for tag in deps.count_tag(women, men): + deps.add_one(tags, seen, tag) + + descriptors = root.get("shared_cast_descriptors") + if isinstance(descriptors, list): + for descriptor in descriptors: + for tag in deps.character_tags_from_descriptor(descriptor): + deps.add_one(tags, seen, tag) + else: + for tag in deps.normal_character_tags(row): + deps.add_one(tags, seen, tag) + + for tag in deps.metadata_family_tags(row): + deps.add_one(tags, seen, tag) + for tag in deps.formatter_hint_tags(row, root): + deps.add(tags, seen, tag) + + hard_scene = deps.clean(row.get("scene_text")) + hard_item = deps.clean(row.get("item")) + hard_role = deps.clean(row.get("source_role_graph") or row.get("role_graph")) + hard_clothing = deps.clean(root.get("hardcore_clothing_state")) + expression = deps.clean(row.get("character_expression_text") or row.get("expression")) + composition = deps.clean(row.get("composition")) + for value in ( + hard_role, + hard_item, + hard_clothing, + hard_scene and f"in {hard_scene}", + expression, + composition, + ): + deps.add(tags, seen, value) + for tag in deps.camera_tags(row, root.get("hardcore_camera_directive"), root.get("hardcore_camera_config")): + deps.add_one(tags, seen, tag) + combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition, root.get("hardcore_prompt", "") or ""]) + for tag in deps.explicit_tags(combined, request.nude_weight): + deps.add_one(tags, seen, tag) + return SDXLTagRoute(tags) + + +def row_core_tags(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> list[str]: + return row_core_tags_result(request, deps).tags + + +def soft_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str: + return soft_tags_result(request, deps).as_text() + + +def hard_tags(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies) -> str: + return hard_tags_result(request, deps).as_text() diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index b0e8b9e..74eeb7b 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -73,6 +73,7 @@ import row_subject_route # noqa: E402 import server_routes # noqa: E402 import sdxl_formatter # noqa: E402 import sdxl_presets # noqa: E402 +import sdxl_tag_routes # noqa: E402 import seed_config # noqa: E402 import krea_pov # noqa: E402 import subject_context # noqa: E402 @@ -2336,6 +2337,67 @@ def smoke_sdxl_presets_policy() -> None: _expect("score_9" not in profiled_prompt, "SDXL photo profile should switch away from Pony score quality tail") +def smoke_sdxl_tag_routes() -> None: + row = _fixture_hardcore_row( + formatter_hints={ + "all": ["shared route anchor"], + "sdxl": ["sdxl route tag"], + } + ) + deps = sdxl_formatter._sdxl_tag_route_dependencies() + typed_row = sdxl_tag_routes.row_core_tags_result( + sdxl_tag_routes.SDXLRowTagRequest(row, 1.29), + deps, + ) + _expect( + typed_row.tags == sdxl_formatter._row_core_tags(row, 1.29), + "Typed SDXL row tag route should match legacy wrapper output", + ) + _expect("sdxl route tag" in typed_row.as_text(), "Typed SDXL row tag route lost route-specific formatter hint") + + pair = pb.build_insta_of_pair( + row_number=1, + start_index=1, + seed=3511, + ethnicity="any", + figure="random", + no_plus_women=False, + no_black=False, + trigger=Trigger, + prepend_trigger_to_prompt=True, + options_json=_insta_options(hardcore_clothing_continuity="partially_removed"), + character_cast=_character_cast(), + hardcore_position_config=_action_filter("penetration_only"), + ) + _expect_pair(pair, "sdxl_tag_routes_pair") + soft_row = pair.get("softcore_row") if isinstance(pair.get("softcore_row"), dict) else {} + hard_row = pair.get("hardcore_row") if isinstance(pair.get("hardcore_row"), dict) else {} + typed_soft = sdxl_tag_routes.soft_tags_result( + sdxl_tag_routes.SDXLPairTagRequest(soft_row, pair, 1.29), + deps, + ) + typed_hard = sdxl_tag_routes.hard_tags_result( + sdxl_tag_routes.SDXLPairTagRequest(hard_row, pair, 1.29), + deps, + ) + _expect( + typed_soft.as_text() == sdxl_formatter._soft_tags(soft_row, pair, 1.29), + "Typed SDXL pair soft tag route should match legacy wrapper output", + ) + _expect( + typed_hard.as_text() == sdxl_formatter._hard_tags(hard_row, pair, 1.29), + "Typed SDXL pair hard tag route should match legacy wrapper output", + ) + formatted = sdxl_formatter.format_sdxl_prompt( + "", + metadata_json=_json(pair), + target="hardcore", + trigger=SdxlTrigger, + prepend_trigger=True, + ) + _expect("sdxl(insta_of_pair)" in formatted.get("method", ""), "SDXL pair formatter route changed method") + + def smoke_hardcore_position_config_policy() -> None: _expect( pb.HARDCORE_POSITION_FAMILY_CHOICES is hardcore_position_config.HARDCORE_POSITION_FAMILY_CHOICES, @@ -5090,6 +5152,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [ ("formatter_cast_policy", smoke_formatter_cast_policy), ("caption_policy", smoke_caption_policy), ("sdxl_presets_policy", smoke_sdxl_presets_policy), + ("sdxl_tag_routes", smoke_sdxl_tag_routes), ("hardcore_position_config_policy", smoke_hardcore_position_config_policy), ("row_route_metadata_policy", smoke_row_route_metadata_policy), ("category_library_route", smoke_category_library_route),