Support seeded atlas prompt cue variants
This commit is contained in:
@@ -12,6 +12,7 @@ METADATA_AXIS_KEYS = {
|
|||||||
"position_key",
|
"position_key",
|
||||||
"position_keys",
|
"position_keys",
|
||||||
"krea2_variant_keys",
|
"krea2_variant_keys",
|
||||||
|
"krea2_prompt_variant_indices",
|
||||||
"restored_prompt_axes",
|
"restored_prompt_axes",
|
||||||
}
|
}
|
||||||
ACTION_CONTEXT_PRIORITY = (
|
ACTION_CONTEXT_PRIORITY = (
|
||||||
|
|||||||
@@ -76,6 +76,33 @@ def get_variant(key: str, *, path: str | Path | None = None) -> dict[str, Any]:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _cue_list(value: Any) -> list[str]:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = value.get("prompt_cues") or value.get("cues")
|
||||||
|
if not isinstance(value, list):
|
||||||
|
return []
|
||||||
|
return [str(cue) for cue in value if str(cue).strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_cue_sets(variant_or_key: dict[str, Any] | str) -> list[list[str]]:
|
||||||
|
variant = get_variant(variant_or_key) if isinstance(variant_or_key, str) else dict(variant_or_key or {})
|
||||||
|
if not variant:
|
||||||
|
return []
|
||||||
|
cue_sets: list[list[str]] = []
|
||||||
|
baseline = _cue_list(variant.get("prompt_cues"))
|
||||||
|
if baseline:
|
||||||
|
cue_sets.append(baseline)
|
||||||
|
for cue_set in variant.get("prompt_variant_cues") or []:
|
||||||
|
cues = _cue_list(cue_set)
|
||||||
|
if cues:
|
||||||
|
cue_sets.append(cues)
|
||||||
|
if not cue_sets:
|
||||||
|
fallback = str(variant.get("canonical_geometry") or "").strip()
|
||||||
|
if fallback:
|
||||||
|
cue_sets.append([fallback])
|
||||||
|
return cue_sets
|
||||||
|
|
||||||
|
|
||||||
def reference_paths(key: str, *, path: str | Path | None = None) -> list[Path]:
|
def reference_paths(key: str, *, path: str | Path | None = None) -> list[Path]:
|
||||||
catalog = load_catalog(path)
|
catalog = load_catalog(path)
|
||||||
atlas_root = Path(str(catalog.get("atlas_root") or ""))
|
atlas_root = Path(str(catalog.get("atlas_root") or ""))
|
||||||
@@ -90,4 +117,3 @@ def reference_paths(key: str, *, path: str | Path | None = None) -> list[Path]:
|
|||||||
continue
|
continue
|
||||||
paths.append(atlas_root / ref_path)
|
paths.append(atlas_root / ref_path)
|
||||||
return paths
|
return paths
|
||||||
|
|
||||||
|
|||||||
+12
-1
@@ -109,7 +109,18 @@ def _krea2_atlas_variant_sentence(axis_values: Any) -> str:
|
|||||||
variant = _selected_krea2_atlas_variant(axis_values)
|
variant = _selected_krea2_atlas_variant(axis_values)
|
||||||
if not variant:
|
if not variant:
|
||||||
return ""
|
return ""
|
||||||
cues = _unique_texts(list(variant.get("prompt_cues") or []) or [variant.get("canonical_geometry")])
|
cue_sets = krea2_pose_variant_catalog.prompt_cue_sets(variant)
|
||||||
|
selected_index = 0
|
||||||
|
if isinstance(axis_values, dict):
|
||||||
|
indices = axis_values.get("krea2_prompt_variant_indices")
|
||||||
|
if isinstance(indices, dict):
|
||||||
|
try:
|
||||||
|
selected_index = int(indices.get(str(variant.get("key") or ""), 0))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
selected_index = 0
|
||||||
|
if selected_index < 0 or selected_index >= len(cue_sets):
|
||||||
|
selected_index = 0
|
||||||
|
cues = _unique_texts(cue_sets[selected_index] if cue_sets else [variant.get("canonical_geometry")])
|
||||||
sentence = _clean(". ".join(cues)).rstrip(".")
|
sentence = _clean(". ".join(cues)).rstrip(".")
|
||||||
if isinstance(axis_values, dict):
|
if isinstance(axis_values, dict):
|
||||||
restored_details = _unique_texts(_list_values(axis_values.get("restored_prompt_details")))
|
restored_details = _unique_texts(_list_values(axis_values.get("restored_prompt_details")))
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ try:
|
|||||||
from . import generate_prompt_batches as g
|
from . import generate_prompt_batches as g
|
||||||
from . import generation_profile_config as generation_profile_policy
|
from . import generation_profile_config as generation_profile_policy
|
||||||
from . import hardcore_position_config as hardcore_position_policy
|
from . import hardcore_position_config as hardcore_position_policy
|
||||||
|
from . import krea2_pose_variant_catalog
|
||||||
from . import location_config as location_policy
|
from . import location_config as location_policy
|
||||||
from . import pair_builder
|
from . import pair_builder
|
||||||
from . import pair_cast
|
from . import pair_cast
|
||||||
@@ -76,6 +77,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
|
|||||||
import generate_prompt_batches as g
|
import generate_prompt_batches as g
|
||||||
import generation_profile_config as generation_profile_policy
|
import generation_profile_config as generation_profile_policy
|
||||||
import hardcore_position_config as hardcore_position_policy
|
import hardcore_position_config as hardcore_position_policy
|
||||||
|
import krea2_pose_variant_catalog
|
||||||
import location_config as location_policy
|
import location_config as location_policy
|
||||||
import pair_builder
|
import pair_builder
|
||||||
import pair_cast
|
import pair_cast
|
||||||
@@ -717,6 +719,30 @@ def _axis_values_with_krea2_variant_keys(
|
|||||||
return merged
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def _axis_values_with_krea2_prompt_variant_indices(
|
||||||
|
axis_values: dict[str, Any],
|
||||||
|
*,
|
||||||
|
seed_config: dict[str, int],
|
||||||
|
seed: int,
|
||||||
|
row_number: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
variant_keys = axis_values.get("krea2_variant_keys") if isinstance(axis_values, dict) else []
|
||||||
|
if not isinstance(variant_keys, list) or not variant_keys:
|
||||||
|
return axis_values
|
||||||
|
rng = seed_policy.axis_rng(seed_config, "pose", seed, row_number)
|
||||||
|
indices: dict[str, int] = {}
|
||||||
|
for key in variant_keys:
|
||||||
|
variant = krea2_pose_variant_catalog.get_variant(str(key))
|
||||||
|
cue_sets = krea2_pose_variant_catalog.prompt_cue_sets(variant)
|
||||||
|
if len(cue_sets) > 1:
|
||||||
|
indices[str(key)] = rng.randrange(len(cue_sets))
|
||||||
|
if not indices:
|
||||||
|
return axis_values
|
||||||
|
merged = dict(axis_values)
|
||||||
|
merged["krea2_prompt_variant_indices"] = indices
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
def _axis_values_with_hardcore_route_metadata(
|
def _axis_values_with_hardcore_route_metadata(
|
||||||
axis_values: dict[str, Any],
|
axis_values: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
@@ -2387,6 +2413,12 @@ def _build_custom_row(
|
|||||||
item_axis_values,
|
item_axis_values,
|
||||||
parsed_hardcore_position_config,
|
parsed_hardcore_position_config,
|
||||||
)
|
)
|
||||||
|
item_axis_values = _axis_values_with_krea2_prompt_variant_indices(
|
||||||
|
item_axis_values,
|
||||||
|
seed_config=seed_config,
|
||||||
|
seed=seed,
|
||||||
|
row_number=row_number,
|
||||||
|
)
|
||||||
item_template_metadata = dict(category_route.item_template_metadata)
|
item_template_metadata = dict(category_route.item_template_metadata)
|
||||||
item_formatter_hints = dict(category_route.formatter_hints)
|
item_formatter_hints = dict(category_route.formatter_hints)
|
||||||
is_pose_category = category_route.is_pose_category
|
is_pose_category = category_route.is_pose_category
|
||||||
|
|||||||
+72
-3
@@ -6762,8 +6762,13 @@ def smoke_krea2_pov_pose_variant_catalog() -> None:
|
|||||||
refs = variant.get("reference_images")
|
refs = variant.get("reference_images")
|
||||||
_expect(isinstance(prompt_cues, list) and prompt_cues, f"{key} has no prompt cues")
|
_expect(isinstance(prompt_cues, list) and prompt_cues, f"{key} has no prompt cues")
|
||||||
_expect(isinstance(avoid_cues, list) and avoid_cues, f"{key} has no avoid cues")
|
_expect(isinstance(avoid_cues, list) and avoid_cues, f"{key} has no avoid cues")
|
||||||
for cue in prompt_cues:
|
prompt_variant_cues = variant.get("prompt_variant_cues", [])
|
||||||
cue_text = _expect_text(f"{key}.prompt_cue", cue, 8)
|
_expect(isinstance(prompt_variant_cues, list), f"{key} prompt variant cue sets should be a list when present")
|
||||||
|
all_cue_sets = [prompt_cues, *prompt_variant_cues]
|
||||||
|
for cue_set_index, cue_set in enumerate(all_cue_sets):
|
||||||
|
_expect(isinstance(cue_set, list) and cue_set, f"{key} prompt cue set {cue_set_index} is empty")
|
||||||
|
for cue in cue_set:
|
||||||
|
cue_text = _expect_text(f"{key}.prompt_cue[{cue_set_index}]", cue, 8)
|
||||||
_expect(
|
_expect(
|
||||||
not re.search(r"\b(?:may|optionally|either|or)\b", cue_text, flags=re.IGNORECASE),
|
not re.search(r"\b(?:may|optionally|either|or)\b", cue_text, flags=re.IGNORECASE),
|
||||||
f"{key} prompt cue should be a direct model instruction, not an option list: {cue_text!r}",
|
f"{key} prompt cue should be a direct model instruction, not an option list: {cue_text!r}",
|
||||||
@@ -6846,7 +6851,17 @@ def smoke_krea2_pov_atlas_variant_prompt_routes() -> None:
|
|||||||
_expect(variant_keys == [key], f"{key} row lost exact variant metadata: {variant_keys}")
|
_expect(variant_keys == [key], f"{key} row lost exact variant metadata: {variant_keys}")
|
||||||
krea = krea_formatter.format_krea2_prompt("", metadata_json=_json(pair), target="hardcore")
|
krea = krea_formatter.format_krea2_prompt("", metadata_json=_json(pair), target="hardcore")
|
||||||
prompt = _expect_text(f"{key}.krea_prompt", krea.get("krea_prompt"), 80).lower()
|
prompt = _expect_text(f"{key}.krea_prompt", krea.get("krea_prompt"), 80).lower()
|
||||||
for cue in variant.get("prompt_cues") or []:
|
hard_axis_values = hard_row.get("item_axis_values") if isinstance(hard_row.get("item_axis_values"), dict) else {}
|
||||||
|
variant_indices = hard_axis_values.get("krea2_prompt_variant_indices") if isinstance(hard_axis_values, dict) else {}
|
||||||
|
selected_index = 0
|
||||||
|
if isinstance(variant_indices, dict):
|
||||||
|
try:
|
||||||
|
selected_index = int(variant_indices.get(key, 0))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
selected_index = 0
|
||||||
|
cue_sets = krea2_pose_variant_catalog.prompt_cue_sets(variant)
|
||||||
|
_expect(0 <= selected_index < len(cue_sets), f"{key} selected invalid prompt variant index {selected_index}")
|
||||||
|
for cue in cue_sets[selected_index]:
|
||||||
cue_text = _expect_text(f"{key}.prompt_cue", cue, 8).lower()
|
cue_text = _expect_text(f"{key}.prompt_cue", cue, 8).lower()
|
||||||
_expect(cue_text in prompt, f"{key} final Krea prompt lost atlas cue {cue_text!r}: {prompt}")
|
_expect(cue_text in prompt, f"{key} final Krea prompt lost atlas cue {cue_text!r}: {prompt}")
|
||||||
atlas_action_prompt = prompt.split(" camera is ", 1)[0]
|
atlas_action_prompt = prompt.split(" camera is ", 1)[0]
|
||||||
@@ -6911,6 +6926,60 @@ def smoke_krea2_pov_atlas_variant_prompt_routes() -> None:
|
|||||||
def smoke_krea2_pose_variant_catalog_policy() -> None:
|
def smoke_krea2_pose_variant_catalog_policy() -> None:
|
||||||
catalog = krea2_pose_variant_catalog.load_catalog()
|
catalog = krea2_pose_variant_catalog.load_catalog()
|
||||||
_expect(catalog.get("version") == 1, "Krea2 pose-variant loader returned wrong catalog")
|
_expect(catalog.get("version") == 1, "Krea2 pose-variant loader returned wrong catalog")
|
||||||
|
synthetic_variant = {
|
||||||
|
"key": "pov_synthetic_seeded_variant",
|
||||||
|
"prompt_cues": ["synthetic baseline atlas cue", "synthetic baseline contact anchor"],
|
||||||
|
"prompt_variant_cues": [
|
||||||
|
["synthetic alternate atlas cue", "synthetic alternate contact anchor"],
|
||||||
|
{"prompt_cues": ["synthetic second alternate cue", "synthetic second contact anchor"]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
_expect(
|
||||||
|
krea2_pose_variant_catalog.prompt_cue_sets(synthetic_variant) == [
|
||||||
|
["synthetic baseline atlas cue", "synthetic baseline contact anchor"],
|
||||||
|
["synthetic alternate atlas cue", "synthetic alternate contact anchor"],
|
||||||
|
["synthetic second alternate cue", "synthetic second contact anchor"],
|
||||||
|
],
|
||||||
|
"Krea2 pose-variant catalog should expose baseline and optional prompt variant cue sets",
|
||||||
|
)
|
||||||
|
original_get_variant = krea2_pose_variant_catalog.get_variant
|
||||||
|
try:
|
||||||
|
def fake_get_variant(key: str, **kwargs):
|
||||||
|
if key == "pov_synthetic_seeded_variant":
|
||||||
|
return synthetic_variant
|
||||||
|
return original_get_variant(key, **kwargs)
|
||||||
|
|
||||||
|
krea2_pose_variant_catalog.get_variant = fake_get_variant
|
||||||
|
baseline_sentence = krea_pov_actions._krea2_atlas_variant_sentence(
|
||||||
|
{
|
||||||
|
"krea2_variant_keys": ["pov_synthetic_seeded_variant"],
|
||||||
|
"krea2_prompt_variant_indices": {"pov_synthetic_seeded_variant": 0},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
alternate_sentence = krea_pov_actions._krea2_atlas_variant_sentence(
|
||||||
|
{
|
||||||
|
"krea2_variant_keys": ["pov_synthetic_seeded_variant"],
|
||||||
|
"krea2_prompt_variant_indices": {"pov_synthetic_seeded_variant": 1},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_expect("synthetic baseline atlas cue" in baseline_sentence, "Krea2 atlas sentence lost baseline prompt cue set")
|
||||||
|
_expect("synthetic alternate atlas cue" in alternate_sentence, "Krea2 atlas sentence lost selected prompt variant cue set")
|
||||||
|
selected_indices: set[int] = set()
|
||||||
|
for seed in range(4510, 4560):
|
||||||
|
axis_values = pb._axis_values_with_krea2_prompt_variant_indices(
|
||||||
|
{"krea2_variant_keys": ["pov_synthetic_seeded_variant"]},
|
||||||
|
seed_config={},
|
||||||
|
seed=seed,
|
||||||
|
row_number=1,
|
||||||
|
)
|
||||||
|
selected_indices.add(int(axis_values.get("krea2_prompt_variant_indices", {}).get("pov_synthetic_seeded_variant", 0)))
|
||||||
|
_expect(
|
||||||
|
len(selected_indices) >= 2,
|
||||||
|
f"Krea2 prompt variant selector should vary across pose seeds, got {sorted(selected_indices)}",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
krea2_pose_variant_catalog.get_variant = original_get_variant
|
||||||
|
|
||||||
proven = krea2_pose_variant_catalog.variant_keys(status="proven")
|
proven = krea2_pose_variant_catalog.variant_keys(status="proven")
|
||||||
_expect(
|
_expect(
|
||||||
proven == [
|
proven == [
|
||||||
|
|||||||
Reference in New Issue
Block a user