From 00139d0cd905016e5e7a2f6d51d1fd0a8bc0fef1 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 27 Jun 2026 10:32:38 +0200 Subject: [PATCH] Add typed prompt axes route --- docs/prompt-architecture-improvement-plan.md | 7 +- docs/prompt-pool-routing-map.md | 2 +- prompt_builder.py | 73 +++++++++++-- row_prompt_axes.py | 106 ++++++++++++++++--- tools/prompt_smoke.py | 17 +++ 5 files changed, 178 insertions(+), 27 deletions(-) diff --git a/docs/prompt-architecture-improvement-plan.md b/docs/prompt-architecture-improvement-plan.md index 951c7e2..835b268 100644 --- a/docs/prompt-architecture-improvement-plan.md +++ b/docs/prompt-architecture-improvement-plan.md @@ -193,10 +193,11 @@ Already isolated: runtime location/composition pool overrides, and generator fallback pool selection live in `row_pools.py`; `prompt_builder.py` keeps public delegate wrappers. -- row scene/pose/expression/composition axis selection, compatible-entry +- row scene/pose/expression/composition axis selection lives in + `row_prompt_axes.py` behind `PromptAxesRoute`, covering compatible-entry filtering, expression-disabled handling, per-character expression promotion, - POV composition adaptation, and pose-category environment sanitizing live in - `row_prompt_axes.py`; `prompt_builder.py` keeps a public delegate wrapper. + legacy dict compatibility, POV composition adaptation, and pose-category + environment sanitizing; `prompt_builder.py` keeps public delegate wrappers. - row prompt/caption text-field resolution, prompt/caption template selection, safe formatting, configured-cast descriptor insertion, and POV directive insertion live in `row_rendering.py`; `prompt_builder.py` keeps public diff --git a/docs/prompt-pool-routing-map.md b/docs/prompt-pool-routing-map.md index 9c6952c..376373a 100644 --- a/docs/prompt-pool-routing-map.md +++ b/docs/prompt-pool-routing-map.md @@ -93,7 +93,7 @@ Core helper ownership: | `row_location.py` | Built-in row location/composition config application, deterministic scene/composition choice, source metadata, and legacy prompt/caption rewrites. | | `row_expression.py` | Row expression cleanup, expression route resolution, expression intensity weighting, character-slot/cast expression override resolution, per-character expression selection, and action-aware character-expression sanitizing. | | `row_pools.py` | Row scene/expression/pose/composition pool routing, category inheritance handling, runtime location/composition pool overrides, and generator fallback pools. | -| `row_prompt_axes.py` | Row scene/pose/expression/composition axis selection, compatible-entry filtering, expression-disabled handling, per-character expression promotion, POV composition adaptation, and pose-category environment sanitizing. | +| `row_prompt_axes.py` | Row scene/pose/expression/composition axis selection behind `PromptAxesRoute`, compatible-entry filtering, expression-disabled handling, per-character expression promotion, legacy dict compatibility, POV composition adaptation, and pose-category environment sanitizing. | | `hardcore_position_config.py` | Hardcore position/action-filter choices, selected-position normalization, config JSON builders/parsers, focus-policy toggles, subcategory allow-list policy, position-key detection, and category/template/axis filtering. | | `pair_options.py` | Insta/OF option schema/defaults, softcore category/outfit/pose pools, partner outfit pools, clothing-continuity labels, negatives, hardcore cast count policy, and hardcore detail-density directives. | | `pair_rows.py` | Insta/OF soft/hard row creation, softcore expression override resolution, Woman A slot context application, soft outfit/pose overrides, and POV row fields. | diff --git a/prompt_builder.py b/prompt_builder.py index 388aee6..fc92f94 100644 --- a/prompt_builder.py +++ b/prompt_builder.py @@ -2146,6 +2146,59 @@ def _prompt_axes_route( ) +def _prompt_axes_route_result( + *, + category: dict[str, Any], + subcategory: dict[str, Any], + item: Any, + subject_type: str, + context: dict[str, Any], + poses: str, + women_count: int, + men_count: int, + scene_rng: random.Random, + pose_rng: random.Random, + expression_rng: random.Random, + composition_rng: random.Random, + expression_disabled: bool, + expression_intensity: float, + character_slots: list[dict[str, Any]] | None = None, + character_slot_map: dict[str, dict[str, Any]] | None = None, + expression_phase: str = "", + source_role_graph: Any = "", + item_axis_values: dict[str, Any] | None = None, + is_pose_category: bool = False, + pov_character_labels: list[str] | None = None, + location_config: dict[str, Any] | None = None, + composition_config: dict[str, Any] | None = None, +) -> row_prompt_axes_policy.PromptAxesRoute: + return row_prompt_axes_policy.resolve_prompt_axes_result( + category=category, + subcategory=subcategory, + item=item, + subject_type=subject_type, + context=context, + poses=poses, + women_count=women_count, + men_count=men_count, + scene_rng=scene_rng, + pose_rng=pose_rng, + expression_rng=expression_rng, + composition_rng=composition_rng, + expression_disabled=expression_disabled, + expression_intensity=expression_intensity, + character_slots=character_slots, + character_slot_map=character_slot_map, + expression_phase=expression_phase, + source_role_graph=source_role_graph, + item_axis_values=item_axis_values, + is_pose_category=is_pose_category, + pov_character_labels=pov_character_labels, + location_config=location_config, + composition_config=composition_config, + ) + + def _role_graph_route( *, rng: random.Random, @@ -2277,7 +2330,7 @@ def _build_custom_row( expression_intensity = expression_route.expression_intensity expression_intensity_source = expression_route.expression_intensity_source - prompt_axes = _prompt_axes_route( + prompt_axes = _prompt_axes_route_result( category=category, subcategory=subcategory, item=item, @@ -2302,15 +2355,15 @@ def _build_custom_row( location_config=parsed_location_config, composition_config=parsed_composition_config, ) - scene_slug = str(prompt_axes.get("scene_slug") or "") - scene = str(prompt_axes.get("scene") or "") - pose = str(prompt_axes.get("pose") or "") - expression = str(prompt_axes.get("expression") or "") - shared_expression = str(prompt_axes.get("shared_expression") or "") - character_expressions = list(prompt_axes.get("character_expressions") or []) - character_expression_text = str(prompt_axes.get("character_expression_text") or "") - source_composition = str(prompt_axes.get("source_composition") or "") - composition = str(prompt_axes.get("composition") or "") + scene_slug = prompt_axes.scene_slug + scene = prompt_axes.scene + pose = prompt_axes.pose + expression = prompt_axes.expression + shared_expression = prompt_axes.shared_expression + character_expressions = list(prompt_axes.character_expressions) + character_expression_text = prompt_axes.character_expression_text + source_composition = prompt_axes.source_composition + composition = prompt_axes.composition action_route = _action_position_route( is_pose_category=is_pose_category, subcategory=subcategory, diff --git a/row_prompt_axes.py b/row_prompt_axes.py index 769c22c..e6964be 100644 --- a/row_prompt_axes.py +++ b/row_prompt_axes.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Any try: @@ -18,7 +19,33 @@ except ImportError: # Allows local smoke tests from the repository root. from hardcore_text_cleanup import sanitize_hardcore_environment_anchors -def resolve_prompt_axes( +@dataclass(frozen=True) +class PromptAxesRoute: + scene_slug: str + scene: str + pose: str + expression: str + shared_expression: str + character_expressions: list[str] + character_expression_text: str + source_composition: str + composition: str + + def as_dict(self) -> dict[str, Any]: + return { + "scene_slug": self.scene_slug, + "scene": self.scene, + "pose": self.pose, + "expression": self.expression, + "shared_expression": self.shared_expression, + "character_expressions": list(self.character_expressions), + "character_expression_text": self.character_expression_text, + "source_composition": self.source_composition, + "composition": self.composition, + } + + +def resolve_prompt_axes_result( *, category: dict[str, Any], subcategory: dict[str, Any], @@ -43,7 +70,7 @@ def resolve_prompt_axes( pov_character_labels: list[str] | None = None, location_config: dict[str, Any] | None = None, composition_config: dict[str, Any] | None = None, -) -> dict[str, Any]: +) -> PromptAxesRoute: character_slots = character_slots or [] character_slot_map = character_slot_map or {} pov_character_labels = pov_character_labels or [] @@ -122,14 +149,67 @@ def resolve_prompt_axes( source_composition = sanitize_hardcore_environment_anchors(source_composition) composition = pov_policy.pov_composition_prompt(source_composition, pov_character_labels) - return { - "scene_slug": scene_slug, - "scene": scene, - "pose": pose, - "expression": expression, - "shared_expression": shared_expression, - "character_expressions": character_expressions, - "character_expression_text": character_expression_text, - "source_composition": source_composition, - "composition": composition, - } + return PromptAxesRoute( + scene_slug=scene_slug, + scene=scene, + pose=pose, + expression=expression, + shared_expression=shared_expression, + character_expressions=character_expressions, + character_expression_text=character_expression_text, + source_composition=source_composition, + composition=composition, + ) + + +def resolve_prompt_axes( + *, + category: dict[str, Any], + subcategory: dict[str, Any], + item: Any, + subject_type: str, + context: dict[str, Any], + poses: str, + women_count: int, + men_count: int, + scene_rng: Any, + pose_rng: Any, + expression_rng: Any, + composition_rng: Any, + expression_disabled: bool, + expression_intensity: float, + character_slots: list[dict[str, Any]] | None = None, + character_slot_map: dict[str, dict[str, Any]] | None = None, + expression_phase: str = "", + source_role_graph: Any = "", + item_axis_values: dict[str, Any] | None = None, + is_pose_category: bool = False, + pov_character_labels: list[str] | None = None, + location_config: dict[str, Any] | None = None, + composition_config: dict[str, Any] | None = None, +) -> dict[str, Any]: + return resolve_prompt_axes_result( + category=category, + subcategory=subcategory, + item=item, + subject_type=subject_type, + context=context, + poses=poses, + women_count=women_count, + men_count=men_count, + scene_rng=scene_rng, + pose_rng=pose_rng, + expression_rng=expression_rng, + composition_rng=composition_rng, + expression_disabled=expression_disabled, + expression_intensity=expression_intensity, + character_slots=character_slots, + character_slot_map=character_slot_map, + expression_phase=expression_phase, + source_role_graph=source_role_graph, + item_axis_values=item_axis_values, + is_pose_category=is_pose_category, + pov_character_labels=pov_character_labels, + location_config=location_config, + composition_config=composition_config, + ).as_dict() diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index be6eacf..4427fb6 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -868,6 +868,13 @@ def smoke_row_prompt_axes_policy() -> None: expression_rng=random.Random(3), composition_rng=random.Random(4), ) + route_result = row_prompt_axes.resolve_prompt_axes_result( + **base_kwargs, + scene_rng=random.Random(1), + pose_rng=random.Random(2), + expression_rng=random.Random(3), + composition_rng=random.Random(4), + ) delegated = pb._prompt_axes_route( **base_kwargs, scene_rng=random.Random(1), @@ -875,7 +882,17 @@ def smoke_row_prompt_axes_policy() -> None: expression_rng=random.Random(3), composition_rng=random.Random(4), ) + typed_delegated = pb._prompt_axes_route_result( + **base_kwargs, + scene_rng=random.Random(1), + pose_rng=random.Random(2), + expression_rng=random.Random(3), + composition_rng=random.Random(4), + ) _expect(delegated == route, "Prompt builder prompt-axes route should delegate to row_prompt_axes") + _expect(route_result.as_dict() == route, "Typed prompt axes route should match legacy dict route") + _expect(typed_delegated == route_result, "Prompt builder typed prompt-axes route should delegate") + _expect(route_result.scene_slug == "studio", "Typed prompt axes route lost selected scene slug") _expect(route["scene_slug"] == "studio", "Prompt axes route lost selected scene slug") _expect(route["scene"] == "quiet studio with repeatable anchors", "Prompt axes route lost selected scene text") _expect(route["pose"] == "standing fallback pose", "Prompt axes route lost selected fallback pose")