from __future__ import annotations from dataclasses import dataclass from typing import Any try: from . import category_template_metadata as template_policy from . import hardcore_position_config as hardcore_position_policy from .hardcore_action_metadata import source_hardcore_action_family except ImportError: # Allows local smoke tests from the repository root. import category_template_metadata as template_policy import hardcore_position_config as hardcore_position_policy from hardcore_action_metadata import source_hardcore_action_family EMPTY_ACTION_POSITION_ROUTE = { "position_family": "", "position_keys": [], "position_key": "", "action_family": "", } @dataclass(frozen=True) class ActionPositionRoute: position_family: str position_keys: list[str] position_key: str action_family: str def as_dict(self) -> dict[str, Any]: return { "position_family": self.position_family, "position_keys": list(self.position_keys), "position_key": self.position_key, "action_family": self.action_family, } def empty_action_position_route_result() -> ActionPositionRoute: return ActionPositionRoute( position_family="", position_keys=[], position_key="", action_family="", ) def empty_action_position_route() -> dict[str, Any]: return empty_action_position_route_result().as_dict() def resolve_action_position_route_result( *, is_pose_category: bool, subcategory: dict[str, Any], hardcore_position_config: dict[str, Any] | None, item_template_metadata: dict[str, Any] | None, item_text: Any, source_role_graph: Any, source_composition: Any, pose: Any, item_axis_values: dict[str, Any] | None = None, ) -> ActionPositionRoute: if not is_pose_category: return empty_action_position_route_result() metadata = item_template_metadata or {} position_family = template_policy.template_position_family( metadata ) or hardcore_position_policy.hardcore_source_position_family( subcategory, hardcore_position_config, ) inferred_position_keys = hardcore_position_policy.hardcore_position_keys( item_text, source_role_graph, source_composition, pose, axis_values=item_axis_values, ) position_keys = template_policy.merge_position_keys( template_policy.template_position_keys(metadata), inferred_position_keys, ) explicit_action_family = template_policy.template_action_family(metadata) action_family = "" if explicit_action_family == "default" else explicit_action_family if not action_family: action_family = source_hardcore_action_family( position_family, source_role_graph, item_text, source_composition, item_axis_values, ) return ActionPositionRoute( position_family=position_family, position_keys=position_keys, position_key=position_keys[0] if position_keys else "", action_family=action_family, ) def resolve_action_position_route( *, is_pose_category: bool, subcategory: dict[str, Any], hardcore_position_config: dict[str, Any] | None, item_template_metadata: dict[str, Any] | None, item_text: Any, source_role_graph: Any, source_composition: Any, pose: Any, item_axis_values: dict[str, Any] | None = None, ) -> dict[str, Any]: return resolve_action_position_route_result( is_pose_category=is_pose_category, subcategory=subcategory, hardcore_position_config=hardcore_position_config, item_template_metadata=item_template_metadata, item_text=item_text, source_role_graph=source_role_graph, source_composition=source_composition, pose=pose, item_axis_values=item_axis_values, ).as_dict()