Support seeded atlas prompt cue variants
This commit is contained in:
+76
-7
@@ -6762,12 +6762,17 @@ def smoke_krea2_pov_pose_variant_catalog() -> None:
|
||||
refs = variant.get("reference_images")
|
||||
_expect(isinstance(prompt_cues, list) and prompt_cues, f"{key} has no prompt cues")
|
||||
_expect(isinstance(avoid_cues, list) and avoid_cues, f"{key} has no avoid cues")
|
||||
for cue in prompt_cues:
|
||||
cue_text = _expect_text(f"{key}.prompt_cue", cue, 8)
|
||||
_expect(
|
||||
not re.search(r"\b(?:may|optionally|either|or)\b", cue_text, flags=re.IGNORECASE),
|
||||
f"{key} prompt cue should be a direct model instruction, not an option list: {cue_text!r}",
|
||||
)
|
||||
prompt_variant_cues = variant.get("prompt_variant_cues", [])
|
||||
_expect(isinstance(prompt_variant_cues, list), f"{key} prompt variant cue sets should be a list when present")
|
||||
all_cue_sets = [prompt_cues, *prompt_variant_cues]
|
||||
for cue_set_index, cue_set in enumerate(all_cue_sets):
|
||||
_expect(isinstance(cue_set, list) and cue_set, f"{key} prompt cue set {cue_set_index} is empty")
|
||||
for cue in cue_set:
|
||||
cue_text = _expect_text(f"{key}.prompt_cue[{cue_set_index}]", cue, 8)
|
||||
_expect(
|
||||
not re.search(r"\b(?:may|optionally|either|or)\b", cue_text, flags=re.IGNORECASE),
|
||||
f"{key} prompt cue should be a direct model instruction, not an option list: {cue_text!r}",
|
||||
)
|
||||
_expect(isinstance(refs, list) and refs, f"{key} has no reference images")
|
||||
hook = variant.get("generator_hook") or {}
|
||||
_expect_text(f"{key}.generator_hook.module", hook.get("module"), 6)
|
||||
@@ -6846,7 +6851,17 @@ def smoke_krea2_pov_atlas_variant_prompt_routes() -> None:
|
||||
_expect(variant_keys == [key], f"{key} row lost exact variant metadata: {variant_keys}")
|
||||
krea = krea_formatter.format_krea2_prompt("", metadata_json=_json(pair), target="hardcore")
|
||||
prompt = _expect_text(f"{key}.krea_prompt", krea.get("krea_prompt"), 80).lower()
|
||||
for cue in variant.get("prompt_cues") or []:
|
||||
hard_axis_values = hard_row.get("item_axis_values") if isinstance(hard_row.get("item_axis_values"), dict) else {}
|
||||
variant_indices = hard_axis_values.get("krea2_prompt_variant_indices") if isinstance(hard_axis_values, dict) else {}
|
||||
selected_index = 0
|
||||
if isinstance(variant_indices, dict):
|
||||
try:
|
||||
selected_index = int(variant_indices.get(key, 0))
|
||||
except (TypeError, ValueError):
|
||||
selected_index = 0
|
||||
cue_sets = krea2_pose_variant_catalog.prompt_cue_sets(variant)
|
||||
_expect(0 <= selected_index < len(cue_sets), f"{key} selected invalid prompt variant index {selected_index}")
|
||||
for cue in cue_sets[selected_index]:
|
||||
cue_text = _expect_text(f"{key}.prompt_cue", cue, 8).lower()
|
||||
_expect(cue_text in prompt, f"{key} final Krea prompt lost atlas cue {cue_text!r}: {prompt}")
|
||||
atlas_action_prompt = prompt.split(" camera is ", 1)[0]
|
||||
@@ -6911,6 +6926,60 @@ def smoke_krea2_pov_atlas_variant_prompt_routes() -> None:
|
||||
def smoke_krea2_pose_variant_catalog_policy() -> None:
|
||||
catalog = krea2_pose_variant_catalog.load_catalog()
|
||||
_expect(catalog.get("version") == 1, "Krea2 pose-variant loader returned wrong catalog")
|
||||
synthetic_variant = {
|
||||
"key": "pov_synthetic_seeded_variant",
|
||||
"prompt_cues": ["synthetic baseline atlas cue", "synthetic baseline contact anchor"],
|
||||
"prompt_variant_cues": [
|
||||
["synthetic alternate atlas cue", "synthetic alternate contact anchor"],
|
||||
{"prompt_cues": ["synthetic second alternate cue", "synthetic second contact anchor"]},
|
||||
],
|
||||
}
|
||||
_expect(
|
||||
krea2_pose_variant_catalog.prompt_cue_sets(synthetic_variant) == [
|
||||
["synthetic baseline atlas cue", "synthetic baseline contact anchor"],
|
||||
["synthetic alternate atlas cue", "synthetic alternate contact anchor"],
|
||||
["synthetic second alternate cue", "synthetic second contact anchor"],
|
||||
],
|
||||
"Krea2 pose-variant catalog should expose baseline and optional prompt variant cue sets",
|
||||
)
|
||||
original_get_variant = krea2_pose_variant_catalog.get_variant
|
||||
try:
|
||||
def fake_get_variant(key: str, **kwargs):
|
||||
if key == "pov_synthetic_seeded_variant":
|
||||
return synthetic_variant
|
||||
return original_get_variant(key, **kwargs)
|
||||
|
||||
krea2_pose_variant_catalog.get_variant = fake_get_variant
|
||||
baseline_sentence = krea_pov_actions._krea2_atlas_variant_sentence(
|
||||
{
|
||||
"krea2_variant_keys": ["pov_synthetic_seeded_variant"],
|
||||
"krea2_prompt_variant_indices": {"pov_synthetic_seeded_variant": 0},
|
||||
}
|
||||
)
|
||||
alternate_sentence = krea_pov_actions._krea2_atlas_variant_sentence(
|
||||
{
|
||||
"krea2_variant_keys": ["pov_synthetic_seeded_variant"],
|
||||
"krea2_prompt_variant_indices": {"pov_synthetic_seeded_variant": 1},
|
||||
}
|
||||
)
|
||||
_expect("synthetic baseline atlas cue" in baseline_sentence, "Krea2 atlas sentence lost baseline prompt cue set")
|
||||
_expect("synthetic alternate atlas cue" in alternate_sentence, "Krea2 atlas sentence lost selected prompt variant cue set")
|
||||
selected_indices: set[int] = set()
|
||||
for seed in range(4510, 4560):
|
||||
axis_values = pb._axis_values_with_krea2_prompt_variant_indices(
|
||||
{"krea2_variant_keys": ["pov_synthetic_seeded_variant"]},
|
||||
seed_config={},
|
||||
seed=seed,
|
||||
row_number=1,
|
||||
)
|
||||
selected_indices.add(int(axis_values.get("krea2_prompt_variant_indices", {}).get("pov_synthetic_seeded_variant", 0)))
|
||||
_expect(
|
||||
len(selected_indices) >= 2,
|
||||
f"Krea2 prompt variant selector should vary across pose seeds, got {sorted(selected_indices)}",
|
||||
)
|
||||
finally:
|
||||
krea2_pose_variant_catalog.get_variant = original_get_variant
|
||||
|
||||
proven = krea2_pose_variant_catalog.variant_keys(status="proven")
|
||||
_expect(
|
||||
proven == [
|
||||
|
||||
Reference in New Issue
Block a user