Extract role graph route policy
This commit is contained in:
+31
-6
@@ -44,6 +44,7 @@ try:
|
||||
from . import row_prompt_axes as row_prompt_axes_policy
|
||||
from . import row_pools as row_pool_policy
|
||||
from . import row_rendering as row_rendering_policy
|
||||
from . import row_role_graph as row_role_graph_policy
|
||||
from . import row_route_metadata as row_route_policy
|
||||
from . import row_subject_route as row_subject_route_policy
|
||||
from . import seed_config as seed_policy
|
||||
@@ -52,7 +53,6 @@ try:
|
||||
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
|
||||
sanitize_hardcore_environment_anchors as _sanitize_hardcore_environment_anchors,
|
||||
)
|
||||
from .hardcore_role_graphs import build_hardcore_role_graph
|
||||
except ImportError: # Allows local smoke tests with `python -c`.
|
||||
from category_library import (
|
||||
compatible_entries as _compatible_entries,
|
||||
@@ -93,6 +93,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
|
||||
import row_prompt_axes as row_prompt_axes_policy
|
||||
import row_pools as row_pool_policy
|
||||
import row_rendering as row_rendering_policy
|
||||
import row_role_graph as row_role_graph_policy
|
||||
import row_route_metadata as row_route_policy
|
||||
import row_subject_route as row_subject_route_policy
|
||||
import seed_config as seed_policy
|
||||
@@ -101,7 +102,6 @@ except ImportError: # Allows local smoke tests with `python -c`.
|
||||
sanitize_hardcore_axis_values as _sanitize_hardcore_axis_values,
|
||||
sanitize_hardcore_environment_anchors as _sanitize_hardcore_environment_anchors,
|
||||
)
|
||||
from hardcore_role_graphs import build_hardcore_role_graph
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parent
|
||||
@@ -2121,6 +2121,25 @@ def _prompt_axes_route(
|
||||
)
|
||||
|
||||
|
||||
def _role_graph_route(
|
||||
*,
|
||||
rng: random.Random,
|
||||
subcategory: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
item_axis_values: dict[str, Any],
|
||||
pov_character_labels: list[str],
|
||||
is_pose_category: bool,
|
||||
) -> row_role_graph_policy.RoleGraphRoute:
|
||||
return row_role_graph_policy.resolve_role_graph_route(
|
||||
rng=rng,
|
||||
subcategory=subcategory,
|
||||
context=context,
|
||||
item_axis_values=item_axis_values,
|
||||
pov_character_labels=pov_character_labels,
|
||||
is_pose_category=is_pose_category,
|
||||
)
|
||||
|
||||
|
||||
def _assemble_custom_row(request: row_assembly_policy.CustomRowAssemblyRequest) -> dict[str, Any]:
|
||||
return row_assembly_policy.assemble_custom_row(request)
|
||||
|
||||
@@ -2207,10 +2226,16 @@ def _build_custom_row(
|
||||
pov_character_labels = list(subject_route.get("pov_character_labels") or [])
|
||||
cast_descriptors = list(subject_route.get("cast_descriptors") or [])
|
||||
cast_descriptor_text = str(subject_route.get("cast_descriptor_text") or "")
|
||||
source_role_graph = build_hardcore_role_graph(role_rng, subcategory, context, item_axis_values, pov_character_labels)
|
||||
if is_pose_category:
|
||||
source_role_graph = _sanitize_hardcore_environment_anchors(source_role_graph)
|
||||
role_graph = _pov_role_graph_prompt(source_role_graph, pov_character_labels)
|
||||
role_graph_route = _role_graph_route(
|
||||
rng=role_rng,
|
||||
subcategory=subcategory,
|
||||
context=context,
|
||||
item_axis_values=item_axis_values,
|
||||
pov_character_labels=pov_character_labels,
|
||||
is_pose_category=is_pose_category,
|
||||
)
|
||||
source_role_graph = role_graph_route.source_role_graph
|
||||
role_graph = role_graph_route.role_graph
|
||||
expression_route = _resolve_expression_route(
|
||||
expression_enabled=expression_enabled,
|
||||
expression_intensity=expression_intensity,
|
||||
|
||||
Reference in New Issue
Block a user