Preserve location route metadata

This commit is contained in:
2026-06-27 13:21:51 +02:00
parent 63e8489fb2
commit 75a71a2df6
9 changed files with 215 additions and 24 deletions
+40 -14
View File
@@ -23,6 +23,7 @@ except ImportError: # Allows local smoke tests from the repository root.
class PromptAxesRoute:
scene_slug: str
scene: str
scene_entry: dict[str, Any]
pose: str
expression: str
shared_expression: str
@@ -30,11 +31,13 @@ class PromptAxesRoute:
character_expression_text: str
source_composition: str
composition: str
composition_entry: dict[str, Any]
def as_dict(self) -> dict[str, Any]:
return {
"scene_slug": self.scene_slug,
"scene": self.scene,
"scene_entry": dict(self.scene_entry),
"pose": self.pose,
"expression": self.expression,
"shared_expression": self.shared_expression,
@@ -42,9 +45,29 @@ class PromptAxesRoute:
"character_expression_text": self.character_expression_text,
"source_composition": self.source_composition,
"composition": self.composition,
"composition_entry": dict(self.composition_entry),
}
def _metadata_entry(value: Any, *, slug: str = "", text: str = "") -> dict[str, Any]:
if isinstance(value, dict):
entry = dict(value)
elif isinstance(value, (list, tuple)) and len(value) == 2:
entry = {"slug": str(value[0]), "prompt": str(value[1])}
else:
entry = {"prompt": str(value or "")}
if slug:
entry["slug"] = slug
if text:
if "prompt" in entry:
entry["prompt"] = text
elif "text" in entry:
entry["text"] = text
else:
entry["prompt"] = text
return entry
def resolve_prompt_axes_result(
*,
category: dict[str, Any],
@@ -75,14 +98,14 @@ def resolve_prompt_axes_result(
character_slot_map = character_slot_map or {}
pov_character_labels = pov_character_labels or []
scene_slug, scene = row_item_policy.choose_pair(
scene_rng,
category_policy.compatible_entries(
row_pool_policy.scene_pool(category, subcategory, item, subject_type, location_config),
women_count,
men_count,
),
scene_entries = category_policy.compatible_entries(
row_pool_policy.scene_pool(category, subcategory, item, subject_type, location_config),
women_count,
men_count,
)
scene_choice = row_item_policy.weighted_choice(scene_rng, scene_entries)
scene_slug, scene = row_item_policy.pair_from(scene_choice)
scene_entry = _metadata_entry(scene_choice, slug=scene_slug, text=scene)
pose = str(
category_policy.merged_field(category, subcategory, item, "pose", "")
or context.get("fallback_pose")
@@ -137,21 +160,23 @@ def resolve_prompt_axes_result(
if character_expression_text:
expression = character_expression_text
source_composition = row_item_policy.choose_text(
composition_rng,
category_policy.compatible_entries(
row_pool_policy.composition_pool(category, subcategory, item, subject_type, composition_config),
women_count,
men_count,
),
composition_entries = category_policy.compatible_entries(
row_pool_policy.composition_pool(category, subcategory, item, subject_type, composition_config),
women_count,
men_count,
)
composition_choice = row_item_policy.weighted_choice(composition_rng, composition_entries)
source_composition = row_item_policy.item_text(composition_choice)
composition_entry = _metadata_entry(composition_choice, text=source_composition)
if is_pose_category:
source_composition = sanitize_hardcore_environment_anchors(source_composition)
composition_entry["prompt"] = source_composition
composition = pov_policy.pov_composition_prompt(source_composition, pov_character_labels)
return PromptAxesRoute(
scene_slug=scene_slug,
scene=scene,
scene_entry=scene_entry,
pose=pose,
expression=expression,
shared_expression=shared_expression,
@@ -159,6 +184,7 @@ def resolve_prompt_axes_result(
character_expression_text=character_expression_text,
source_composition=source_composition,
composition=composition,
composition_entry=composition_entry,
)