Extract category extension policy

This commit is contained in:
2026-06-27 09:17:00 +02:00
parent 23bcb1b526
commit 3c1f6784c1
5 changed files with 154 additions and 84 deletions
+117
View File
@@ -0,0 +1,117 @@
from __future__ import annotations
import json
from typing import Any
try:
from . import category_library as category_policy
from . import generate_prompt_batches as g
from . import row_item as row_item_policy
except ImportError: # Allows local smoke tests with top-level imports.
import category_library as category_policy
import generate_prompt_batches as g
import row_item as row_item_policy
BUILTIN_CATEGORIES = [
"auto_weighted",
"auto_full",
"woman",
"man",
"couple",
"group_or_layout",
"custom_random",
]
_EXTENSIONS_APPLIED = False
def list_from(value: Any) -> list[Any]:
if value is None:
return []
if isinstance(value, list):
return value
return [value]
def unique_extend(target: list[Any], additions: list[Any]) -> None:
seen = set()
for item in target:
try:
seen.add(json.dumps(item, sort_keys=True))
except TypeError:
seen.add(repr(item))
for item in additions:
try:
marker = json.dumps(item, sort_keys=True)
except TypeError:
marker = repr(item)
if marker not in seen:
target.append(item)
seen.add(marker)
def extension_targets() -> dict[str, tuple[list[Any], bool]]:
return {
"women_clothes": (g.WOMEN_CLOTHES, False),
"women_clothes_minimal": (g.WOMEN_CLOTHES_MINIMAL, False),
"men_clothes": (g.MEN_CLOTHES, False),
"men_clothes_minimal": (g.MEN_CLOTHES_MINIMAL, False),
"couple_outfits": (g.COUPLE_OUTFITS, False),
"couple_outfits_minimal": (g.COUPLE_OUTFITS_MINIMAL, False),
"poses": (g.POSES, False),
"evocative_poses": (g.EVOCATIVE_POSES, False),
"backside_poses": (g.BACKSIDE_POSES, False),
"expressions": (g.EXPRESSIONS, False),
"compositions": (g.COMPOSITIONS, False),
"props": (g.PROPS, False),
"figure_curvy": (g.FIGURE_CURVY, False),
"figure_athletic": (g.FIGURE_ATHLETIC, False),
"figure_bombshell": (g.FIGURE_BOMBSHELL, False),
"scenes": (g.SCENES, True),
"group_scenes": (g.GROUP_SCENES, True),
"layouts_full": (g.LAYOUTS_FULL, True),
"layouts_minimal": (g.LAYOUTS_MINIMAL, True),
"group_compositions": (g.GROUP_COMPOSITIONS, False),
"group_ages": (g.GROUP_AGES, False),
}
def apply_pool_extensions() -> None:
global _EXTENSIONS_APPLIED
if _EXTENSIONS_APPLIED:
return
targets = extension_targets()
for path in category_policy.category_json_files():
data = category_policy.read_category_json(path)
extensions = data.get("pool_extensions", {})
if not isinstance(extensions, dict):
raise ValueError(f"pool_extensions in {path} must be an object")
for target_name, additions in extensions.items():
if target_name not in targets:
known = ", ".join(sorted(targets))
raise ValueError(f"Unknown pool extension '{target_name}' in {path}. Known: {known}")
target, expects_pair = targets[target_name]
normalized = (
[row_item_policy.pair_from(item) for item in list_from(additions)]
if expects_pair
else [row_item_policy.item_text(item) for item in list_from(additions)]
)
unique_extend(target, normalized)
g.EVOCATIVE_ALL = g.EVOCATIVE_POSES + g.BACKSIDE_POSES
_EXTENSIONS_APPLIED = True
def category_choices() -> list[str]:
apply_pool_extensions()
custom = [category["name"] for category in category_policy.load_category_library()]
return BUILTIN_CATEGORIES + [name for name in custom if name not in BUILTIN_CATEGORIES]
def subcategory_choices() -> list[str]:
apply_pool_extensions()
choices = [category_policy.RANDOM_SUBCATEGORY]
for category in category_policy.load_category_library():
for subcategory in category["subcategories"]:
choices.append(f"{category['name']} / {subcategory['name']}")
return choices
@@ -123,6 +123,8 @@ Already isolated:
- JSON category loading, subcategory normalization, named scene/expression/
composition pool loading, cast compatibility filtering, exact subcategory
lookup, and inheritance-based pool merging live in `category_library.py`.
- JSON `pool_extensions`, legacy pool patching, built-in category choice lists,
and category/subcategory UI choices live in `category_extensions.py`.
- object-style item-template metadata extraction, action/position family
normalization, position-key normalization, and metadata audit errors live in
`category_template_metadata.py`.
+2 -1
View File
@@ -68,6 +68,7 @@ Core helper ownership:
| Python module | What it owns |
| --- | --- |
| `category_library.py` | JSON category loading, subcategory normalization, named scene/expression/composition pool loading, cast compatibility filtering, exact subcategory lookup, and inheritance-based pool merging. |
| `category_extensions.py` | JSON `pool_extensions`, legacy pool patching, built-in category choice lists, and category/subcategory UI choices. |
| `category_template_metadata.py` | Object-style item-template metadata extraction, action/position family normalization, position-key normalization, key merging, and audit validation errors. |
| `row_item.py` | Row item selection, weighted item/pair choice, item-template axis filling, and oral/outercourse axis compatibility filters. |
| `row_generation.py` | Built-in legacy row generation, auto-weighted/auto-full selection, row mode randomization, ratio clamps, and expression-intensity randomization. |
@@ -246,7 +247,7 @@ Important JSON keys:
- `prompt_template` / `caption_template`: final prompt assembly for that category.
- `inherit_scenes`, `inherit_expressions`, `inherit_compositions`: stop or allow
inheritance from category/subcategory/item levels.
- `pool_extensions`: patch legacy pools from JSON.
- `pool_extensions`: patch legacy pools from JSON through `category_extensions.py`.
Current category/pool files:
+8 -83
View File
@@ -1,6 +1,5 @@
from __future__ import annotations
import json
import random
import re
from pathlib import Path
@@ -9,16 +8,15 @@ from typing import Any
try:
from .category_library import (
category_json_files as _json_files,
compatible_entries as _compatible_entries,
compatible_entry as _compatible_entry,
find_subcategory as _find_subcategory,
load_category_library,
merged_field as _merged_field,
read_category_json as _read_json,
)
from . import camera_config as camera_policy
from . import cast_context as cast_context_policy
from . import category_extensions as category_extensions_policy
from . import category_template_metadata as item_template_policy
from . import character_appearance as character_appearance_policy
from . import character_config as character_policy
@@ -54,16 +52,15 @@ try:
from .hardcore_role_graphs import build_hardcore_role_graph
except ImportError: # Allows local smoke tests with `python -c`.
from category_library import (
category_json_files as _json_files,
compatible_entries as _compatible_entries,
compatible_entry as _compatible_entry,
find_subcategory as _find_subcategory,
load_category_library,
merged_field as _merged_field,
read_category_json as _read_json,
)
import camera_config as camera_policy
import cast_context as cast_context_policy
import category_extensions as category_extensions_policy
import category_template_metadata as item_template_policy
import character_appearance as character_appearance_policy
import character_config as character_policy
@@ -102,15 +99,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
ROOT_DIR = Path(__file__).resolve().parent
PROFILE_DIR = character_profile_policy.PROFILE_DIR
BUILTIN_CATEGORIES = [
"auto_weighted",
"auto_full",
"woman",
"man",
"couple",
"group_or_layout",
"custom_random",
]
BUILTIN_CATEGORIES = category_extensions_policy.BUILTIN_CATEGORIES
RANDOM_SUBCATEGORY = "random"
SEED_AXIS_SALTS = seed_policy.SEED_AXIS_SALTS
SEED_AXIS_ALIASES = seed_policy.SEED_AXIS_ALIASES
@@ -197,9 +186,6 @@ CAMERA_PHONE_PROMPTS = camera_policy.CAMERA_PHONE_PROMPTS
CAMERA_PRIORITY_PROMPTS = camera_policy.CAMERA_PRIORITY_PROMPTS
_EXTENSIONS_APPLIED = False
class SafeFormatDict(dict):
def __missing__(self, key: str) -> str:
return "{" + key + "}"
@@ -226,20 +212,7 @@ def _is_false(value: Any) -> bool:
def _unique_extend(target: list[Any], additions: list[Any]) -> None:
seen = set()
for item in target:
try:
seen.add(json.dumps(item, sort_keys=True))
except TypeError:
seen.add(repr(item))
for item in additions:
try:
marker = json.dumps(item, sort_keys=True)
except TypeError:
marker = repr(item)
if marker not in seen:
target.append(item)
seen.add(marker)
category_extensions_policy.unique_extend(target, additions)
def _pair_from(value: Any) -> tuple[str, str]:
@@ -333,67 +306,19 @@ def _choose_pair(rng: random.Random, items: list[Any]) -> tuple[str, str]:
def _extension_targets() -> dict[str, tuple[list[Any], bool]]:
return {
"women_clothes": (g.WOMEN_CLOTHES, False),
"women_clothes_minimal": (g.WOMEN_CLOTHES_MINIMAL, False),
"men_clothes": (g.MEN_CLOTHES, False),
"men_clothes_minimal": (g.MEN_CLOTHES_MINIMAL, False),
"couple_outfits": (g.COUPLE_OUTFITS, False),
"couple_outfits_minimal": (g.COUPLE_OUTFITS_MINIMAL, False),
"poses": (g.POSES, False),
"evocative_poses": (g.EVOCATIVE_POSES, False),
"backside_poses": (g.BACKSIDE_POSES, False),
"expressions": (g.EXPRESSIONS, False),
"compositions": (g.COMPOSITIONS, False),
"props": (g.PROPS, False),
"figure_curvy": (g.FIGURE_CURVY, False),
"figure_athletic": (g.FIGURE_ATHLETIC, False),
"figure_bombshell": (g.FIGURE_BOMBSHELL, False),
"scenes": (g.SCENES, True),
"group_scenes": (g.GROUP_SCENES, True),
"layouts_full": (g.LAYOUTS_FULL, True),
"layouts_minimal": (g.LAYOUTS_MINIMAL, True),
"group_compositions": (g.GROUP_COMPOSITIONS, False),
"group_ages": (g.GROUP_AGES, False),
}
return category_extensions_policy.extension_targets()
def apply_pool_extensions() -> None:
global _EXTENSIONS_APPLIED
if _EXTENSIONS_APPLIED:
return
targets = _extension_targets()
for path in _json_files():
data = _read_json(path)
extensions = data.get("pool_extensions", {})
if not isinstance(extensions, dict):
raise ValueError(f"pool_extensions in {path} must be an object")
for target_name, additions in extensions.items():
if target_name not in targets:
known = ", ".join(sorted(targets))
raise ValueError(f"Unknown pool extension '{target_name}' in {path}. Known: {known}")
target, expects_pair = targets[target_name]
normalized = [_pair_from(item) for item in _list_from(additions)] if expects_pair else [
_item_text(item) for item in _list_from(additions)
]
_unique_extend(target, normalized)
g.EVOCATIVE_ALL = g.EVOCATIVE_POSES + g.BACKSIDE_POSES
_EXTENSIONS_APPLIED = True
category_extensions_policy.apply_pool_extensions()
def category_choices() -> list[str]:
apply_pool_extensions()
custom = [category["name"] for category in load_category_library()]
return BUILTIN_CATEGORIES + [name for name in custom if name not in BUILTIN_CATEGORIES]
return category_extensions_policy.category_choices()
def subcategory_choices() -> list[str]:
apply_pool_extensions()
choices = [RANDOM_SUBCATEGORY]
for category in load_category_library():
for subcategory in category["subcategories"]:
choices.append(f"{category['name']} / {subcategory['name']}")
return choices
return category_extensions_policy.subcategory_choices()
def seed_mode_choices() -> list[str]:
+25
View File
@@ -27,6 +27,7 @@ if str(ROOT) not in sys.path:
import caption_naturalizer # noqa: E402
import caption_policy # noqa: E402
import cast_context # noqa: E402
import category_extensions # noqa: E402
import category_template_metadata # noqa: E402
import character_appearance # noqa: E402
import character_config # noqa: E402
@@ -882,6 +883,29 @@ def smoke_row_generation_policy() -> None:
)
def smoke_category_extensions_policy() -> None:
_expect(pb.BUILTIN_CATEGORIES is category_extensions.BUILTIN_CATEGORIES, "Prompt builder built-in categories should come from category_extensions")
targets = category_extensions.extension_targets()
_expect(
pb._extension_targets().keys() == targets.keys(),
"Prompt builder extension targets should delegate to category_extensions",
)
_expect(targets.get("group_scenes", (None, None))[1] is True, "Group scene extension target should expect scene pairs")
sample = [{"slug": "a", "prompt": "one"}]
category_extensions.unique_extend(sample, [{"slug": "a", "prompt": "one"}, {"slug": "b", "prompt": "two"}])
_expect(len(sample) == 2 and sample[1]["slug"] == "b", "Category extension unique_extend changed")
pb.apply_pool_extensions()
category_extensions.apply_pool_extensions()
_expect(pb.category_choices() == category_extensions.category_choices(), "Prompt builder category choices should delegate")
_expect(pb.subcategory_choices() == category_extensions.subcategory_choices(), "Prompt builder subcategory choices should delegate")
_expect("Hardcore sexual poses" in pb.category_choices(), "Category choices lost hardcore JSON category")
_expect(
any(slug == "private_suite_group_party" for slug, _prompt in targets["group_scenes"][0]),
"JSON pool_extensions did not reach legacy group scenes",
)
def smoke_category_cast_config_policy() -> None:
_expect(pb.CATEGORY_PRESETS is category_cast_config.CATEGORY_PRESETS, "Prompt builder category presets are not delegated")
_expect(pb.CAST_PRESETS is category_cast_config.CAST_PRESETS, "Prompt builder cast presets are not delegated")
@@ -4083,6 +4107,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
("row_expression_policy", smoke_row_expression_policy),
("row_item_policy", smoke_row_item_policy),
("row_generation_policy", smoke_row_generation_policy),
("category_extensions_policy", smoke_category_extensions_policy),
("category_cast_config_policy", smoke_category_cast_config_policy),
("generation_profile_config_policy", smoke_generation_profile_config_policy),
("filter_config_policy", smoke_filter_config_policy),