diff --git a/category_library.py b/category_library.py new file mode 100644 index 0000000..2092381 --- /dev/null +++ b/category_library.py @@ -0,0 +1,549 @@ +from __future__ import annotations + +import json +import random +import re +from pathlib import Path +from typing import Any + + +ROOT_DIR = Path(__file__).resolve().parent +CATEGORY_DIR = ROOT_DIR / "categories" +RANDOM_SUBCATEGORY = "random" + + +def category_json_files() -> list[Path]: + if not CATEGORY_DIR.exists(): + return [] + return sorted(path for path in CATEGORY_DIR.glob("*.json") if path.is_file()) + + +def read_category_json(path: Path) -> dict[str, Any]: + try: + data = json.loads(path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid JSON in {path}: {exc}") from exc + if not isinstance(data, dict): + raise ValueError(f"{path} must contain a JSON object") + return data + + +def _slug(value: str) -> str: + text = str(value or "").lower() + text = re.sub(r"[^a-z0-9]+", "_", text) + return text.strip("_")[:48] or "custom" + + +def _list_from(value: Any) -> list[Any]: + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + +def _is_false(value: Any) -> bool: + if isinstance(value, bool): + return value is False + if isinstance(value, str): + return value.strip().lower() in ("false", "0", "no", "off") + return False + + +def _entry_text(item: Any) -> str: + if isinstance(item, dict): + return str( + item.get("template") + or item.get("prompt") + or item.get("text") + or item.get("description") + or item.get("name") + or "" + ).strip() + return str(item).strip() + + +def _unique_extend(target: list[Any], additions: list[Any]) -> None: + seen = set() + for item in target: + try: + seen.add(json.dumps(item, sort_keys=True)) + except TypeError: + seen.add(repr(item)) + for item in additions: + try: + marker = json.dumps(item, sort_keys=True) + except TypeError: + marker = repr(item) + if marker not in seen: + target.append(item) + seen.add(marker) + + +def _weighted_choice(rng: random.Random, items: list[Any]) -> Any: + if not items: + raise ValueError("Cannot choose from an empty list") + weights: list[float] = [] + for item in items: + weight = item.get("weight", 1.0) if isinstance(item, dict) else 1.0 + try: + weights.append(max(0.0, float(weight))) + except (TypeError, ValueError): + weights.append(1.0) + total = sum(weights) + if total <= 0: + return items[rng.randrange(len(items))] + pick = rng.random() * total + running = 0.0 + for item, weight in zip(items, weights): + running += weight + if pick <= running: + return item + return items[-1] + + +def template_list(category: dict[str, Any], subcategory: dict[str, Any], item: Any, key: str) -> list[Any]: + if isinstance(item, dict) and key in item: + return _list_from(item[key]) + if key in subcategory: + return _list_from(subcategory[key]) + if key in category: + return _list_from(category[key]) + return [] + + +def _constraint_int(entry: dict[str, Any], key: str) -> int | None: + if key not in entry: + return None + try: + return int(entry[key]) + except (TypeError, ValueError): + return None + + +def _cast_requirement_matches(requirement: str, women_count: int, men_count: int) -> bool: + total = women_count + men_count + requirement = requirement.strip().lower() + if requirement in ("", "any"): + return True + if requirement == "women_only": + return women_count > 0 and men_count == 0 + if requirement == "men_only": + return men_count > 0 and women_count == 0 + if requirement == "mixed": + return women_count > 0 and men_count > 0 + if requirement == "has_women": + return women_count > 0 + if requirement == "has_men": + return men_count > 0 + if requirement == "solo": + return total == 1 + if requirement == "couple": + return total == 2 + if requirement == "threesome": + return total == 3 + if requirement == "group": + return total >= 4 + return True + + +def _is_toy_assisted_double_couple_text(text: str) -> bool: + text = text.lower() + if "toy" not in text: + return False + return any( + token in text + for token in ( + "double penetration", + "double-penetration", + "vaginal and anal penetration", + "second penetration point", + "second point of contact", + "second contact", + ) + ) + + +def _heuristic_cast_compatible(text: str, women_count: int, men_count: int) -> bool: + text = text.lower() + if not text: + return True + total = women_count + men_count + if total == 2 and women_count == 1 and men_count == 1: + if "{double_act}" in text: + return False + if _is_toy_assisted_double_couple_text(text): + return False + if total == 1: + solo_blocked_terms = ( + "partner", + "partners", + "two bodies", + "three bodies", + "bodies still pressed", + "bodies pressed", + "bodies tangled", + "wet bodies", + "chests heaving together", + "straddling a partner", + "shared climax", + "between two", + "from both sides", + "front-and-back", + "body contact", + ) + if any(term in text for term in solo_blocked_terms): + return False + solo_toy_terms = ("toy", "dildo", "finger", "fingers", "self") + if "penetration" in text and not any(term in text for term in solo_toy_terms): + return False + if total < 3 and "threesome" in text: + return False + if total != 3 and ("centered threesome" in text or "three-way" in text): + return False + if total < 3 and ("three bodies" in text or "center partner" in text or "center body" in text): + return False + if total < 4 and ("orgy" in text or "group sex" in text or "group-sex" in text or "group pile" in text): + return False + if total < 3 and ( + "double penetration" in text + or "two partners penetrating" in text + or "front-and-back penetration" in text + or "one penis in pussy and one penis in ass" in text + or "pussy and ass filled" in text + or "vaginal and anal penetration at the same time" in text + or "front-and-back double penetration" in text + or "hardcore double penetration" in text + or "kneeling double penetration" in text + or "standing supported double penetration" in text + or "deep double penetration" in text + or "between two partners" in text + or "from both sides" in text + ): + toy_terms = ("strap-on", "strap on", "dildo", "toy", "finger") + if not any(term in text for term in toy_terms): + return False + if men_count == 0: + toy_terms = ("strap-on", "strap on", "dildo", "toy", "finger", "fingers") + penetration_terms = ( + "vaginal penetration", + "deep vaginal sex", + "penetrative sex", + "pussy penetration", + "pussy stretched", + "vaginal thrusting", + "full-body penetrative", + "close-contact vaginal", + "penetration clearly visible", + "explicit penetrative contact", + ) + if any(term in text for term in penetration_terms) and not any(term in text for term in toy_terms): + return False + male_terms = ( + " penis", + "penis ", + "penises", + "cum", + "creampie", + "facial", + "blowjob", + "fellatio", + "deepthroat", + "ejaculation", + "semen", + ) + if any(term in text for term in male_terms) and not any(term in text for term in toy_terms): + return False + elif men_count < 2 and "penises" in text: + return False + if women_count == 0: + if "penetrative sex" in text and not any(term in text for term in ("anal", "ass", "male/male", "men")): + return False + female_terms = ( + "pussy", + "vaginal", + "vagina", + "cunnilingus", + "clit", + "clitoris", + "breasts", + "breast ", + "nipples", + "nipple", + "underboob", + ) + if any(term in text for term in female_terms): + return False + return True + + +def compatible_entry(entry: Any, women_count: int, men_count: int) -> bool: + if not isinstance(entry, dict): + return _heuristic_cast_compatible(_entry_text(entry), women_count, men_count) + total = women_count + men_count + for key, value in ( + ("min_women", women_count), + ("min_men", men_count), + ("min_people", total), + ): + minimum = _constraint_int(entry, key) + if minimum is not None and value < minimum: + return False + for key, value in ( + ("max_women", women_count), + ("max_men", men_count), + ("max_people", total), + ): + maximum = _constraint_int(entry, key) + if maximum is not None and value > maximum: + return False + requirements = _list_from(entry.get("cast", [])) + _list_from(entry.get("requires", [])) + if requirements and not all(_cast_requirement_matches(str(req), women_count, men_count) for req in requirements): + return False + if any(key in entry for key in ("subcategories", "item_templates", "item_axes")): + return True + return _heuristic_cast_compatible(_entry_text(entry), women_count, men_count) + + +def compatible_entries(entries: list[Any], women_count: int, men_count: int) -> list[Any]: + filtered = [entry for entry in entries if compatible_entry(entry, women_count, men_count)] + return filtered or entries + + +def merged_axes(category: dict[str, Any], subcategory: dict[str, Any], item: Any) -> dict[str, list[Any]]: + axes: dict[str, list[Any]] = {} + for source in (category, subcategory, item if isinstance(item, dict) else None): + if not isinstance(source, dict): + continue + raw_axes = source.get("item_axes", {}) + if raw_axes is None: + continue + if not isinstance(raw_axes, dict): + raise ValueError("item_axes must be a JSON object") + for key, values in raw_axes.items(): + axes[str(key)] = _list_from(values) + return axes + + +def _normalize_subcategories(category: dict[str, Any]) -> list[dict[str, Any]]: + raw = category.get("subcategories", []) + if isinstance(raw, dict): + raw = [ + {"name": name, **(value if isinstance(value, dict) else {"items": value})} + for name, value in raw.items() + ] + subcategories: list[dict[str, Any]] = [] + for entry in _list_from(raw): + if isinstance(entry, str): + sub = {"name": entry, "items": [entry]} + elif isinstance(entry, dict): + sub = dict(entry) + else: + raise ValueError(f"Subcategory must be an object or string: {entry!r}") + name = str(sub.get("name") or sub.get("slug") or "General").strip() + sub["name"] = name + sub["slug"] = str(sub.get("slug") or _slug(name)) + if "items" not in sub and "prompts" in sub: + sub["items"] = sub["prompts"] + if "items" not in sub: + sub["items"] = [name] + subcategories.append(sub) + if not subcategories: + name = str(category.get("name") or "General") + subcategories.append({"name": "General", "slug": "general", "items": [name]}) + return subcategories + + +def _normalize_categories(raw_categories: Any) -> list[dict[str, Any]]: + if isinstance(raw_categories, dict): + iterable = [ + {"name": name, **(value if isinstance(value, dict) else {"subcategories": value})} + for name, value in raw_categories.items() + ] + else: + iterable = _list_from(raw_categories) + + categories: list[dict[str, Any]] = [] + for entry in iterable: + if not isinstance(entry, dict): + raise ValueError(f"Category must be an object: {entry!r}") + category = dict(entry) + name = str(category.get("name") or category.get("slug") or "Custom").strip() + category["name"] = name + category["slug"] = str(category.get("slug") or _slug(name)) + category["subcategories"] = _normalize_subcategories(category) + categories.append(category) + return categories + + +def load_category_library() -> list[dict[str, Any]]: + categories: list[dict[str, Any]] = [] + for path in category_json_files(): + data = read_category_json(path) + categories.extend(_normalize_categories(data.get("categories", []))) + return categories + + +def load_named_pool_library(key: str) -> dict[str, list[Any]]: + pools: dict[str, list[Any]] = {} + for path in category_json_files(): + data = read_category_json(path) + raw_pools = data.get(key, {}) + if not raw_pools: + continue + if not isinstance(raw_pools, dict): + raise ValueError(f"{key} in {path} must be an object") + for name, entries in raw_pools.items(): + pool_name = str(name).strip() + if not pool_name: + continue + pools.setdefault(pool_name, []) + _unique_extend(pools[pool_name], _list_from(entries)) + return pools + + +def load_scene_pool_library() -> dict[str, list[Any]]: + return load_named_pool_library("scene_pools") + + +def load_expression_pool_library() -> dict[str, list[Any]]: + return load_named_pool_library("expression_pools") + + +def load_composition_pool_library() -> dict[str, list[Any]]: + return load_named_pool_library("composition_pools") + + +def find_category(categories: list[dict[str, Any]], name_or_slug: str) -> dict[str, Any] | None: + wanted = name_or_slug.strip().lower() + for category in categories: + if category["name"].lower() == wanted or category["slug"].lower() == wanted: + return category + return None + + +def _base_cast_counts(women_count: int, men_count: int) -> tuple[int, int]: + women_count = max(0, int(women_count)) + men_count = max(0, int(men_count)) + if women_count + men_count == 0: + women_count = 1 + return women_count, men_count + + +def _counts_for_exact_subcategory( + subcategory: dict[str, Any], + women_count: int, + men_count: int, +) -> tuple[int, int]: + women_count, men_count = _base_cast_counts(women_count, men_count) + + min_women = _constraint_int(subcategory, "min_women") + if min_women is not None and women_count < min_women: + women_count = min_women + min_men = _constraint_int(subcategory, "min_men") + if min_men is not None and men_count < min_men: + men_count = min_men + + min_people = _constraint_int(subcategory, "min_people") + if min_people is not None: + missing = min_people - (women_count + men_count) + if missing > 0: + if women_count > 0 or men_count == 0: + women_count += missing + else: + men_count += missing + return women_count, men_count + + +def find_subcategory( + categories: list[dict[str, Any]], + category_choice: str, + subcategory_choice: str, + category_rng: random.Random, + subcategory_rng: random.Random, + women_count: int = 1, + men_count: int = 1, + random_subcategory: str = RANDOM_SUBCATEGORY, +) -> tuple[dict[str, Any], dict[str, Any], int, int]: + women_count, men_count = _base_cast_counts(women_count, men_count) + if subcategory_choice and subcategory_choice != random_subcategory and " / " in subcategory_choice: + category_name, subcategory_name = subcategory_choice.split(" / ", 1) + category = find_category(categories, category_name) + if not category: + raise ValueError(f"Unknown category in subcategory picker: {category_name}") + wanted = subcategory_name.strip().lower() + for subcategory in category["subcategories"]: + if subcategory["name"].lower() == wanted or subcategory["slug"].lower() == wanted: + adjusted_women_count, adjusted_men_count = _counts_for_exact_subcategory( + subcategory, + women_count, + men_count, + ) + if not compatible_entry(subcategory, adjusted_women_count, adjusted_men_count): + raise ValueError( + f"Subcategory '{subcategory['name']}' is not compatible with " + f"women_count={women_count}, men_count={men_count}" + ) + return category, subcategory, adjusted_women_count, adjusted_men_count + raise ValueError(f"Unknown subcategory '{subcategory_name}' for category '{category_name}'") + + if category_choice == "custom_random": + if not categories: + raise ValueError("No custom categories found in categories/*.json") + category = _weighted_choice(category_rng, categories) + else: + category = find_category(categories, category_choice) + if not category: + raise ValueError(f"Unknown custom category: {category_choice}") + subcategories = compatible_entries(category["subcategories"], women_count, men_count) + subcategory = _weighted_choice(subcategory_rng, subcategories) + return category, subcategory, women_count, men_count + + +def merged_field(category: dict[str, Any], subcategory: dict[str, Any], item: Any, key: str, default: Any = None) -> Any: + if isinstance(item, dict) and key in item: + return item[key] + if key in subcategory: + return subcategory[key] + if key in category: + return category[key] + return default + + +def _sources_with_inheritance( + category: dict[str, Any], + subcategory: dict[str, Any], + item: Any, + inherit_key: str, +) -> tuple[Any, ...]: + item_source = item if isinstance(item, dict) else None + if item_source is not None and _is_false(item_source.get(inherit_key)): + return (item_source,) + if _is_false(subcategory.get(inherit_key)): + return (subcategory, item_source) + return (category, subcategory, item_source) + + +def configured_pool( + category: dict[str, Any], + subcategory: dict[str, Any], + item: Any, + direct_key: str, + pool_key: str, + pool_library: dict[str, list[Any]], + inherit_key: str, +) -> list[Any]: + entries: list[Any] = [] + singular_pool_key = pool_key[:-1] if pool_key.endswith("s") else pool_key + for source in _sources_with_inheritance(category, subcategory, item, inherit_key): + if not isinstance(source, dict): + continue + if direct_key in source: + _unique_extend(entries, _list_from(source[direct_key])) + refs = _list_from(source.get(singular_pool_key)) + _list_from(source.get(pool_key)) + for ref in refs: + ref_name = str(ref).strip() + if ref_name not in pool_library: + raise ValueError(f"Unknown {singular_pool_key} '{ref_name}'") + _unique_extend(entries, pool_library[ref_name]) + return entries diff --git a/docs/prompt-architecture-improvement-plan.md b/docs/prompt-architecture-improvement-plan.md index 7375417..1aca917 100644 --- a/docs/prompt-architecture-improvement-plan.md +++ b/docs/prompt-architecture-improvement-plan.md @@ -95,10 +95,14 @@ Keep here: Move or isolate later: -- category-library loading and inheritance helpers into `category_library.py`. +- pair assembly and camera mutation helpers that still live in + `prompt_builder.py`. Already isolated: +- JSON category loading, subcategory normalization, named scene/expression/ + composition pool loading, cast compatibility filtering, exact subcategory + lookup, and inheritance-based pool merging live in `category_library.py`. - hardcore configured-cast role graph generation lives in `hardcore_role_graphs.py`; `prompt_builder.py` selects item/axis metadata and then asks that module for the source role graph. diff --git a/docs/prompt-pool-routing-map.md b/docs/prompt-pool-routing-map.md index dc810f8..fbe29ce 100644 --- a/docs/prompt-pool-routing-map.md +++ b/docs/prompt-pool-routing-map.md @@ -64,6 +64,7 @@ Core helper ownership: | Python module | What it owns | | --- | --- | +| `category_library.py` | JSON category loading, subcategory normalization, named scene/expression/composition pool loading, cast compatibility filtering, exact subcategory lookup, and inheritance-based pool merging. | | `hardcore_role_graphs.py` | Source role graph construction for hardcore configured-cast rows, including POV-aware interaction geometry. | | `hardcore_role_fallback.py` | Solo, same-sex, mixed group fallback, and support-partner role graph wording for configured casts. | | `hardcore_role_interaction.py` | Foreplay, manual stimulation, body worship, clothing transition, dominant guidance, camera performance, aftercare, and group coordination role graph wording. | @@ -161,7 +162,7 @@ There are two category systems. | Source | Files/functions | Notes | | --- | --- | --- | | Built-in legacy generator | `generate_prompt_batches.py`, `_build_direct_builtin_row`, `_build_auto_weighted_row` | Handles legacy `woman`, `man`, `couple`, `group_or_layout`, `auto_weighted`, and `auto_full`. | -| JSON category library | `categories/*.json`, `load_category_library`, `_build_custom_row` | Handles expandable categories such as casual clothes, erotic clothes, and hardcore sexual poses. | +| JSON category library | `categories/*.json`, `category_library.load_category_library`, `_build_custom_row` | Handles expandable categories such as casual clothes, erotic clothes, and hardcore sexual poses. | JSON categories are the scalable system. Add new main categories or subcategories there unless the behavior needs Python logic. @@ -768,7 +769,7 @@ pair metadata through the core Python APIs, then verifies: | Symptom | First file/function to inspect | | --- | --- | -| Wrong main category/subcategory frequency | Category node config, `load_category_library`, category JSON weights. | +| Wrong main category/subcategory frequency | Category node config, `category_library.load_category_library`, category JSON weights. | | Wrong outfit/clothing item | Relevant category JSON, `INSTA_OF_SOFTCORE_OUTFITS`, `SxCP Character Clothing`. | | Nude/clothing state confusing Krea2 | `build_insta_of_pair` clothing state helpers, then `krea_clothing.natural_clothing_state`. | | Wrong location | `categories/location_pools.json`, category `scene_pool`, `_scene_pool`. | diff --git a/prompt_builder.py b/prompt_builder.py index 40b0714..14df261 100644 --- a/prompt_builder.py +++ b/prompt_builder.py @@ -9,6 +9,21 @@ from string import Formatter from typing import Any, Callable try: + from .category_library import ( + category_json_files as _json_files, + compatible_entries as _compatible_entries, + compatible_entry as _compatible_entry, + configured_pool as _configured_pool, + find_subcategory as _find_subcategory, + load_category_library, + load_composition_pool_library, + load_expression_pool_library, + load_scene_pool_library, + merged_axes as _merged_axes, + merged_field as _merged_field, + read_category_json as _read_json, + template_list as _template_list, + ) from . import generate_prompt_batches as g from . import scene_camera_adapters from .hardcore_text_cleanup import ( @@ -23,6 +38,21 @@ try: sanitize_prompt_text, ) except ImportError: # Allows local smoke tests with `python -c`. + from category_library import ( + category_json_files as _json_files, + compatible_entries as _compatible_entries, + compatible_entry as _compatible_entry, + configured_pool as _configured_pool, + find_subcategory as _find_subcategory, + load_category_library, + load_composition_pool_library, + load_expression_pool_library, + load_scene_pool_library, + merged_axes as _merged_axes, + merged_field as _merged_field, + read_category_json as _read_json, + template_list as _template_list, + ) import generate_prompt_batches as g import scene_camera_adapters from hardcore_text_cleanup import ( @@ -39,7 +69,6 @@ except ImportError: # Allows local smoke tests with `python -c`. ROOT_DIR = Path(__file__).resolve().parent -CATEGORY_DIR = ROOT_DIR / "categories" PROFILE_DIR = ROOT_DIR / "profiles" BUILTIN_CATEGORIES = [ @@ -726,22 +755,6 @@ class SafeFormatDict(dict): return "{" + key + "}" -def _json_files() -> list[Path]: - if not CATEGORY_DIR.exists(): - return [] - return sorted(path for path in CATEGORY_DIR.glob("*.json") if path.is_file()) - - -def _read_json(path: Path) -> dict[str, Any]: - try: - data = json.loads(path.read_text(encoding="utf-8")) - except json.JSONDecodeError as exc: - raise ValueError(f"Invalid JSON in {path}: {exc}") from exc - if not isinstance(data, dict): - raise ValueError(f"{path} must contain a JSON object") - return data - - def _slug(value: str) -> str: return g.slugify(value) or "custom" @@ -845,229 +858,6 @@ def _item_name(item: Any) -> str: return _item_text(item) -def _template_list(category: dict[str, Any], subcategory: dict[str, Any], item: Any, key: str) -> list[Any]: - if isinstance(item, dict) and key in item: - return _list_from(item[key]) - if key in subcategory: - return _list_from(subcategory[key]) - if key in category: - return _list_from(category[key]) - return [] - - -def _constraint_int(entry: dict[str, Any], key: str) -> int | None: - if key not in entry: - return None - try: - return int(entry[key]) - except (TypeError, ValueError): - return None - - -def _cast_requirement_matches(requirement: str, women_count: int, men_count: int) -> bool: - total = women_count + men_count - requirement = requirement.strip().lower() - if requirement in ("", "any"): - return True - if requirement == "women_only": - return women_count > 0 and men_count == 0 - if requirement == "men_only": - return men_count > 0 and women_count == 0 - if requirement == "mixed": - return women_count > 0 and men_count > 0 - if requirement == "has_women": - return women_count > 0 - if requirement == "has_men": - return men_count > 0 - if requirement == "solo": - return total == 1 - if requirement == "couple": - return total == 2 - if requirement == "threesome": - return total == 3 - if requirement == "group": - return total >= 4 - return True - - -def _is_toy_assisted_double_couple_text(text: str) -> bool: - text = text.lower() - if "toy" not in text: - return False - return any( - token in text - for token in ( - "double penetration", - "double-penetration", - "vaginal and anal penetration", - "second penetration point", - "second point of contact", - "second contact", - ) - ) - - -def _heuristic_cast_compatible(text: str, women_count: int, men_count: int) -> bool: - text = text.lower() - if not text: - return True - total = women_count + men_count - if total == 2 and women_count == 1 and men_count == 1: - if "{double_act}" in text: - return False - if _is_toy_assisted_double_couple_text(text): - return False - if total == 1: - solo_blocked_terms = ( - "partner", - "partners", - "two bodies", - "three bodies", - "bodies still pressed", - "bodies pressed", - "bodies tangled", - "wet bodies", - "chests heaving together", - "straddling a partner", - "shared climax", - "between two", - "from both sides", - "front-and-back", - "body contact", - ) - if any(term in text for term in solo_blocked_terms): - return False - solo_toy_terms = ("toy", "dildo", "finger", "fingers", "self") - if "penetration" in text and not any(term in text for term in solo_toy_terms): - return False - if total < 3 and "threesome" in text: - return False - if total != 3 and ("centered threesome" in text or "three-way" in text): - return False - if total < 3 and ("three bodies" in text or "center partner" in text or "center body" in text): - return False - if total < 4 and ("orgy" in text or "group sex" in text or "group-sex" in text or "group pile" in text): - return False - if total < 3 and ( - "double penetration" in text - or "two partners penetrating" in text - or "front-and-back penetration" in text - or "one penis in pussy and one penis in ass" in text - or "pussy and ass filled" in text - or "vaginal and anal penetration at the same time" in text - or "front-and-back double penetration" in text - or "hardcore double penetration" in text - or "kneeling double penetration" in text - or "standing supported double penetration" in text - or "deep double penetration" in text - or "between two partners" in text - or "from both sides" in text - ): - toy_terms = ("strap-on", "strap on", "dildo", "toy", "finger") - if not any(term in text for term in toy_terms): - return False - if men_count == 0: - toy_terms = ("strap-on", "strap on", "dildo", "toy", "finger", "fingers") - penetration_terms = ( - "vaginal penetration", - "deep vaginal sex", - "penetrative sex", - "pussy penetration", - "pussy stretched", - "vaginal thrusting", - "full-body penetrative", - "close-contact vaginal", - "penetration clearly visible", - "explicit penetrative contact", - ) - if any(term in text for term in penetration_terms) and not any(term in text for term in toy_terms): - return False - male_terms = ( - " penis", - "penis ", - "penises", - "cum", - "creampie", - "facial", - "blowjob", - "fellatio", - "deepthroat", - "ejaculation", - "semen", - ) - if any(term in text for term in male_terms) and not any(term in text for term in toy_terms): - return False - elif men_count < 2 and "penises" in text: - return False - if women_count == 0: - if "penetrative sex" in text and not any(term in text for term in ("anal", "ass", "male/male", "men")): - return False - female_terms = ( - "pussy", - "vaginal", - "vagina", - "cunnilingus", - "clit", - "clitoris", - "breasts", - "breast ", - "nipples", - "nipple", - "underboob", - ) - if any(term in text for term in female_terms): - return False - return True - - -def _compatible_entry(entry: Any, women_count: int, men_count: int) -> bool: - if not isinstance(entry, dict): - return _heuristic_cast_compatible(_entry_text(entry), women_count, men_count) - total = women_count + men_count - for key, value in ( - ("min_women", women_count), - ("min_men", men_count), - ("min_people", total), - ): - minimum = _constraint_int(entry, key) - if minimum is not None and value < minimum: - return False - for key, value in ( - ("max_women", women_count), - ("max_men", men_count), - ("max_people", total), - ): - maximum = _constraint_int(entry, key) - if maximum is not None and value > maximum: - return False - requirements = _list_from(entry.get("cast", [])) + _list_from(entry.get("requires", [])) - if requirements and not all(_cast_requirement_matches(str(req), women_count, men_count) for req in requirements): - return False - if any(key in entry for key in ("subcategories", "item_templates", "item_axes")): - return True - return _heuristic_cast_compatible(_entry_text(entry), women_count, men_count) - - -def _compatible_entries(entries: list[Any], women_count: int, men_count: int) -> list[Any]: - filtered = [entry for entry in entries if _compatible_entry(entry, women_count, men_count)] - return filtered or entries - - -def _merged_axes(category: dict[str, Any], subcategory: dict[str, Any], item: Any) -> dict[str, list[Any]]: - axes: dict[str, list[Any]] = {} - for source in (category, subcategory, item if isinstance(item, dict) else None): - if not isinstance(source, dict): - continue - raw_axes = source.get("item_axes", {}) - if raw_axes is None: - continue - if not isinstance(raw_axes, dict): - raise ValueError("item_axes must be a JSON object") - for key, values in raw_axes.items(): - axes[str(key)] = _list_from(values) - return axes - - def _oral_acts_for_position(values: list[Any], position: str) -> list[Any]: position_text = str(position or "").lower() if not position_text: @@ -1324,87 +1114,6 @@ def _choose_pair(rng: random.Random, items: list[Any]) -> tuple[str, str]: return _pair_from(_weighted_choice(rng, items)) -def _normalize_subcategories(category: dict[str, Any]) -> list[dict[str, Any]]: - raw = category.get("subcategories", []) - if isinstance(raw, dict): - raw = [ - {"name": name, **(value if isinstance(value, dict) else {"items": value})} - for name, value in raw.items() - ] - subcategories: list[dict[str, Any]] = [] - for entry in _list_from(raw): - if isinstance(entry, str): - sub = {"name": entry, "items": [entry]} - elif isinstance(entry, dict): - sub = dict(entry) - else: - raise ValueError(f"Subcategory must be an object or string: {entry!r}") - name = str(sub.get("name") or sub.get("slug") or "General").strip() - sub["name"] = name - sub["slug"] = str(sub.get("slug") or _slug(name)) - if "items" not in sub and "prompts" in sub: - sub["items"] = sub["prompts"] - if "items" not in sub: - sub["items"] = [name] - subcategories.append(sub) - if not subcategories: - name = str(category.get("name") or "General") - subcategories.append({"name": "General", "slug": "general", "items": [name]}) - return subcategories - - -def _normalize_categories(raw_categories: Any) -> list[dict[str, Any]]: - if isinstance(raw_categories, dict): - iterable = [ - {"name": name, **(value if isinstance(value, dict) else {"subcategories": value})} - for name, value in raw_categories.items() - ] - else: - iterable = _list_from(raw_categories) - - categories: list[dict[str, Any]] = [] - for entry in iterable: - if not isinstance(entry, dict): - raise ValueError(f"Category must be an object: {entry!r}") - category = dict(entry) - name = str(category.get("name") or category.get("slug") or "Custom").strip() - category["name"] = name - category["slug"] = str(category.get("slug") or _slug(name)) - category["subcategories"] = _normalize_subcategories(category) - categories.append(category) - return categories - - -def load_category_library() -> list[dict[str, Any]]: - categories: list[dict[str, Any]] = [] - for path in _json_files(): - data = _read_json(path) - categories.extend(_normalize_categories(data.get("categories", []))) - return categories - - -def _load_named_pool_library(key: str) -> dict[str, list[Any]]: - pools: dict[str, list[Any]] = {} - for path in _json_files(): - data = _read_json(path) - raw_pools = data.get(key, {}) - if not raw_pools: - continue - if not isinstance(raw_pools, dict): - raise ValueError(f"{key} in {path} must be an object") - for name, entries in raw_pools.items(): - pool_name = str(name).strip() - if not pool_name: - continue - pools.setdefault(pool_name, []) - _unique_extend(pools[pool_name], _list_from(entries)) - return pools - - -def load_scene_pool_library() -> dict[str, list[Any]]: - return _load_named_pool_library("scene_pools") - - LOCATION_POOL_PRESETS = { "custom_only": (), "all_json_locations": ("*",), @@ -1435,14 +1144,6 @@ def location_pool_preset_choices() -> list[str]: return list(LOCATION_POOL_PRESETS) + pool_choices -def load_expression_pool_library() -> dict[str, list[Any]]: - return _load_named_pool_library("expression_pools") - - -def load_composition_pool_library() -> dict[str, list[Any]]: - return _load_named_pool_library("composition_pools") - - COMPOSITION_POOL_PRESETS = { "custom_only": (), "all_json_compositions": ("*",), @@ -3913,101 +3614,6 @@ def _auto_full_choice(seed_config: dict[str, int], seed: int, row_number: int) - return str(choice.get("category") or "auto_weighted") -def _find_category(categories: list[dict[str, Any]], name_or_slug: str) -> dict[str, Any] | None: - wanted = name_or_slug.strip().lower() - for category in categories: - if category["name"].lower() == wanted or category["slug"].lower() == wanted: - return category - return None - - -def _base_cast_counts(women_count: int, men_count: int) -> tuple[int, int]: - women_count = max(0, int(women_count)) - men_count = max(0, int(men_count)) - if women_count + men_count == 0: - women_count = 1 - return women_count, men_count - - -def _counts_for_exact_subcategory( - subcategory: dict[str, Any], - women_count: int, - men_count: int, -) -> tuple[int, int]: - women_count, men_count = _base_cast_counts(women_count, men_count) - - min_women = _constraint_int(subcategory, "min_women") - if min_women is not None and women_count < min_women: - women_count = min_women - min_men = _constraint_int(subcategory, "min_men") - if min_men is not None and men_count < min_men: - men_count = min_men - - min_people = _constraint_int(subcategory, "min_people") - if min_people is not None: - missing = min_people - (women_count + men_count) - if missing > 0: - if women_count > 0 or men_count == 0: - women_count += missing - else: - men_count += missing - return women_count, men_count - - -def _find_subcategory( - categories: list[dict[str, Any]], - category_choice: str, - subcategory_choice: str, - category_rng: random.Random, - subcategory_rng: random.Random, - women_count: int = 1, - men_count: int = 1, -) -> tuple[dict[str, Any], dict[str, Any], int, int]: - women_count, men_count = _base_cast_counts(women_count, men_count) - if subcategory_choice and subcategory_choice != RANDOM_SUBCATEGORY and " / " in subcategory_choice: - category_name, subcategory_name = subcategory_choice.split(" / ", 1) - category = _find_category(categories, category_name) - if not category: - raise ValueError(f"Unknown category in subcategory picker: {category_name}") - wanted = subcategory_name.strip().lower() - for subcategory in category["subcategories"]: - if subcategory["name"].lower() == wanted or subcategory["slug"].lower() == wanted: - adjusted_women_count, adjusted_men_count = _counts_for_exact_subcategory( - subcategory, - women_count, - men_count, - ) - if not _compatible_entry(subcategory, adjusted_women_count, adjusted_men_count): - raise ValueError( - f"Subcategory '{subcategory['name']}' is not compatible with " - f"women_count={women_count}, men_count={men_count}" - ) - return category, subcategory, adjusted_women_count, adjusted_men_count - raise ValueError(f"Unknown subcategory '{subcategory_name}' for category '{category_name}'") - - if category_choice == "custom_random": - if not categories: - raise ValueError("No custom categories found in categories/*.json") - category = _weighted_choice(category_rng, categories) - else: - category = _find_category(categories, category_choice) - if not category: - raise ValueError(f"Unknown custom category: {category_choice}") - subcategories = _compatible_entries(category["subcategories"], women_count, men_count) - subcategory = _weighted_choice(subcategory_rng, subcategories) - return category, subcategory, women_count, men_count - - -def _merged_field(category: dict[str, Any], subcategory: dict[str, Any], item: Any, key: str, default: Any = None) -> Any: - if isinstance(item, dict) and key in item: - return item[key] - if key in subcategory: - return subcategory[key] - if key in category: - return category[key] - return default - - def _body_phrase(body: Any, figure_note: Any = "") -> str: body = str(body or "").strip() figure_note = str(figure_note or "").strip() @@ -6112,45 +5718,6 @@ def _apply_composition_config_to_legacy_row( return row -def _sources_with_inheritance( - category: dict[str, Any], - subcategory: dict[str, Any], - item: Any, - inherit_key: str, -) -> tuple[Any, ...]: - item_source = item if isinstance(item, dict) else None - if item_source is not None and _is_false(item_source.get(inherit_key)): - return (item_source,) - if _is_false(subcategory.get(inherit_key)): - return (subcategory, item_source) - return (category, subcategory, item_source) - - -def _configured_pool( - category: dict[str, Any], - subcategory: dict[str, Any], - item: Any, - direct_key: str, - pool_key: str, - pool_library: dict[str, list[Any]], - inherit_key: str, -) -> list[Any]: - entries: list[Any] = [] - singular_pool_key = pool_key[:-1] if pool_key.endswith("s") else pool_key - for source in _sources_with_inheritance(category, subcategory, item, inherit_key): - if not isinstance(source, dict): - continue - if direct_key in source: - _unique_extend(entries, _list_from(source[direct_key])) - refs = _list_from(source.get(singular_pool_key)) + _list_from(source.get(pool_key)) - for ref in refs: - ref_name = str(ref).strip() - if ref_name not in pool_library: - raise ValueError(f"Unknown {singular_pool_key} '{ref_name}'") - _unique_extend(entries, pool_library[ref_name]) - return entries - - def _expression_pool(category: dict[str, Any], subcategory: dict[str, Any], item: Any) -> list[Any]: return _configured_pool( category, diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 6ace139..00f2bbd 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -11,6 +11,7 @@ from __future__ import annotations import argparse import json +import random import re import sys from dataclasses import dataclass, field @@ -23,6 +24,7 @@ if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import caption_naturalizer # noqa: E402 +import category_library # noqa: E402 import krea_formatter # noqa: E402 import prompt_builder as pb # noqa: E402 import sdxl_formatter # noqa: E402 @@ -494,6 +496,56 @@ def smoke_config_route_location_theme() -> None: _expect_formatter_outputs(row, "config_route_location_theme", target="single") +def smoke_category_library_route() -> None: + categories = category_library.load_category_library() + _expect(len(categories) >= 3, "category library should load JSON categories") + category, subcategory, women_count, men_count = category_library.find_subcategory( + categories, + "custom_random", + "Hardcore sexual poses / Oral sex", + random.Random(101), + random.Random(102), + women_count=1, + men_count=1, + ) + _expect(category.get("slug") == "hardcore_sexual_poses", "exact category lookup selected wrong category") + _expect(subcategory.get("slug") == "oral_sex", "exact subcategory lookup selected wrong subcategory") + _expect((women_count, men_count) == (1, 1), "exact subcategory lookup changed compatible cast counts") + + item = category_library.compatible_entries(list(subcategory.get("items") or []), women_count, men_count)[0] + scenes = category_library.configured_pool( + category, + subcategory, + item, + "scenes", + "scene_pools", + category_library.load_scene_pool_library(), + "inherit_scenes", + ) + expressions = category_library.configured_pool( + category, + subcategory, + item, + "expressions", + "expression_pools", + category_library.load_expression_pool_library(), + "inherit_expressions", + ) + compositions = category_library.configured_pool( + category, + subcategory, + item, + "compositions", + "composition_pools", + category_library.load_composition_pool_library(), + "inherit_compositions", + ) + _expect(scenes, "category inheritance did not resolve scenes") + _expect(expressions, "category inheritance did not resolve expressions") + _expect(compositions, "category inheritance did not resolve compositions") + _expect(any("oral" in _clean_key(entry.get("prompt") if isinstance(entry, dict) else entry) for entry in scenes), "oral scene pool did not contribute") + + def smoke_hardcore_category_routes() -> None: cast = _character_cast() cases = [ @@ -1593,6 +1645,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [ ("builtin_single_woman", smoke_builtin_single), ("camera_scene_single", smoke_camera_scene_single), ("config_route_location_theme", smoke_config_route_location_theme), + ("category_library_route", smoke_category_library_route), ("hardcore_category_routes", smoke_hardcore_category_routes), ("krea_close_foreplay_route", smoke_krea_close_foreplay_route), ("insta_pair_same_cast", smoke_insta_pair),