Preserve location route metadata
This commit is contained in:
+40
-14
@@ -23,6 +23,7 @@ except ImportError: # Allows local smoke tests from the repository root.
|
||||
class PromptAxesRoute:
|
||||
scene_slug: str
|
||||
scene: str
|
||||
scene_entry: dict[str, Any]
|
||||
pose: str
|
||||
expression: str
|
||||
shared_expression: str
|
||||
@@ -30,11 +31,13 @@ class PromptAxesRoute:
|
||||
character_expression_text: str
|
||||
source_composition: str
|
||||
composition: str
|
||||
composition_entry: dict[str, Any]
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"scene_slug": self.scene_slug,
|
||||
"scene": self.scene,
|
||||
"scene_entry": dict(self.scene_entry),
|
||||
"pose": self.pose,
|
||||
"expression": self.expression,
|
||||
"shared_expression": self.shared_expression,
|
||||
@@ -42,9 +45,29 @@ class PromptAxesRoute:
|
||||
"character_expression_text": self.character_expression_text,
|
||||
"source_composition": self.source_composition,
|
||||
"composition": self.composition,
|
||||
"composition_entry": dict(self.composition_entry),
|
||||
}
|
||||
|
||||
|
||||
def _metadata_entry(value: Any, *, slug: str = "", text: str = "") -> dict[str, Any]:
|
||||
if isinstance(value, dict):
|
||||
entry = dict(value)
|
||||
elif isinstance(value, (list, tuple)) and len(value) == 2:
|
||||
entry = {"slug": str(value[0]), "prompt": str(value[1])}
|
||||
else:
|
||||
entry = {"prompt": str(value or "")}
|
||||
if slug:
|
||||
entry["slug"] = slug
|
||||
if text:
|
||||
if "prompt" in entry:
|
||||
entry["prompt"] = text
|
||||
elif "text" in entry:
|
||||
entry["text"] = text
|
||||
else:
|
||||
entry["prompt"] = text
|
||||
return entry
|
||||
|
||||
|
||||
def resolve_prompt_axes_result(
|
||||
*,
|
||||
category: dict[str, Any],
|
||||
@@ -75,14 +98,14 @@ def resolve_prompt_axes_result(
|
||||
character_slot_map = character_slot_map or {}
|
||||
pov_character_labels = pov_character_labels or []
|
||||
|
||||
scene_slug, scene = row_item_policy.choose_pair(
|
||||
scene_rng,
|
||||
category_policy.compatible_entries(
|
||||
row_pool_policy.scene_pool(category, subcategory, item, subject_type, location_config),
|
||||
women_count,
|
||||
men_count,
|
||||
),
|
||||
scene_entries = category_policy.compatible_entries(
|
||||
row_pool_policy.scene_pool(category, subcategory, item, subject_type, location_config),
|
||||
women_count,
|
||||
men_count,
|
||||
)
|
||||
scene_choice = row_item_policy.weighted_choice(scene_rng, scene_entries)
|
||||
scene_slug, scene = row_item_policy.pair_from(scene_choice)
|
||||
scene_entry = _metadata_entry(scene_choice, slug=scene_slug, text=scene)
|
||||
pose = str(
|
||||
category_policy.merged_field(category, subcategory, item, "pose", "")
|
||||
or context.get("fallback_pose")
|
||||
@@ -137,21 +160,23 @@ def resolve_prompt_axes_result(
|
||||
if character_expression_text:
|
||||
expression = character_expression_text
|
||||
|
||||
source_composition = row_item_policy.choose_text(
|
||||
composition_rng,
|
||||
category_policy.compatible_entries(
|
||||
row_pool_policy.composition_pool(category, subcategory, item, subject_type, composition_config),
|
||||
women_count,
|
||||
men_count,
|
||||
),
|
||||
composition_entries = category_policy.compatible_entries(
|
||||
row_pool_policy.composition_pool(category, subcategory, item, subject_type, composition_config),
|
||||
women_count,
|
||||
men_count,
|
||||
)
|
||||
composition_choice = row_item_policy.weighted_choice(composition_rng, composition_entries)
|
||||
source_composition = row_item_policy.item_text(composition_choice)
|
||||
composition_entry = _metadata_entry(composition_choice, text=source_composition)
|
||||
if is_pose_category:
|
||||
source_composition = sanitize_hardcore_environment_anchors(source_composition)
|
||||
composition_entry["prompt"] = source_composition
|
||||
composition = pov_policy.pov_composition_prompt(source_composition, pov_character_labels)
|
||||
|
||||
return PromptAxesRoute(
|
||||
scene_slug=scene_slug,
|
||||
scene=scene,
|
||||
scene_entry=scene_entry,
|
||||
pose=pose,
|
||||
expression=expression,
|
||||
shared_expression=shared_expression,
|
||||
@@ -159,6 +184,7 @@ def resolve_prompt_axes_result(
|
||||
character_expression_text=character_expression_text,
|
||||
source_composition=source_composition,
|
||||
composition=composition,
|
||||
composition_entry=composition_entry,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user