Support seeded atlas prompt cue variants

This commit is contained in:
2026-07-01 00:21:31 +02:00
parent 7e41613c1e
commit b8d164a3da
5 changed files with 148 additions and 9 deletions
+32
View File
@@ -27,6 +27,7 @@ try:
from . import generate_prompt_batches as g
from . import generation_profile_config as generation_profile_policy
from . import hardcore_position_config as hardcore_position_policy
from . import krea2_pose_variant_catalog
from . import location_config as location_policy
from . import pair_builder
from . import pair_cast
@@ -76,6 +77,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
import generate_prompt_batches as g
import generation_profile_config as generation_profile_policy
import hardcore_position_config as hardcore_position_policy
import krea2_pose_variant_catalog
import location_config as location_policy
import pair_builder
import pair_cast
@@ -717,6 +719,30 @@ def _axis_values_with_krea2_variant_keys(
return merged
def _axis_values_with_krea2_prompt_variant_indices(
axis_values: dict[str, Any],
*,
seed_config: dict[str, int],
seed: int,
row_number: int,
) -> dict[str, Any]:
variant_keys = axis_values.get("krea2_variant_keys") if isinstance(axis_values, dict) else []
if not isinstance(variant_keys, list) or not variant_keys:
return axis_values
rng = seed_policy.axis_rng(seed_config, "pose", seed, row_number)
indices: dict[str, int] = {}
for key in variant_keys:
variant = krea2_pose_variant_catalog.get_variant(str(key))
cue_sets = krea2_pose_variant_catalog.prompt_cue_sets(variant)
if len(cue_sets) > 1:
indices[str(key)] = rng.randrange(len(cue_sets))
if not indices:
return axis_values
merged = dict(axis_values)
merged["krea2_prompt_variant_indices"] = indices
return merged
def _axis_values_with_hardcore_route_metadata(
axis_values: dict[str, Any],
*,
@@ -2387,6 +2413,12 @@ def _build_custom_row(
item_axis_values,
parsed_hardcore_position_config,
)
item_axis_values = _axis_values_with_krea2_prompt_variant_indices(
item_axis_values,
seed_config=seed_config,
seed=seed,
row_number=row_number,
)
item_template_metadata = dict(category_route.item_template_metadata)
item_formatter_hints = dict(category_route.formatter_hints)
is_pose_category = category_route.is_pose_category