242 lines
8.7 KiB
Python
242 lines
8.7 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
try:
|
|
from . import category_library as category_policy
|
|
from . import row_expression as row_expression_policy
|
|
from . import row_item as row_item_policy
|
|
from . import row_pools as row_pool_policy
|
|
from . import pov_policy
|
|
from .hardcore_text_cleanup import sanitize_hardcore_environment_anchors
|
|
except ImportError: # Allows local smoke tests from the repository root.
|
|
import category_library as category_policy
|
|
import row_expression as row_expression_policy
|
|
import row_item as row_item_policy
|
|
import row_pools as row_pool_policy
|
|
import pov_policy
|
|
from hardcore_text_cleanup import sanitize_hardcore_environment_anchors
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PromptAxesRoute:
|
|
scene_slug: str
|
|
scene: str
|
|
scene_entry: dict[str, Any]
|
|
pose: str
|
|
expression: str
|
|
shared_expression: str
|
|
character_expressions: list[str]
|
|
character_expression_text: str
|
|
source_composition: str
|
|
composition: str
|
|
composition_entry: dict[str, Any]
|
|
|
|
def as_dict(self) -> dict[str, Any]:
|
|
return {
|
|
"scene_slug": self.scene_slug,
|
|
"scene": self.scene,
|
|
"scene_entry": dict(self.scene_entry),
|
|
"pose": self.pose,
|
|
"expression": self.expression,
|
|
"shared_expression": self.shared_expression,
|
|
"character_expressions": list(self.character_expressions),
|
|
"character_expression_text": self.character_expression_text,
|
|
"source_composition": self.source_composition,
|
|
"composition": self.composition,
|
|
"composition_entry": dict(self.composition_entry),
|
|
}
|
|
|
|
|
|
def _metadata_entry(value: Any, *, slug: str = "", text: str = "") -> dict[str, Any]:
|
|
if isinstance(value, dict):
|
|
entry = dict(value)
|
|
elif isinstance(value, (list, tuple)) and len(value) == 2:
|
|
entry = {"slug": str(value[0]), "prompt": str(value[1])}
|
|
else:
|
|
entry = {"prompt": str(value or "")}
|
|
if slug:
|
|
entry["slug"] = slug
|
|
if text:
|
|
if "prompt" in entry:
|
|
entry["prompt"] = text
|
|
elif "text" in entry:
|
|
entry["text"] = text
|
|
else:
|
|
entry["prompt"] = text
|
|
return entry
|
|
|
|
|
|
def resolve_prompt_axes_result(
|
|
*,
|
|
category: dict[str, Any],
|
|
subcategory: dict[str, Any],
|
|
item: Any,
|
|
subject_type: str,
|
|
context: dict[str, Any],
|
|
poses: str,
|
|
women_count: int,
|
|
men_count: int,
|
|
scene_rng: Any,
|
|
pose_rng: Any,
|
|
expression_rng: Any,
|
|
composition_rng: Any,
|
|
expression_disabled: bool,
|
|
expression_intensity: float,
|
|
character_slots: list[dict[str, Any]] | None = None,
|
|
character_slot_map: dict[str, dict[str, Any]] | None = None,
|
|
expression_phase: str = "",
|
|
source_role_graph: Any = "",
|
|
item_axis_values: dict[str, Any] | None = None,
|
|
is_pose_category: bool = False,
|
|
pov_character_labels: list[str] | None = None,
|
|
location_config: dict[str, Any] | None = None,
|
|
composition_config: dict[str, Any] | None = None,
|
|
) -> PromptAxesRoute:
|
|
character_slots = character_slots or []
|
|
character_slot_map = character_slot_map or {}
|
|
pov_character_labels = pov_character_labels or []
|
|
|
|
scene_entries = category_policy.compatible_entries(
|
|
row_pool_policy.scene_pool(category, subcategory, item, subject_type, location_config),
|
|
women_count,
|
|
men_count,
|
|
)
|
|
scene_choice = row_item_policy.weighted_choice(scene_rng, scene_entries)
|
|
scene_slug, scene = row_item_policy.pair_from(scene_choice)
|
|
scene_entry = _metadata_entry(scene_choice, slug=scene_slug, text=scene)
|
|
pose = str(
|
|
category_policy.merged_field(category, subcategory, item, "pose", "")
|
|
or context.get("fallback_pose")
|
|
or row_item_policy.choose_text(
|
|
pose_rng,
|
|
category_policy.compatible_entries(
|
|
row_pool_policy.pose_pool(category, subcategory, item, subject_type, poses),
|
|
women_count,
|
|
men_count,
|
|
),
|
|
)
|
|
)
|
|
if is_pose_category:
|
|
pose = sanitize_hardcore_environment_anchors(pose)
|
|
|
|
expression_pool = row_pool_policy.expression_pool(category, subcategory, item)
|
|
if expression_disabled:
|
|
expression = ""
|
|
else:
|
|
expression_entries = category_policy.compatible_entries(
|
|
row_expression_policy.expression_entries_for_intensity(expression_pool, expression_intensity),
|
|
women_count,
|
|
men_count,
|
|
)
|
|
expression = row_item_policy.choose_text(expression_rng, expression_entries)
|
|
if subject_type in ("couple", "group") and ";" not in expression:
|
|
secondary_expression = row_item_policy.choose_distinct_text(expression_rng, expression_entries, expression)
|
|
if secondary_expression:
|
|
expression = f"{expression}; {secondary_expression}"
|
|
|
|
shared_expression = expression
|
|
character_expressions: list[str] = []
|
|
character_expression_text = ""
|
|
if not expression_disabled and subject_type == "configured_cast" and character_slots:
|
|
character_expressions = row_expression_policy.character_expression_entries(
|
|
expression_rng,
|
|
expression_pool,
|
|
expression_intensity,
|
|
character_slot_map,
|
|
women_count,
|
|
men_count,
|
|
expression_phase,
|
|
)
|
|
character_expression_text = "; ".join(character_expressions)
|
|
character_expression_text = row_expression_policy.sanitize_character_expression_text_for_action(
|
|
character_expression_text,
|
|
source_role_graph,
|
|
item,
|
|
item_axis_values or {},
|
|
)
|
|
character_expressions = [part.strip() for part in character_expression_text.split(";") if part.strip()]
|
|
if character_expression_text:
|
|
expression = character_expression_text
|
|
|
|
composition_entries = category_policy.compatible_entries(
|
|
row_pool_policy.composition_pool(category, subcategory, item, subject_type, composition_config),
|
|
women_count,
|
|
men_count,
|
|
)
|
|
composition_choice = row_item_policy.weighted_choice(composition_rng, composition_entries)
|
|
source_composition = row_item_policy.item_text(composition_choice)
|
|
composition_entry = _metadata_entry(composition_choice, text=source_composition)
|
|
if is_pose_category:
|
|
source_composition = sanitize_hardcore_environment_anchors(source_composition)
|
|
composition_entry["prompt"] = source_composition
|
|
composition = pov_policy.pov_composition_prompt(source_composition, pov_character_labels)
|
|
|
|
return PromptAxesRoute(
|
|
scene_slug=scene_slug,
|
|
scene=scene,
|
|
scene_entry=scene_entry,
|
|
pose=pose,
|
|
expression=expression,
|
|
shared_expression=shared_expression,
|
|
character_expressions=character_expressions,
|
|
character_expression_text=character_expression_text,
|
|
source_composition=source_composition,
|
|
composition=composition,
|
|
composition_entry=composition_entry,
|
|
)
|
|
|
|
|
|
def resolve_prompt_axes(
|
|
*,
|
|
category: dict[str, Any],
|
|
subcategory: dict[str, Any],
|
|
item: Any,
|
|
subject_type: str,
|
|
context: dict[str, Any],
|
|
poses: str,
|
|
women_count: int,
|
|
men_count: int,
|
|
scene_rng: Any,
|
|
pose_rng: Any,
|
|
expression_rng: Any,
|
|
composition_rng: Any,
|
|
expression_disabled: bool,
|
|
expression_intensity: float,
|
|
character_slots: list[dict[str, Any]] | None = None,
|
|
character_slot_map: dict[str, dict[str, Any]] | None = None,
|
|
expression_phase: str = "",
|
|
source_role_graph: Any = "",
|
|
item_axis_values: dict[str, Any] | None = None,
|
|
is_pose_category: bool = False,
|
|
pov_character_labels: list[str] | None = None,
|
|
location_config: dict[str, Any] | None = None,
|
|
composition_config: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
return resolve_prompt_axes_result(
|
|
category=category,
|
|
subcategory=subcategory,
|
|
item=item,
|
|
subject_type=subject_type,
|
|
context=context,
|
|
poses=poses,
|
|
women_count=women_count,
|
|
men_count=men_count,
|
|
scene_rng=scene_rng,
|
|
pose_rng=pose_rng,
|
|
expression_rng=expression_rng,
|
|
composition_rng=composition_rng,
|
|
expression_disabled=expression_disabled,
|
|
expression_intensity=expression_intensity,
|
|
character_slots=character_slots,
|
|
character_slot_map=character_slot_map,
|
|
expression_phase=expression_phase,
|
|
source_role_graph=source_role_graph,
|
|
item_axis_values=item_axis_values,
|
|
is_pose_category=is_pose_category,
|
|
pov_character_labels=pov_character_labels,
|
|
location_config=location_config,
|
|
composition_config=composition_config,
|
|
).as_dict()
|