Extract role graph route policy
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from . import hardcore_role_graphs
|
||||
from . import hardcore_text_cleanup
|
||||
from . import pov_policy
|
||||
except ImportError: # Allows local smoke tests from the repository root.
|
||||
import hardcore_role_graphs
|
||||
import hardcore_text_cleanup
|
||||
import pov_policy
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RoleGraphRoute:
|
||||
source_role_graph: str
|
||||
role_graph: str
|
||||
|
||||
|
||||
def resolve_role_graph_route(
|
||||
*,
|
||||
rng: random.Random,
|
||||
subcategory: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
item_axis_values: dict[str, Any],
|
||||
pov_character_labels: list[str],
|
||||
is_pose_category: bool,
|
||||
) -> RoleGraphRoute:
|
||||
source_role_graph = hardcore_role_graphs.build_hardcore_role_graph(
|
||||
rng,
|
||||
subcategory,
|
||||
context,
|
||||
item_axis_values,
|
||||
pov_character_labels,
|
||||
)
|
||||
if is_pose_category:
|
||||
source_role_graph = hardcore_text_cleanup.sanitize_hardcore_environment_anchors(source_role_graph)
|
||||
role_graph = pov_policy.pov_role_graph_prompt(source_role_graph, pov_character_labels)
|
||||
return RoleGraphRoute(source_role_graph=source_role_graph, role_graph=role_graph)
|
||||
Reference in New Issue
Block a user