From d31d513ec3b3091c9981820d1f2e494349eaf124 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 27 Jun 2026 09:42:16 +0200 Subject: [PATCH] Extract row category route policy --- docs/prompt-architecture-improvement-plan.md | 4 + docs/prompt-pool-routing-map.md | 15 +- prompt_builder.py | 108 +++++++------- row_category_route.py | 143 +++++++++++++++++++ tools/prompt_smoke.py | 53 +++++++ 5 files changed, 258 insertions(+), 65 deletions(-) create mode 100644 row_category_route.py diff --git a/docs/prompt-architecture-improvement-plan.md b/docs/prompt-architecture-improvement-plan.md index 11d1d1d..8483da7 100644 --- a/docs/prompt-architecture-improvement-plan.md +++ b/docs/prompt-architecture-improvement-plan.md @@ -131,6 +131,10 @@ Already isolated: - row item selection, weighted item/pair choice, item-template axis filling, and oral/outercourse axis compatibility filters live in `row_item.py`; `prompt_builder.py` keeps public delegate wrappers. +- row category/subcategory/item route resolution, hardcore position-category + filtering, cast-count adjustment, pose-vs-content seed-axis choice, item + metadata collection, and pose-category item sanitizing live in + `row_category_route.py`; `prompt_builder.py` keeps public delegate wrappers. - row prompt/caption template selection, safe formatting, default prompt templates, configured-cast descriptor insertion, and POV directive insertion live in `row_rendering.py`; `prompt_builder.py` keeps compatibility aliases. diff --git a/docs/prompt-pool-routing-map.md b/docs/prompt-pool-routing-map.md index cc3e309..498c0d6 100644 --- a/docs/prompt-pool-routing-map.md +++ b/docs/prompt-pool-routing-map.md @@ -71,6 +71,7 @@ Core helper ownership: | `category_extensions.py` | JSON `pool_extensions`, legacy pool patching, built-in category choice lists, and category/subcategory UI choices. | | `category_template_metadata.py` | Object-style item-template metadata extraction, action/position family normalization, position-key normalization, key merging, and audit validation errors. | | `row_item.py` | Row item selection, weighted item/pair choice, item-template axis filling, and oral/outercourse axis compatibility filters. | +| `row_category_route.py` | Row category/subcategory/item route resolution, hardcore position-category filtering, cast-count adjustment, pose-vs-content seed-axis choice, item metadata collection, and pose-category item sanitizing. | | `row_rendering.py` | Row prompt/caption template selection, safe formatting, default prompt templates, configured-cast descriptor insertion, and POV directive insertion. | | `row_route_metadata.py` | Row action/position route metadata resolution, template metadata precedence, inferred position-key merging, and source action-family fallback. | | `row_generation.py` | Built-in legacy row generation, auto-weighted/auto-full selection, row mode randomization, ratio clamps, and expression-intensity randomization. | @@ -469,13 +470,13 @@ plain prompt text. When debugging, inspect these fields before editing pools. | Field | Owner | Consumed by | Meaning | | --- | --- | --- | --- | | `source` | `build_prompt` / row builder | All formatters | Usually `json_category` or `built_in_generator`; tells which route created the row. | -| `main_category`, `subcategory` | Category selection | All formatters and debug | Human-readable selected category route. | -| `category_slug`, `subcategory_slug` | JSON category normalization | Debug/filtering | Stable-ish machine labels for selected category route. | -| `content_seed_axis` | `_build_custom_row` | Debug | Shows whether the item/action was driven by `content` or `pose`. Critical for hardcore pose categories. | -| `item` | `row_item.compose_item` or Insta override | Krea/SDXL/Naturalizer | Clothing item, category item, or sexual scene/action text. | -| `item_axis_values` | `row_item.compose_item` | Krea hardcore rewrite, SDXL tags | Filled template axes such as position/action/detail values. | -| `item_template_metadata` | `row_item.compose_item` | Debug, Krea/SDXL/Naturalizer route metadata | Optional metadata from object-style item templates; currently used to prefer explicit action/position families and keys before inference. | -| `formatter_hints` | `category_template_metadata.formatter_hints` | Krea/SDXL/Naturalizer route specialization, debug | Normalized route-specific hints from object-style item templates, keyed by `all`, `krea`, `sdxl`, or `caption`; each formatter consumes `all` plus its own route only. | +| `main_category`, `subcategory` | `row_category_route.select_category_item_route` | All formatters and debug | Human-readable selected category route. | +| `category_slug`, `subcategory_slug` | `row_category_route.select_category_item_route` | Debug/filtering | Stable-ish machine labels for selected category route. | +| `content_seed_axis` | `row_category_route.select_category_item_route` | Debug | Shows whether the item/action was driven by `content` or `pose`. Critical for hardcore pose categories. | +| `item` | `row_category_route.select_category_item_route` or Insta override | Krea/SDXL/Naturalizer | Clothing item, category item, or sexual scene/action text. | +| `item_axis_values` | `row_category_route.select_category_item_route` | Krea hardcore rewrite, SDXL tags | Filled template axes such as position/action/detail values. | +| `item_template_metadata` | `row_category_route.select_category_item_route` | Debug, Krea/SDXL/Naturalizer route metadata | Optional metadata from object-style item templates; currently used to prefer explicit action/position families and keys before inference. | +| `formatter_hints` | `row_category_route.select_category_item_route` | Krea/SDXL/Naturalizer route specialization, debug | Normalized route-specific hints from object-style item templates, keyed by `all`, `krea`, `sdxl`, or `caption`; each formatter consumes `all` plus its own route only. | | `action_family` | `row_route_metadata.resolve_action_position_route` | Krea hardcore rewrite, SDXL tags, natural captions, debug | Source-aware formatter semantic family such as `foreplay`, `outercourse`, `oral`, `penetration`, `toy_double`, or `climax`. | | `position_family` | `row_route_metadata.resolve_action_position_route` | Debug/filtering | Source/UI hardcore family selected by template metadata or subcategory, such as `manual`, `interaction`, `oral`, `anal`, or `climax`. | | `position_key`, `position_keys` | `row_route_metadata.resolve_action_position_route` | Debug/future filters | Concrete position tokens from object-template metadata and inferred axes/role text, such as `kneeling`, `doggy`, `boobjob`, or `open_thighs`. | diff --git a/prompt_builder.py b/prompt_builder.py index 63a33e9..8082dea 100644 --- a/prompt_builder.py +++ b/prompt_builder.py @@ -35,6 +35,7 @@ try: from . import pov_policy from . import row_normalization as row_policy from . import row_camera as row_camera_policy + from . import row_category_route as row_category_route_policy from . import row_expression as row_expression_policy from . import row_generation as row_generation_policy from . import row_item as row_item_policy @@ -80,6 +81,7 @@ except ImportError: # Allows local smoke tests with `python -c`. import pov_policy import row_normalization as row_policy import row_camera as row_camera_policy + import row_category_route as row_category_route_policy import row_expression as row_expression_policy import row_generation as row_generation_policy import row_item as row_item_policy @@ -772,18 +774,32 @@ def _axis_rng(seed_config: dict[str, int], axis: str, base_seed: int, row_number def _is_pose_content_category(category: dict[str, Any], subcategory: dict[str, Any]) -> bool: - haystack = " ".join( - str(value) - for value in ( - category.get("name", ""), - category.get("slug", ""), - category.get("item_label", ""), - subcategory.get("name", ""), - subcategory.get("slug", ""), - subcategory.get("item_label", ""), - ) - ).lower() - return "pose" in haystack or "sex" in haystack + return row_category_route_policy.is_pose_content_category(category, subcategory) + + +def _select_category_item_route( + *, + category_choice: str, + subcategory_choice: str, + seed_config: dict[str, int], + seed: int, + row_number: int, + women_count: int, + men_count: int, + hardcore_position_config: dict[str, Any] | None = None, + categories: list[dict[str, Any]] | None = None, +) -> dict[str, Any]: + return row_category_route_policy.select_category_item_route( + category_choice=category_choice, + subcategory_choice=subcategory_choice, + seed_config=seed_config, + seed=seed, + row_number=row_number, + women_count=women_count, + men_count=men_count, + hardcore_position_config=hardcore_position_config, + categories=categories, + ) def _format(template: str, context: dict[str, Any]) -> str: @@ -2004,9 +2020,6 @@ def _build_custom_row( location_config: str | dict[str, Any] | None = None, composition_config: str | dict[str, Any] | None = None, ) -> dict[str, Any]: - categories = load_category_library() - category_rng = _axis_rng(seed_config, "category", seed, row_number) - subcategory_rng = _axis_rng(seed_config, "subcategory", seed, row_number) person_rng = _axis_rng(seed_config, "person", seed, row_number) scene_rng = _axis_rng(seed_config, "scene", seed, row_number) pose_rng = _axis_rng(seed_config, "pose", seed, row_number) @@ -2017,50 +2030,29 @@ def _build_custom_row( parsed_location_config = _parse_location_config(location_config) parsed_composition_config = _parse_composition_config(composition_config) - requested_women_count = women_count - requested_men_count = men_count - categories = _filter_hardcore_categories_for_position( - categories, - parsed_hardcore_position_config, - women_count, - men_count, + category_route = _select_category_item_route( + category_choice=category_choice, + subcategory_choice=subcategory_choice, + seed_config=seed_config, + seed=seed, + row_number=row_number, + women_count=women_count, + men_count=men_count, + hardcore_position_config=parsed_hardcore_position_config, ) - category, subcategory, women_count, men_count = _find_subcategory( - categories, - category_choice, - subcategory_choice, - category_rng, - subcategory_rng, - women_count, - men_count, - ) - count_adjustment = {} - if women_count != requested_women_count or men_count != requested_men_count: - count_adjustment = { - "requested_women_count": requested_women_count, - "requested_men_count": requested_men_count, - "effective_women_count": women_count, - "effective_men_count": men_count, - } - if _is_hardcore_sexual_category(category): - subcategory = _apply_hardcore_position_config_to_subcategory(subcategory, parsed_hardcore_position_config) - content_axis = "pose" if _is_pose_content_category(category, subcategory) else "content" - content_rng = _axis_rng(seed_config, content_axis, seed, row_number) - items = _list_from(subcategory.get("items", [subcategory["name"]])) - item = _weighted_choice(content_rng, items) - item_text, item_name, item_axis_values, item_template_metadata = _compose_item( - content_rng, - category, - subcategory, - item, - women_count, - men_count, - ) - is_pose_category = _is_pose_content_category(category, subcategory) - if is_pose_category: - item_text = _sanitize_hardcore_environment_anchors(item_text) - item_axis_values = _sanitize_hardcore_axis_values(item_axis_values) - item_formatter_hints = _template_formatter_hints(item_template_metadata) + category = category_route["category"] + subcategory = category_route["subcategory"] + women_count = int(category_route["women_count"]) + men_count = int(category_route["men_count"]) + count_adjustment = dict(category_route.get("count_adjustment") or {}) + content_axis = str(category_route.get("content_axis") or "content") + item = category_route["item"] + item_text = str(category_route.get("item_text") or "") + item_name = str(category_route.get("item_name") or "") + item_axis_values = dict(category_route.get("item_axis_values") or {}) + item_template_metadata = dict(category_route.get("item_template_metadata") or {}) + item_formatter_hints = dict(category_route.get("formatter_hints") or {}) + is_pose_category = bool(category_route.get("is_pose_category")) subject_type = str(_merged_field(category, subcategory, item, "subject_type", "single_any")) context = _subject_context(person_rng, subject_type, ethnicity, figure, no_plus_women, no_black, women_count, men_count) character_slots = _parse_character_cast(character_cast) diff --git a/row_category_route.py b/row_category_route.py new file mode 100644 index 0000000..6985a19 --- /dev/null +++ b/row_category_route.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from typing import Any + +try: + from . import category_library as category_policy + from . import category_template_metadata as template_policy + from . import hardcore_position_config as hardcore_position_policy + from . import row_item as row_item_policy + from . import seed_config as seed_policy + from .hardcore_text_cleanup import ( + sanitize_hardcore_axis_values, + sanitize_hardcore_environment_anchors, + ) +except ImportError: # Allows local smoke tests from the repository root. + import category_library as category_policy + import category_template_metadata as template_policy + import hardcore_position_config as hardcore_position_policy + import row_item as row_item_policy + import seed_config as seed_policy + from hardcore_text_cleanup import ( + sanitize_hardcore_axis_values, + sanitize_hardcore_environment_anchors, + ) + + +def _list_from(value: Any) -> list[Any]: + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + +def is_pose_content_category(category: dict[str, Any], subcategory: dict[str, Any]) -> bool: + haystack = " ".join( + str(value) + for value in ( + category.get("name", ""), + category.get("slug", ""), + category.get("item_label", ""), + subcategory.get("name", ""), + subcategory.get("slug", ""), + subcategory.get("item_label", ""), + ) + ).lower() + return "pose" in haystack or "sex" in haystack + + +def cast_count_adjustment( + requested_women_count: int, + requested_men_count: int, + effective_women_count: int, + effective_men_count: int, +) -> dict[str, int]: + if requested_women_count == effective_women_count and requested_men_count == effective_men_count: + return {} + return { + "requested_women_count": requested_women_count, + "requested_men_count": requested_men_count, + "effective_women_count": effective_women_count, + "effective_men_count": effective_men_count, + } + + +def select_category_item_route( + *, + category_choice: str, + subcategory_choice: str, + seed_config: dict[str, int], + seed: int, + row_number: int, + women_count: int, + men_count: int, + hardcore_position_config: dict[str, Any] | None = None, + categories: list[dict[str, Any]] | None = None, +) -> dict[str, Any]: + source_categories = category_policy.load_category_library() if categories is None else categories + parsed_hardcore_position_config = hardcore_position_config or {} + requested_women_count = women_count + requested_men_count = men_count + + category_rng = seed_policy.axis_rng(seed_config, "category", seed, row_number) + subcategory_rng = seed_policy.axis_rng(seed_config, "subcategory", seed, row_number) + filtered_categories = hardcore_position_policy.filter_hardcore_categories_for_position( + source_categories, + parsed_hardcore_position_config, + women_count, + men_count, + category_policy.compatible_entry, + ) + category, subcategory, women_count, men_count = category_policy.find_subcategory( + filtered_categories, + category_choice, + subcategory_choice, + category_rng, + subcategory_rng, + women_count, + men_count, + ) + count_adjustment = cast_count_adjustment( + requested_women_count, + requested_men_count, + women_count, + men_count, + ) + if hardcore_position_policy.is_hardcore_sexual_category(category): + subcategory = hardcore_position_policy.apply_hardcore_position_config_to_subcategory( + subcategory, + parsed_hardcore_position_config, + ) + + is_pose_category = is_pose_content_category(category, subcategory) + content_axis = "pose" if is_pose_category else "content" + content_rng = seed_policy.axis_rng(seed_config, content_axis, seed, row_number) + item = row_item_policy.weighted_choice(content_rng, _list_from(subcategory.get("items", [subcategory["name"]]))) + item_text, item_name, item_axis_values, item_template_metadata = row_item_policy.compose_item( + content_rng, + category, + subcategory, + item, + women_count, + men_count, + ) + if is_pose_category: + item_text = sanitize_hardcore_environment_anchors(item_text) + item_axis_values = sanitize_hardcore_axis_values(item_axis_values) + + return { + "category": category, + "subcategory": subcategory, + "women_count": women_count, + "men_count": men_count, + "count_adjustment": count_adjustment, + "content_axis": content_axis, + "item": item, + "item_text": item_text, + "item_name": item_name, + "item_axis_values": item_axis_values, + "item_template_metadata": item_template_metadata, + "formatter_hints": template_policy.formatter_hints(item_template_metadata), + "is_pose_category": is_pose_category, + } diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 813c7f0..352d4e8 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -52,6 +52,7 @@ import pov_policy # noqa: E402 import row_normalization # noqa: E402 import route_metadata # noqa: E402 import row_camera # noqa: E402 +import row_category_route # noqa: E402 import row_expression # noqa: E402 import row_generation # noqa: E402 import row_item # noqa: E402 @@ -833,6 +834,57 @@ def smoke_row_item_policy() -> None: _expect(metadata.get("action_family") == "oral", "Row item compose lost template metadata") +def smoke_row_category_route_policy() -> None: + hard_config = hardcore_position_config.parse_hardcore_position_config(_position_filter("oral_only", "oral", ["kneeling"])) + seed_cfg = seed_config.parse_seed_config({}) + route = row_category_route.select_category_item_route( + category_choice="custom_random", + subcategory_choice="Hardcore sexual poses / Oral sex", + seed_config=seed_cfg, + seed=2301, + row_number=1, + women_count=1, + men_count=1, + hardcore_position_config=hard_config, + ) + delegated = pb._select_category_item_route( + category_choice="custom_random", + subcategory_choice="Hardcore sexual poses / Oral sex", + seed_config=seed_cfg, + seed=2301, + row_number=1, + women_count=1, + men_count=1, + hardcore_position_config=hard_config, + ) + _expect(delegated == route, "Prompt builder category/item route should delegate to row_category_route") + _expect(route["category"]["slug"] == "hardcore_sexual_poses", "Row category route selected wrong hardcore category") + _expect(route["subcategory"]["slug"] == "oral_sex", "Row category route selected wrong hardcore subcategory") + _expect(route["content_axis"] == "pose", "Hardcore pose category should use pose seed axis") + _expect(route["is_pose_category"] is True, "Hardcore pose category should be marked as pose content") + _expect(isinstance(route["item_axis_values"], dict), "Row category route lost item axis metadata") + _expect(isinstance(route["formatter_hints"], dict), "Row category route lost formatter hint metadata") + _expect( + pb._is_pose_content_category(route["category"], route["subcategory"]) + == row_category_route.is_pose_content_category(route["category"], route["subcategory"]), + "Prompt builder pose-content wrapper should delegate", + ) + + casual_route = row_category_route.select_category_item_route( + category_choice="custom_random", + subcategory_choice="Casual clothes / Streetwear", + seed_config=seed_cfg, + seed=2301, + row_number=1, + women_count=1, + men_count=0, + hardcore_position_config={}, + ) + _expect(casual_route["category"]["slug"] == "casual_clothes", "Row category route selected wrong casual category") + _expect(casual_route["content_axis"] == "content", "Non-pose category should use content seed axis") + _expect(casual_route["is_pose_category"] is False, "Non-pose category should not be marked as pose content") + + def smoke_row_generation_policy() -> None: _expect(pb._ratio_or_none(-1) is None, "Prompt builder ratio helper should treat negative as unset") _expect(pb._ratio_or_none(1.5) == row_generation.ratio_or_none(1.5) == 1.0, "Row generation ratio clamp changed") @@ -4257,6 +4309,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [ ("row_location_policy", smoke_row_location_policy), ("row_expression_policy", smoke_row_expression_policy), ("row_item_policy", smoke_row_item_policy), + ("row_category_route_policy", smoke_row_category_route_policy), ("row_generation_policy", smoke_row_generation_policy), ("category_extensions_policy", smoke_category_extensions_policy), ("category_cast_config_policy", smoke_category_cast_config_policy),