diff --git a/item_axis_policy.py b/item_axis_policy.py index 503736e..98e3c07 100644 --- a/item_axis_policy.py +++ b/item_axis_policy.py @@ -12,6 +12,7 @@ METADATA_AXIS_KEYS = { "position_key", "position_keys", "krea2_variant_keys", + "krea2_prompt_variant_indices", "restored_prompt_axes", } ACTION_CONTEXT_PRIORITY = ( diff --git a/krea2_pose_variant_catalog.py b/krea2_pose_variant_catalog.py index 403cc9a..216c298 100644 --- a/krea2_pose_variant_catalog.py +++ b/krea2_pose_variant_catalog.py @@ -76,6 +76,33 @@ def get_variant(key: str, *, path: str | Path | None = None) -> dict[str, Any]: return {} +def _cue_list(value: Any) -> list[str]: + if isinstance(value, dict): + value = value.get("prompt_cues") or value.get("cues") + if not isinstance(value, list): + return [] + return [str(cue) for cue in value if str(cue).strip()] + + +def prompt_cue_sets(variant_or_key: dict[str, Any] | str) -> list[list[str]]: + variant = get_variant(variant_or_key) if isinstance(variant_or_key, str) else dict(variant_or_key or {}) + if not variant: + return [] + cue_sets: list[list[str]] = [] + baseline = _cue_list(variant.get("prompt_cues")) + if baseline: + cue_sets.append(baseline) + for cue_set in variant.get("prompt_variant_cues") or []: + cues = _cue_list(cue_set) + if cues: + cue_sets.append(cues) + if not cue_sets: + fallback = str(variant.get("canonical_geometry") or "").strip() + if fallback: + cue_sets.append([fallback]) + return cue_sets + + def reference_paths(key: str, *, path: str | Path | None = None) -> list[Path]: catalog = load_catalog(path) atlas_root = Path(str(catalog.get("atlas_root") or "")) @@ -90,4 +117,3 @@ def reference_paths(key: str, *, path: str | Path | None = None) -> list[Path]: continue paths.append(atlas_root / ref_path) return paths - diff --git a/krea_pov_actions.py b/krea_pov_actions.py index 65f12b6..e4b4891 100644 --- a/krea_pov_actions.py +++ b/krea_pov_actions.py @@ -109,7 +109,18 @@ def _krea2_atlas_variant_sentence(axis_values: Any) -> str: variant = _selected_krea2_atlas_variant(axis_values) if not variant: return "" - cues = _unique_texts(list(variant.get("prompt_cues") or []) or [variant.get("canonical_geometry")]) + cue_sets = krea2_pose_variant_catalog.prompt_cue_sets(variant) + selected_index = 0 + if isinstance(axis_values, dict): + indices = axis_values.get("krea2_prompt_variant_indices") + if isinstance(indices, dict): + try: + selected_index = int(indices.get(str(variant.get("key") or ""), 0)) + except (TypeError, ValueError): + selected_index = 0 + if selected_index < 0 or selected_index >= len(cue_sets): + selected_index = 0 + cues = _unique_texts(cue_sets[selected_index] if cue_sets else [variant.get("canonical_geometry")]) sentence = _clean(". ".join(cues)).rstrip(".") if isinstance(axis_values, dict): restored_details = _unique_texts(_list_values(axis_values.get("restored_prompt_details"))) diff --git a/prompt_builder.py b/prompt_builder.py index bae9243..fc53382 100644 --- a/prompt_builder.py +++ b/prompt_builder.py @@ -27,6 +27,7 @@ try: from . import generate_prompt_batches as g from . import generation_profile_config as generation_profile_policy from . import hardcore_position_config as hardcore_position_policy + from . import krea2_pose_variant_catalog from . import location_config as location_policy from . import pair_builder from . import pair_cast @@ -76,6 +77,7 @@ except ImportError: # Allows local smoke tests with `python -c`. import generate_prompt_batches as g import generation_profile_config as generation_profile_policy import hardcore_position_config as hardcore_position_policy + import krea2_pose_variant_catalog import location_config as location_policy import pair_builder import pair_cast @@ -717,6 +719,30 @@ def _axis_values_with_krea2_variant_keys( return merged +def _axis_values_with_krea2_prompt_variant_indices( + axis_values: dict[str, Any], + *, + seed_config: dict[str, int], + seed: int, + row_number: int, +) -> dict[str, Any]: + variant_keys = axis_values.get("krea2_variant_keys") if isinstance(axis_values, dict) else [] + if not isinstance(variant_keys, list) or not variant_keys: + return axis_values + rng = seed_policy.axis_rng(seed_config, "pose", seed, row_number) + indices: dict[str, int] = {} + for key in variant_keys: + variant = krea2_pose_variant_catalog.get_variant(str(key)) + cue_sets = krea2_pose_variant_catalog.prompt_cue_sets(variant) + if len(cue_sets) > 1: + indices[str(key)] = rng.randrange(len(cue_sets)) + if not indices: + return axis_values + merged = dict(axis_values) + merged["krea2_prompt_variant_indices"] = indices + return merged + + def _axis_values_with_hardcore_route_metadata( axis_values: dict[str, Any], *, @@ -2387,6 +2413,12 @@ def _build_custom_row( item_axis_values, parsed_hardcore_position_config, ) + item_axis_values = _axis_values_with_krea2_prompt_variant_indices( + item_axis_values, + seed_config=seed_config, + seed=seed, + row_number=row_number, + ) item_template_metadata = dict(category_route.item_template_metadata) item_formatter_hints = dict(category_route.formatter_hints) is_pose_category = category_route.is_pose_category diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index a518653..092045a 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -6762,12 +6762,17 @@ def smoke_krea2_pov_pose_variant_catalog() -> None: refs = variant.get("reference_images") _expect(isinstance(prompt_cues, list) and prompt_cues, f"{key} has no prompt cues") _expect(isinstance(avoid_cues, list) and avoid_cues, f"{key} has no avoid cues") - for cue in prompt_cues: - cue_text = _expect_text(f"{key}.prompt_cue", cue, 8) - _expect( - not re.search(r"\b(?:may|optionally|either|or)\b", cue_text, flags=re.IGNORECASE), - f"{key} prompt cue should be a direct model instruction, not an option list: {cue_text!r}", - ) + prompt_variant_cues = variant.get("prompt_variant_cues", []) + _expect(isinstance(prompt_variant_cues, list), f"{key} prompt variant cue sets should be a list when present") + all_cue_sets = [prompt_cues, *prompt_variant_cues] + for cue_set_index, cue_set in enumerate(all_cue_sets): + _expect(isinstance(cue_set, list) and cue_set, f"{key} prompt cue set {cue_set_index} is empty") + for cue in cue_set: + cue_text = _expect_text(f"{key}.prompt_cue[{cue_set_index}]", cue, 8) + _expect( + not re.search(r"\b(?:may|optionally|either|or)\b", cue_text, flags=re.IGNORECASE), + f"{key} prompt cue should be a direct model instruction, not an option list: {cue_text!r}", + ) _expect(isinstance(refs, list) and refs, f"{key} has no reference images") hook = variant.get("generator_hook") or {} _expect_text(f"{key}.generator_hook.module", hook.get("module"), 6) @@ -6846,7 +6851,17 @@ def smoke_krea2_pov_atlas_variant_prompt_routes() -> None: _expect(variant_keys == [key], f"{key} row lost exact variant metadata: {variant_keys}") krea = krea_formatter.format_krea2_prompt("", metadata_json=_json(pair), target="hardcore") prompt = _expect_text(f"{key}.krea_prompt", krea.get("krea_prompt"), 80).lower() - for cue in variant.get("prompt_cues") or []: + hard_axis_values = hard_row.get("item_axis_values") if isinstance(hard_row.get("item_axis_values"), dict) else {} + variant_indices = hard_axis_values.get("krea2_prompt_variant_indices") if isinstance(hard_axis_values, dict) else {} + selected_index = 0 + if isinstance(variant_indices, dict): + try: + selected_index = int(variant_indices.get(key, 0)) + except (TypeError, ValueError): + selected_index = 0 + cue_sets = krea2_pose_variant_catalog.prompt_cue_sets(variant) + _expect(0 <= selected_index < len(cue_sets), f"{key} selected invalid prompt variant index {selected_index}") + for cue in cue_sets[selected_index]: cue_text = _expect_text(f"{key}.prompt_cue", cue, 8).lower() _expect(cue_text in prompt, f"{key} final Krea prompt lost atlas cue {cue_text!r}: {prompt}") atlas_action_prompt = prompt.split(" camera is ", 1)[0] @@ -6911,6 +6926,60 @@ def smoke_krea2_pov_atlas_variant_prompt_routes() -> None: def smoke_krea2_pose_variant_catalog_policy() -> None: catalog = krea2_pose_variant_catalog.load_catalog() _expect(catalog.get("version") == 1, "Krea2 pose-variant loader returned wrong catalog") + synthetic_variant = { + "key": "pov_synthetic_seeded_variant", + "prompt_cues": ["synthetic baseline atlas cue", "synthetic baseline contact anchor"], + "prompt_variant_cues": [ + ["synthetic alternate atlas cue", "synthetic alternate contact anchor"], + {"prompt_cues": ["synthetic second alternate cue", "synthetic second contact anchor"]}, + ], + } + _expect( + krea2_pose_variant_catalog.prompt_cue_sets(synthetic_variant) == [ + ["synthetic baseline atlas cue", "synthetic baseline contact anchor"], + ["synthetic alternate atlas cue", "synthetic alternate contact anchor"], + ["synthetic second alternate cue", "synthetic second contact anchor"], + ], + "Krea2 pose-variant catalog should expose baseline and optional prompt variant cue sets", + ) + original_get_variant = krea2_pose_variant_catalog.get_variant + try: + def fake_get_variant(key: str, **kwargs): + if key == "pov_synthetic_seeded_variant": + return synthetic_variant + return original_get_variant(key, **kwargs) + + krea2_pose_variant_catalog.get_variant = fake_get_variant + baseline_sentence = krea_pov_actions._krea2_atlas_variant_sentence( + { + "krea2_variant_keys": ["pov_synthetic_seeded_variant"], + "krea2_prompt_variant_indices": {"pov_synthetic_seeded_variant": 0}, + } + ) + alternate_sentence = krea_pov_actions._krea2_atlas_variant_sentence( + { + "krea2_variant_keys": ["pov_synthetic_seeded_variant"], + "krea2_prompt_variant_indices": {"pov_synthetic_seeded_variant": 1}, + } + ) + _expect("synthetic baseline atlas cue" in baseline_sentence, "Krea2 atlas sentence lost baseline prompt cue set") + _expect("synthetic alternate atlas cue" in alternate_sentence, "Krea2 atlas sentence lost selected prompt variant cue set") + selected_indices: set[int] = set() + for seed in range(4510, 4560): + axis_values = pb._axis_values_with_krea2_prompt_variant_indices( + {"krea2_variant_keys": ["pov_synthetic_seeded_variant"]}, + seed_config={}, + seed=seed, + row_number=1, + ) + selected_indices.add(int(axis_values.get("krea2_prompt_variant_indices", {}).get("pov_synthetic_seeded_variant", 0))) + _expect( + len(selected_indices) >= 2, + f"Krea2 prompt variant selector should vary across pose seeds, got {sorted(selected_indices)}", + ) + finally: + krea2_pose_variant_catalog.get_variant = original_get_variant + proven = krea2_pose_variant_catalog.variant_keys(status="proven") _expect( proven == [