Support seeded atlas prompt cue variants
This commit is contained in:
@@ -27,6 +27,7 @@ try:
|
||||
from . import generate_prompt_batches as g
|
||||
from . import generation_profile_config as generation_profile_policy
|
||||
from . import hardcore_position_config as hardcore_position_policy
|
||||
from . import krea2_pose_variant_catalog
|
||||
from . import location_config as location_policy
|
||||
from . import pair_builder
|
||||
from . import pair_cast
|
||||
@@ -76,6 +77,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
|
||||
import generate_prompt_batches as g
|
||||
import generation_profile_config as generation_profile_policy
|
||||
import hardcore_position_config as hardcore_position_policy
|
||||
import krea2_pose_variant_catalog
|
||||
import location_config as location_policy
|
||||
import pair_builder
|
||||
import pair_cast
|
||||
@@ -717,6 +719,30 @@ def _axis_values_with_krea2_variant_keys(
|
||||
return merged
|
||||
|
||||
|
||||
def _axis_values_with_krea2_prompt_variant_indices(
|
||||
axis_values: dict[str, Any],
|
||||
*,
|
||||
seed_config: dict[str, int],
|
||||
seed: int,
|
||||
row_number: int,
|
||||
) -> dict[str, Any]:
|
||||
variant_keys = axis_values.get("krea2_variant_keys") if isinstance(axis_values, dict) else []
|
||||
if not isinstance(variant_keys, list) or not variant_keys:
|
||||
return axis_values
|
||||
rng = seed_policy.axis_rng(seed_config, "pose", seed, row_number)
|
||||
indices: dict[str, int] = {}
|
||||
for key in variant_keys:
|
||||
variant = krea2_pose_variant_catalog.get_variant(str(key))
|
||||
cue_sets = krea2_pose_variant_catalog.prompt_cue_sets(variant)
|
||||
if len(cue_sets) > 1:
|
||||
indices[str(key)] = rng.randrange(len(cue_sets))
|
||||
if not indices:
|
||||
return axis_values
|
||||
merged = dict(axis_values)
|
||||
merged["krea2_prompt_variant_indices"] = indices
|
||||
return merged
|
||||
|
||||
|
||||
def _axis_values_with_hardcore_route_metadata(
|
||||
axis_values: dict[str, Any],
|
||||
*,
|
||||
@@ -2387,6 +2413,12 @@ def _build_custom_row(
|
||||
item_axis_values,
|
||||
parsed_hardcore_position_config,
|
||||
)
|
||||
item_axis_values = _axis_values_with_krea2_prompt_variant_indices(
|
||||
item_axis_values,
|
||||
seed_config=seed_config,
|
||||
seed=seed,
|
||||
row_number=row_number,
|
||||
)
|
||||
item_template_metadata = dict(category_route.item_template_metadata)
|
||||
item_formatter_hints = dict(category_route.formatter_hints)
|
||||
is_pose_category = category_route.is_pose_category
|
||||
|
||||
Reference in New Issue
Block a user