diff --git a/category_extensions.py b/category_extensions.py new file mode 100644 index 0000000..db13530 --- /dev/null +++ b/category_extensions.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import json +from typing import Any + +try: + from . import category_library as category_policy + from . import generate_prompt_batches as g + from . import row_item as row_item_policy +except ImportError: # Allows local smoke tests with top-level imports. + import category_library as category_policy + import generate_prompt_batches as g + import row_item as row_item_policy + + +BUILTIN_CATEGORIES = [ + "auto_weighted", + "auto_full", + "woman", + "man", + "couple", + "group_or_layout", + "custom_random", +] + +_EXTENSIONS_APPLIED = False + + +def list_from(value: Any) -> list[Any]: + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + +def unique_extend(target: list[Any], additions: list[Any]) -> None: + seen = set() + for item in target: + try: + seen.add(json.dumps(item, sort_keys=True)) + except TypeError: + seen.add(repr(item)) + for item in additions: + try: + marker = json.dumps(item, sort_keys=True) + except TypeError: + marker = repr(item) + if marker not in seen: + target.append(item) + seen.add(marker) + + +def extension_targets() -> dict[str, tuple[list[Any], bool]]: + return { + "women_clothes": (g.WOMEN_CLOTHES, False), + "women_clothes_minimal": (g.WOMEN_CLOTHES_MINIMAL, False), + "men_clothes": (g.MEN_CLOTHES, False), + "men_clothes_minimal": (g.MEN_CLOTHES_MINIMAL, False), + "couple_outfits": (g.COUPLE_OUTFITS, False), + "couple_outfits_minimal": (g.COUPLE_OUTFITS_MINIMAL, False), + "poses": (g.POSES, False), + "evocative_poses": (g.EVOCATIVE_POSES, False), + "backside_poses": (g.BACKSIDE_POSES, False), + "expressions": (g.EXPRESSIONS, False), + "compositions": (g.COMPOSITIONS, False), + "props": (g.PROPS, False), + "figure_curvy": (g.FIGURE_CURVY, False), + "figure_athletic": (g.FIGURE_ATHLETIC, False), + "figure_bombshell": (g.FIGURE_BOMBSHELL, False), + "scenes": (g.SCENES, True), + "group_scenes": (g.GROUP_SCENES, True), + "layouts_full": (g.LAYOUTS_FULL, True), + "layouts_minimal": (g.LAYOUTS_MINIMAL, True), + "group_compositions": (g.GROUP_COMPOSITIONS, False), + "group_ages": (g.GROUP_AGES, False), + } + + +def apply_pool_extensions() -> None: + global _EXTENSIONS_APPLIED + if _EXTENSIONS_APPLIED: + return + targets = extension_targets() + for path in category_policy.category_json_files(): + data = category_policy.read_category_json(path) + extensions = data.get("pool_extensions", {}) + if not isinstance(extensions, dict): + raise ValueError(f"pool_extensions in {path} must be an object") + for target_name, additions in extensions.items(): + if target_name not in targets: + known = ", ".join(sorted(targets)) + raise ValueError(f"Unknown pool extension '{target_name}' in {path}. Known: {known}") + target, expects_pair = targets[target_name] + normalized = ( + [row_item_policy.pair_from(item) for item in list_from(additions)] + if expects_pair + else [row_item_policy.item_text(item) for item in list_from(additions)] + ) + unique_extend(target, normalized) + g.EVOCATIVE_ALL = g.EVOCATIVE_POSES + g.BACKSIDE_POSES + _EXTENSIONS_APPLIED = True + + +def category_choices() -> list[str]: + apply_pool_extensions() + custom = [category["name"] for category in category_policy.load_category_library()] + return BUILTIN_CATEGORIES + [name for name in custom if name not in BUILTIN_CATEGORIES] + + +def subcategory_choices() -> list[str]: + apply_pool_extensions() + choices = [category_policy.RANDOM_SUBCATEGORY] + for category in category_policy.load_category_library(): + for subcategory in category["subcategories"]: + choices.append(f"{category['name']} / {subcategory['name']}") + return choices diff --git a/docs/prompt-architecture-improvement-plan.md b/docs/prompt-architecture-improvement-plan.md index 9dbdb4a..dbf582f 100644 --- a/docs/prompt-architecture-improvement-plan.md +++ b/docs/prompt-architecture-improvement-plan.md @@ -123,6 +123,8 @@ Already isolated: - JSON category loading, subcategory normalization, named scene/expression/ composition pool loading, cast compatibility filtering, exact subcategory lookup, and inheritance-based pool merging live in `category_library.py`. +- JSON `pool_extensions`, legacy pool patching, built-in category choice lists, + and category/subcategory UI choices live in `category_extensions.py`. - object-style item-template metadata extraction, action/position family normalization, position-key normalization, and metadata audit errors live in `category_template_metadata.py`. diff --git a/docs/prompt-pool-routing-map.md b/docs/prompt-pool-routing-map.md index a12c0bd..5dbc45d 100644 --- a/docs/prompt-pool-routing-map.md +++ b/docs/prompt-pool-routing-map.md @@ -68,6 +68,7 @@ Core helper ownership: | Python module | What it owns | | --- | --- | | `category_library.py` | JSON category loading, subcategory normalization, named scene/expression/composition pool loading, cast compatibility filtering, exact subcategory lookup, and inheritance-based pool merging. | +| `category_extensions.py` | JSON `pool_extensions`, legacy pool patching, built-in category choice lists, and category/subcategory UI choices. | | `category_template_metadata.py` | Object-style item-template metadata extraction, action/position family normalization, position-key normalization, key merging, and audit validation errors. | | `row_item.py` | Row item selection, weighted item/pair choice, item-template axis filling, and oral/outercourse axis compatibility filters. | | `row_generation.py` | Built-in legacy row generation, auto-weighted/auto-full selection, row mode randomization, ratio clamps, and expression-intensity randomization. | @@ -246,7 +247,7 @@ Important JSON keys: - `prompt_template` / `caption_template`: final prompt assembly for that category. - `inherit_scenes`, `inherit_expressions`, `inherit_compositions`: stop or allow inheritance from category/subcategory/item levels. -- `pool_extensions`: patch legacy pools from JSON. +- `pool_extensions`: patch legacy pools from JSON through `category_extensions.py`. Current category/pool files: diff --git a/prompt_builder.py b/prompt_builder.py index 1525b1a..89f3224 100644 --- a/prompt_builder.py +++ b/prompt_builder.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import random import re from pathlib import Path @@ -9,16 +8,15 @@ from typing import Any try: from .category_library import ( - category_json_files as _json_files, compatible_entries as _compatible_entries, compatible_entry as _compatible_entry, find_subcategory as _find_subcategory, load_category_library, merged_field as _merged_field, - read_category_json as _read_json, ) from . import camera_config as camera_policy from . import cast_context as cast_context_policy + from . import category_extensions as category_extensions_policy from . import category_template_metadata as item_template_policy from . import character_appearance as character_appearance_policy from . import character_config as character_policy @@ -54,16 +52,15 @@ try: from .hardcore_role_graphs import build_hardcore_role_graph except ImportError: # Allows local smoke tests with `python -c`. from category_library import ( - category_json_files as _json_files, compatible_entries as _compatible_entries, compatible_entry as _compatible_entry, find_subcategory as _find_subcategory, load_category_library, merged_field as _merged_field, - read_category_json as _read_json, ) import camera_config as camera_policy import cast_context as cast_context_policy + import category_extensions as category_extensions_policy import category_template_metadata as item_template_policy import character_appearance as character_appearance_policy import character_config as character_policy @@ -102,15 +99,7 @@ except ImportError: # Allows local smoke tests with `python -c`. ROOT_DIR = Path(__file__).resolve().parent PROFILE_DIR = character_profile_policy.PROFILE_DIR -BUILTIN_CATEGORIES = [ - "auto_weighted", - "auto_full", - "woman", - "man", - "couple", - "group_or_layout", - "custom_random", -] +BUILTIN_CATEGORIES = category_extensions_policy.BUILTIN_CATEGORIES RANDOM_SUBCATEGORY = "random" SEED_AXIS_SALTS = seed_policy.SEED_AXIS_SALTS SEED_AXIS_ALIASES = seed_policy.SEED_AXIS_ALIASES @@ -197,9 +186,6 @@ CAMERA_PHONE_PROMPTS = camera_policy.CAMERA_PHONE_PROMPTS CAMERA_PRIORITY_PROMPTS = camera_policy.CAMERA_PRIORITY_PROMPTS -_EXTENSIONS_APPLIED = False - - class SafeFormatDict(dict): def __missing__(self, key: str) -> str: return "{" + key + "}" @@ -226,20 +212,7 @@ def _is_false(value: Any) -> bool: def _unique_extend(target: list[Any], additions: list[Any]) -> None: - seen = set() - for item in target: - try: - seen.add(json.dumps(item, sort_keys=True)) - except TypeError: - seen.add(repr(item)) - for item in additions: - try: - marker = json.dumps(item, sort_keys=True) - except TypeError: - marker = repr(item) - if marker not in seen: - target.append(item) - seen.add(marker) + category_extensions_policy.unique_extend(target, additions) def _pair_from(value: Any) -> tuple[str, str]: @@ -333,67 +306,19 @@ def _choose_pair(rng: random.Random, items: list[Any]) -> tuple[str, str]: def _extension_targets() -> dict[str, tuple[list[Any], bool]]: - return { - "women_clothes": (g.WOMEN_CLOTHES, False), - "women_clothes_minimal": (g.WOMEN_CLOTHES_MINIMAL, False), - "men_clothes": (g.MEN_CLOTHES, False), - "men_clothes_minimal": (g.MEN_CLOTHES_MINIMAL, False), - "couple_outfits": (g.COUPLE_OUTFITS, False), - "couple_outfits_minimal": (g.COUPLE_OUTFITS_MINIMAL, False), - "poses": (g.POSES, False), - "evocative_poses": (g.EVOCATIVE_POSES, False), - "backside_poses": (g.BACKSIDE_POSES, False), - "expressions": (g.EXPRESSIONS, False), - "compositions": (g.COMPOSITIONS, False), - "props": (g.PROPS, False), - "figure_curvy": (g.FIGURE_CURVY, False), - "figure_athletic": (g.FIGURE_ATHLETIC, False), - "figure_bombshell": (g.FIGURE_BOMBSHELL, False), - "scenes": (g.SCENES, True), - "group_scenes": (g.GROUP_SCENES, True), - "layouts_full": (g.LAYOUTS_FULL, True), - "layouts_minimal": (g.LAYOUTS_MINIMAL, True), - "group_compositions": (g.GROUP_COMPOSITIONS, False), - "group_ages": (g.GROUP_AGES, False), - } + return category_extensions_policy.extension_targets() def apply_pool_extensions() -> None: - global _EXTENSIONS_APPLIED - if _EXTENSIONS_APPLIED: - return - targets = _extension_targets() - for path in _json_files(): - data = _read_json(path) - extensions = data.get("pool_extensions", {}) - if not isinstance(extensions, dict): - raise ValueError(f"pool_extensions in {path} must be an object") - for target_name, additions in extensions.items(): - if target_name not in targets: - known = ", ".join(sorted(targets)) - raise ValueError(f"Unknown pool extension '{target_name}' in {path}. Known: {known}") - target, expects_pair = targets[target_name] - normalized = [_pair_from(item) for item in _list_from(additions)] if expects_pair else [ - _item_text(item) for item in _list_from(additions) - ] - _unique_extend(target, normalized) - g.EVOCATIVE_ALL = g.EVOCATIVE_POSES + g.BACKSIDE_POSES - _EXTENSIONS_APPLIED = True + category_extensions_policy.apply_pool_extensions() def category_choices() -> list[str]: - apply_pool_extensions() - custom = [category["name"] for category in load_category_library()] - return BUILTIN_CATEGORIES + [name for name in custom if name not in BUILTIN_CATEGORIES] + return category_extensions_policy.category_choices() def subcategory_choices() -> list[str]: - apply_pool_extensions() - choices = [RANDOM_SUBCATEGORY] - for category in load_category_library(): - for subcategory in category["subcategories"]: - choices.append(f"{category['name']} / {subcategory['name']}") - return choices + return category_extensions_policy.subcategory_choices() def seed_mode_choices() -> list[str]: diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index d43d871..c42f9a3 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -27,6 +27,7 @@ if str(ROOT) not in sys.path: import caption_naturalizer # noqa: E402 import caption_policy # noqa: E402 import cast_context # noqa: E402 +import category_extensions # noqa: E402 import category_template_metadata # noqa: E402 import character_appearance # noqa: E402 import character_config # noqa: E402 @@ -882,6 +883,29 @@ def smoke_row_generation_policy() -> None: ) +def smoke_category_extensions_policy() -> None: + _expect(pb.BUILTIN_CATEGORIES is category_extensions.BUILTIN_CATEGORIES, "Prompt builder built-in categories should come from category_extensions") + targets = category_extensions.extension_targets() + _expect( + pb._extension_targets().keys() == targets.keys(), + "Prompt builder extension targets should delegate to category_extensions", + ) + _expect(targets.get("group_scenes", (None, None))[1] is True, "Group scene extension target should expect scene pairs") + sample = [{"slug": "a", "prompt": "one"}] + category_extensions.unique_extend(sample, [{"slug": "a", "prompt": "one"}, {"slug": "b", "prompt": "two"}]) + _expect(len(sample) == 2 and sample[1]["slug"] == "b", "Category extension unique_extend changed") + + pb.apply_pool_extensions() + category_extensions.apply_pool_extensions() + _expect(pb.category_choices() == category_extensions.category_choices(), "Prompt builder category choices should delegate") + _expect(pb.subcategory_choices() == category_extensions.subcategory_choices(), "Prompt builder subcategory choices should delegate") + _expect("Hardcore sexual poses" in pb.category_choices(), "Category choices lost hardcore JSON category") + _expect( + any(slug == "private_suite_group_party" for slug, _prompt in targets["group_scenes"][0]), + "JSON pool_extensions did not reach legacy group scenes", + ) + + def smoke_category_cast_config_policy() -> None: _expect(pb.CATEGORY_PRESETS is category_cast_config.CATEGORY_PRESETS, "Prompt builder category presets are not delegated") _expect(pb.CAST_PRESETS is category_cast_config.CAST_PRESETS, "Prompt builder cast presets are not delegated") @@ -4083,6 +4107,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [ ("row_expression_policy", smoke_row_expression_policy), ("row_item_policy", smoke_row_item_policy), ("row_generation_policy", smoke_row_generation_policy), + ("category_extensions_policy", smoke_category_extensions_policy), ("category_cast_config_policy", smoke_category_cast_config_policy), ("generation_profile_config_policy", smoke_generation_profile_config_policy), ("filter_config_policy", smoke_filter_config_policy),