Files
ComfyUI-Ethanfel-Prompt-Bui…/row_category_route.py
T

342 lines
12 KiB
Python

from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any
try:
from . import category_library as category_policy
from . import category_template_metadata as template_policy
from . import hardcore_position_config as hardcore_position_policy
from . import row_item as row_item_policy
from . import seed_config as seed_policy
from .hardcore_text_cleanup import (
sanitize_hardcore_axis_values,
sanitize_hardcore_environment_anchors,
)
except ImportError: # Allows local smoke tests from the repository root.
import category_library as category_policy
import category_template_metadata as template_policy
import hardcore_position_config as hardcore_position_policy
import row_item as row_item_policy
import seed_config as seed_policy
from hardcore_text_cleanup import (
sanitize_hardcore_axis_values,
sanitize_hardcore_environment_anchors,
)
def _list_from(value: Any) -> list[Any]:
if value is None:
return []
if isinstance(value, list):
return value
return [value]
def is_pose_content_category(category: dict[str, Any], subcategory: dict[str, Any]) -> bool:
haystack = " ".join(
str(value)
for value in (
category.get("name", ""),
category.get("slug", ""),
category.get("item_label", ""),
subcategory.get("name", ""),
subcategory.get("slug", ""),
subcategory.get("item_label", ""),
)
).lower()
tokens = set(re.findall(r"[a-z0-9]+", haystack))
return bool(tokens.intersection({"pose", "poses", "sex", "sexual"}))
def cast_count_adjustment(
requested_women_count: int,
requested_men_count: int,
effective_women_count: int,
effective_men_count: int,
) -> dict[str, int]:
if requested_women_count == effective_women_count and requested_men_count == effective_men_count:
return {}
return {
"requested_women_count": requested_women_count,
"requested_men_count": requested_men_count,
"effective_women_count": effective_women_count,
"effective_men_count": effective_men_count,
}
@dataclass(frozen=True)
class CategoryItemRoute:
category: dict[str, Any]
subcategory: dict[str, Any]
women_count: int
men_count: int
count_adjustment: dict[str, int]
content_axis: str
item: Any
item_text: str
item_name: str
item_axis_values: dict[str, Any]
item_template_metadata: dict[str, Any]
formatter_hints: dict[str, Any]
is_pose_category: bool
def as_dict(self) -> dict[str, Any]:
return {
"category": self.category,
"subcategory": self.subcategory,
"women_count": self.women_count,
"men_count": self.men_count,
"count_adjustment": dict(self.count_adjustment),
"content_axis": self.content_axis,
"item": self.item,
"item_text": self.item_text,
"item_name": self.item_name,
"item_axis_values": dict(self.item_axis_values),
"item_template_metadata": dict(self.item_template_metadata),
"formatter_hints": dict(self.formatter_hints),
"is_pose_category": self.is_pose_category,
}
def _unique_texts(values: list[Any]) -> list[str]:
selected: list[str] = []
seen: set[str] = set()
for value in values:
text = row_item_policy.entry_text(value).strip(" .;")
lower = text.lower()
if not text or lower in seen:
continue
selected.append(text)
seen.add(lower)
return selected
def _restore_axis_values_for_context(
values: list[Any],
*,
subcategory_slug: str,
axis_name: str,
item_axis_values: dict[str, Any],
women_count: int,
men_count: int,
) -> list[Any]:
values = category_policy.compatible_entries(values, women_count, men_count)
if subcategory_slug == "oral_sex":
return row_item_policy.oral_axis_values_for_context(
values,
str(item_axis_values.get("position") or ""),
str(item_axis_values.get("oral_act") or ""),
axis_name,
)
if subcategory_slug == "outercourse_sex":
return row_item_policy.outercourse_axis_values_for_position(
values,
str(item_axis_values.get("position") or ""),
axis_name,
)
if subcategory_slug == "anal_double_penetration":
return row_item_policy.anal_axis_values_for_position(
values,
str(item_axis_values.get("position") or ""),
axis_name,
)
return values
def _fallback_restore_axis_values(
source_categories: list[dict[str, Any]] | None,
axis_name: str,
hardcore_position_config: dict[str, Any],
) -> list[Any]:
if not source_categories:
return []
values: list[Any] = []
for category in source_categories:
if not hardcore_position_policy.is_hardcore_sexual_category(category):
continue
for subcategory in category.get("subcategories", []):
raw_axes = subcategory.get("item_axes")
if isinstance(raw_axes, dict):
values.extend(_list_from(raw_axes.get(axis_name)))
if not values:
return []
return hardcore_position_policy.filter_hardcore_axis(axis_name, values, hardcore_position_config)
def _restored_prompt_axis_values(
rng: Any,
subcategory: dict[str, Any],
item_axis_values: dict[str, Any],
hardcore_position_config: dict[str, Any],
women_count: int,
men_count: int,
source_categories: list[dict[str, Any]] | None = None,
) -> dict[str, str]:
restore_axes = hardcore_position_policy.normalize_restore_prompt_axes(
hardcore_position_config.get("restore_prompt_axes") if isinstance(hardcore_position_config, dict) else []
)
raw_axes = subcategory.get("item_axes")
if not restore_axes or not isinstance(raw_axes, dict):
return {}
restored: dict[str, str] = {}
subcategory_slug = str(subcategory.get("slug") or "").lower()
for axis_name in restore_axes:
existing = ""
if axis_name in item_axis_values and item_axis_values.get(axis_name) is not None:
existing = row_item_policy.entry_text(item_axis_values.get(axis_name)).strip(" .;")
if existing:
restored[axis_name] = existing
continue
values = _list_from(raw_axes.get(axis_name))
if not values:
values = _fallback_restore_axis_values(source_categories, axis_name, hardcore_position_config)
if not values:
continue
values = _restore_axis_values_for_context(
values,
subcategory_slug=subcategory_slug,
axis_name=axis_name,
item_axis_values=item_axis_values,
women_count=women_count,
men_count=men_count,
)
if not values:
continue
restored[axis_name] = row_item_policy.entry_text(row_item_policy.weighted_choice(rng, values)).strip(" .;")
return {axis: value for axis, value in restored.items() if value}
def _append_restored_prompt_details(item_text: str, details: list[str]) -> str:
details = [detail for detail in _unique_texts(details) if detail.lower() not in str(item_text or "").lower()]
if not details:
return item_text
if not item_text:
return "; ".join(details)
return f"{str(item_text).rstrip(' .')}, with {'; '.join(details)}"
def select_category_item_route_result(
*,
category_choice: str,
subcategory_choice: str,
seed_config: dict[str, int],
seed: int,
row_number: int,
women_count: int,
men_count: int,
hardcore_position_config: dict[str, Any] | None = None,
categories: list[dict[str, Any]] | None = None,
) -> CategoryItemRoute:
source_categories = category_policy.load_category_library() if categories is None else categories
parsed_hardcore_position_config = hardcore_position_config or {}
requested_women_count = women_count
requested_men_count = men_count
category_rng = seed_policy.axis_rng(seed_config, "category", seed, row_number)
subcategory_rng = seed_policy.axis_rng(seed_config, "subcategory", seed, row_number)
filtered_categories = hardcore_position_policy.filter_hardcore_categories_for_position(
source_categories,
parsed_hardcore_position_config,
women_count,
men_count,
category_policy.compatible_entry,
)
category, subcategory, women_count, men_count = category_policy.find_subcategory(
filtered_categories,
category_choice,
subcategory_choice,
category_rng,
subcategory_rng,
women_count,
men_count,
)
count_adjustment = cast_count_adjustment(
requested_women_count,
requested_men_count,
women_count,
men_count,
)
if hardcore_position_policy.is_hardcore_sexual_category(category):
subcategory = hardcore_position_policy.apply_hardcore_position_config_to_subcategory(
subcategory,
parsed_hardcore_position_config,
)
is_pose_category = is_pose_content_category(category, subcategory)
content_axis = "pose" if is_pose_category else "content"
content_rng = seed_policy.axis_rng(seed_config, content_axis, seed, row_number)
item = row_item_policy.weighted_choice(content_rng, _list_from(subcategory.get("items", [subcategory["name"]])))
item_text, item_name, item_axis_values, item_template_metadata = row_item_policy.compose_item(
content_rng,
category,
subcategory,
item,
women_count,
men_count,
)
restored_axis_values = _restored_prompt_axis_values(
content_rng,
subcategory,
item_axis_values,
parsed_hardcore_position_config,
women_count,
men_count,
source_categories,
)
if restored_axis_values:
restored_details = _unique_texts(list(restored_axis_values.values()))
item_axis_values = {
**item_axis_values,
**restored_axis_values,
"restored_prompt_axes": list(restored_axis_values.keys()),
"restored_prompt_details": restored_details,
}
item_text = _append_restored_prompt_details(item_text, restored_details)
if is_pose_category:
item_text = sanitize_hardcore_environment_anchors(item_text)
item_axis_values = sanitize_hardcore_axis_values(item_axis_values)
return CategoryItemRoute(
category=category,
subcategory=subcategory,
women_count=women_count,
men_count=men_count,
count_adjustment=count_adjustment,
content_axis=content_axis,
item=item,
item_text=item_text,
item_name=item_name,
item_axis_values=item_axis_values,
item_template_metadata=item_template_metadata,
formatter_hints=template_policy.formatter_hints(item_template_metadata),
is_pose_category=is_pose_category,
)
def select_category_item_route(
*,
category_choice: str,
subcategory_choice: str,
seed_config: dict[str, int],
seed: int,
row_number: int,
women_count: int,
men_count: int,
hardcore_position_config: dict[str, Any] | None = None,
categories: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
return select_category_item_route_result(
category_choice=category_choice,
subcategory_choice=subcategory_choice,
seed_config=seed_config,
seed=seed,
row_number=row_number,
women_count=women_count,
men_count=men_count,
hardcore_position_config=hardcore_position_config,
categories=categories,
).as_dict()