from __future__ import annotations import random from dataclasses import dataclass from typing import Any try: from . import hardcore_role_graphs from . import hardcore_text_cleanup from . import pov_policy except ImportError: # Allows local smoke tests from the repository root. import hardcore_role_graphs import hardcore_text_cleanup import pov_policy @dataclass(frozen=True) class RoleGraphRoute: source_role_graph: str role_graph: str def resolve_role_graph_route( *, rng: random.Random, subcategory: dict[str, Any], context: dict[str, Any], item_axis_values: dict[str, Any], pov_character_labels: list[str], is_pose_category: bool, ) -> RoleGraphRoute: source_role_graph = hardcore_role_graphs.build_hardcore_role_graph( rng, subcategory, context, item_axis_values, pov_character_labels, ) if is_pose_category: source_role_graph = hardcore_text_cleanup.sanitize_hardcore_environment_anchors(source_role_graph) role_graph = pov_policy.pov_role_graph_prompt(source_role_graph, pov_character_labels) return RoleGraphRoute(source_role_graph=source_role_graph, role_graph=role_graph)