43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
from __future__ import annotations
|
|
|
|
import random
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
try:
|
|
from . import hardcore_role_graphs
|
|
from . import hardcore_text_cleanup
|
|
from . import pov_policy
|
|
except ImportError: # Allows local smoke tests from the repository root.
|
|
import hardcore_role_graphs
|
|
import hardcore_text_cleanup
|
|
import pov_policy
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RoleGraphRoute:
|
|
source_role_graph: str
|
|
role_graph: str
|
|
|
|
|
|
def resolve_role_graph_route(
|
|
*,
|
|
rng: random.Random,
|
|
subcategory: dict[str, Any],
|
|
context: dict[str, Any],
|
|
item_axis_values: dict[str, Any],
|
|
pov_character_labels: list[str],
|
|
is_pose_category: bool,
|
|
) -> RoleGraphRoute:
|
|
source_role_graph = hardcore_role_graphs.build_hardcore_role_graph(
|
|
rng,
|
|
subcategory,
|
|
context,
|
|
item_axis_values,
|
|
pov_character_labels,
|
|
)
|
|
if is_pose_category:
|
|
source_role_graph = hardcore_text_cleanup.sanitize_hardcore_environment_anchors(source_role_graph)
|
|
role_graph = pov_policy.pov_role_graph_prompt(source_role_graph, pov_character_labels)
|
|
return RoleGraphRoute(source_role_graph=source_role_graph, role_graph=role_graph)
|