118 lines
4.0 KiB
Python
118 lines
4.0 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any
|
|
|
|
try:
|
|
from . import category_library as category_policy
|
|
from . import generate_prompt_batches as g
|
|
from . import row_item as row_item_policy
|
|
except ImportError: # Allows local smoke tests with top-level imports.
|
|
import category_library as category_policy
|
|
import generate_prompt_batches as g
|
|
import row_item as row_item_policy
|
|
|
|
|
|
BUILTIN_CATEGORIES = [
|
|
"auto_weighted",
|
|
"auto_full",
|
|
"woman",
|
|
"man",
|
|
"couple",
|
|
"group_or_layout",
|
|
"custom_random",
|
|
]
|
|
|
|
_EXTENSIONS_APPLIED = False
|
|
|
|
|
|
def list_from(value: Any) -> list[Any]:
|
|
if value is None:
|
|
return []
|
|
if isinstance(value, list):
|
|
return value
|
|
return [value]
|
|
|
|
|
|
def unique_extend(target: list[Any], additions: list[Any]) -> None:
|
|
seen = set()
|
|
for item in target:
|
|
try:
|
|
seen.add(json.dumps(item, sort_keys=True))
|
|
except TypeError:
|
|
seen.add(repr(item))
|
|
for item in additions:
|
|
try:
|
|
marker = json.dumps(item, sort_keys=True)
|
|
except TypeError:
|
|
marker = repr(item)
|
|
if marker not in seen:
|
|
target.append(item)
|
|
seen.add(marker)
|
|
|
|
|
|
def extension_targets() -> dict[str, tuple[list[Any], bool]]:
|
|
return {
|
|
"women_clothes": (g.WOMEN_CLOTHES, False),
|
|
"women_clothes_minimal": (g.WOMEN_CLOTHES_MINIMAL, False),
|
|
"men_clothes": (g.MEN_CLOTHES, False),
|
|
"men_clothes_minimal": (g.MEN_CLOTHES_MINIMAL, False),
|
|
"couple_outfits": (g.COUPLE_OUTFITS, False),
|
|
"couple_outfits_minimal": (g.COUPLE_OUTFITS_MINIMAL, False),
|
|
"poses": (g.POSES, False),
|
|
"evocative_poses": (g.EVOCATIVE_POSES, False),
|
|
"backside_poses": (g.BACKSIDE_POSES, False),
|
|
"expressions": (g.EXPRESSIONS, False),
|
|
"compositions": (g.COMPOSITIONS, False),
|
|
"props": (g.PROPS, False),
|
|
"figure_curvy": (g.FIGURE_CURVY, False),
|
|
"figure_athletic": (g.FIGURE_ATHLETIC, False),
|
|
"figure_bombshell": (g.FIGURE_BOMBSHELL, False),
|
|
"scenes": (g.SCENES, True),
|
|
"group_scenes": (g.GROUP_SCENES, True),
|
|
"layouts_full": (g.LAYOUTS_FULL, True),
|
|
"layouts_minimal": (g.LAYOUTS_MINIMAL, True),
|
|
"group_compositions": (g.GROUP_COMPOSITIONS, False),
|
|
"group_ages": (g.GROUP_AGES, False),
|
|
}
|
|
|
|
|
|
def apply_pool_extensions() -> None:
|
|
global _EXTENSIONS_APPLIED
|
|
if _EXTENSIONS_APPLIED:
|
|
return
|
|
targets = extension_targets()
|
|
for path in category_policy.category_json_files():
|
|
data = category_policy.read_category_json(path)
|
|
extensions = data.get("pool_extensions", {})
|
|
if not isinstance(extensions, dict):
|
|
raise ValueError(f"pool_extensions in {path} must be an object")
|
|
for target_name, additions in extensions.items():
|
|
if target_name not in targets:
|
|
known = ", ".join(sorted(targets))
|
|
raise ValueError(f"Unknown pool extension '{target_name}' in {path}. Known: {known}")
|
|
target, expects_pair = targets[target_name]
|
|
normalized = (
|
|
[row_item_policy.pair_from(item) for item in list_from(additions)]
|
|
if expects_pair
|
|
else [row_item_policy.item_text(item) for item in list_from(additions)]
|
|
)
|
|
unique_extend(target, normalized)
|
|
g.EVOCATIVE_ALL = g.EVOCATIVE_POSES + g.BACKSIDE_POSES
|
|
_EXTENSIONS_APPLIED = True
|
|
|
|
|
|
def category_choices() -> list[str]:
|
|
apply_pool_extensions()
|
|
custom = [category["name"] for category in category_policy.load_category_library()]
|
|
return BUILTIN_CATEGORIES + [name for name in custom if name not in BUILTIN_CATEGORIES]
|
|
|
|
|
|
def subcategory_choices() -> list[str]:
|
|
apply_pool_extensions()
|
|
choices = [category_policy.RANDOM_SUBCATEGORY]
|
|
for category in category_policy.load_category_library():
|
|
for subcategory in category["subcategories"]:
|
|
choices.append(category_policy.exact_subcategory_selector(category, subcategory))
|
|
return choices
|