diff --git a/krea_configured_cast_formatter.py b/krea_configured_cast_formatter.py index ac6dcda..532e491 100644 --- a/krea_configured_cast_formatter.py +++ b/krea_configured_cast_formatter.py @@ -3,6 +3,11 @@ from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable +try: + from . import krea2_pose_variant_catalog +except ImportError: # Allows local smoke tests with top-level imports. + import krea2_pose_variant_catalog + MOUTH_EXPRESSION_TERMS = ("mouth", "oral", "tongue", "lips", "gagging", "saliva", "drool") TOP_VIEW_ORAL_VARIANT = "pov_blowjob_top_down_vertical_shaft" @@ -110,6 +115,21 @@ def _has_krea2_oral_contact_variant(row: dict[str, Any]) -> bool: return any(key in ORAL_CONTACT_VARIANTS for key in _krea2_variant_keys(row)) +def _has_krea2_atlas_variant(row: dict[str, Any]) -> bool: + return any(krea2_pose_variant_catalog.get_variant(key) for key in _krea2_variant_keys(row)) + + +def _has_krea2_top_down_variant(row: dict[str, Any]) -> bool: + for key in _krea2_variant_keys(row): + variant = krea2_pose_variant_catalog.get_variant(key) + geometry = " ".join( + [str(variant.get("canonical_geometry") or ""), *[str(cue) for cue in variant.get("prompt_cues") or []]] + ).lower() + if any(term in geometry for term in ("top-down", "top view", "top-view", "nadir", "overhead")): + return True + return False + + def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) -> Any: if not _has_krea2_oral_contact_variant(row): return expression @@ -126,7 +146,7 @@ def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) - def _filter_camera_scene_for_krea2_variant(row: dict[str, Any], camera_scene: Any) -> str: text = str(camera_scene or "") - if _has_krea2_oral_contact_variant(row) and "eye-level" in text.lower(): + if (_has_krea2_oral_contact_variant(row) or _has_krea2_top_down_variant(row)) and "eye-level" in text.lower(): return "" return text @@ -193,7 +213,7 @@ def format_configured_cast_result( action, ) camera_scene = _filter_camera_scene_for_krea2_variant(row, request.camera_scene) - output_composition = deps.pov_composition_text(composition, pov_labels) + output_composition = "" if _has_krea2_atlas_variant(row) else deps.pov_composition_text(composition, pov_labels) parts = [ action, scene_anchor, diff --git a/krea_pair_formatter.py b/krea_pair_formatter.py index 672ca24..171113e 100644 --- a/krea_pair_formatter.py +++ b/krea_pair_formatter.py @@ -3,6 +3,11 @@ from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable +try: + from . import krea2_pose_variant_catalog +except ImportError: # Allows local smoke tests with top-level imports. + import krea2_pose_variant_catalog + MOUTH_EXPRESSION_TERMS = ("mouth", "oral", "tongue", "lips", "gagging", "saliva", "drool") TOP_VIEW_ORAL_VARIANT = "pov_blowjob_top_down_vertical_shaft" @@ -88,6 +93,21 @@ def _has_krea2_oral_contact_variant(row: dict[str, Any]) -> bool: return any(key in ORAL_CONTACT_VARIANTS for key in _krea2_variant_keys(row)) +def _has_krea2_atlas_variant(row: dict[str, Any]) -> bool: + return any(krea2_pose_variant_catalog.get_variant(key) for key in _krea2_variant_keys(row)) + + +def _has_krea2_top_down_variant(row: dict[str, Any]) -> bool: + for key in _krea2_variant_keys(row): + variant = krea2_pose_variant_catalog.get_variant(key) + geometry = " ".join( + [str(variant.get("canonical_geometry") or ""), *[str(cue) for cue in variant.get("prompt_cues") or []]] + ).lower() + if any(term in geometry for term in ("top-down", "top view", "top-view", "nadir", "overhead")): + return True + return False + + def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) -> Any: if not _has_krea2_oral_contact_variant(row): return expression @@ -104,7 +124,7 @@ def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) - def _filter_camera_scene_for_krea2_variant(row: dict[str, Any], camera_scene: Any) -> str: text = str(camera_scene or "") - if _has_krea2_oral_contact_variant(row) and "eye-level" in text.lower(): + if (_has_krea2_oral_contact_variant(row) or _has_krea2_top_down_variant(row)) and "eye-level" in text.lower(): return "" return text @@ -190,7 +210,7 @@ def format_insta_pair_result(request: KreaPairFormatRequest, deps: KreaPairForma hard_axis_values, hard_detail_density, ) - hard_output_composition = deps.pov_composition_text(hard_composition, pov_labels) + hard_output_composition = "" if _has_krea2_atlas_variant(hard) else deps.pov_composition_text(hard_composition, pov_labels) same_soft_cast = options.get("softcore_cast") == "same_as_hardcore" soft_output_composition = deps.pov_composition_text(soft.get("composition"), pov_labels if same_soft_cast else []) soft_cast_presence = deps.softcore_cast_presence_phrase( diff --git a/krea_pov_actions.py b/krea_pov_actions.py index 3e79e09..5e347de 100644 --- a/krea_pov_actions.py +++ b/krea_pov_actions.py @@ -4,6 +4,7 @@ import re from typing import Any try: + from . import krea2_pose_variant_catalog from . import outercourse_action_policy as outercourse_policy from .krea_action_context import ( axis_values_text, @@ -15,6 +16,7 @@ try: ) from .krea_detail import limit_detail_for_density except ImportError: # Allows local smoke tests with `python -c`. + import krea2_pose_variant_catalog import outercourse_action_policy as outercourse_policy from krea_action_context import ( axis_values_text, @@ -52,6 +54,65 @@ def _has_krea2_variant(axis_values: Any, key: str) -> bool: return key in _list_values(axis_values.get("krea2_variant_keys")) +def _metadata_text(axis_values: Any, key: str) -> str: + if not isinstance(axis_values, dict): + return "" + return _clean(axis_values.get(key, "")).lower() + + +def _metadata_values(axis_values: Any, key: str) -> set[str]: + if not isinstance(axis_values, dict): + return set() + return {value for value in _list_values(axis_values.get(key)) if value} + + +def _variant_matches_route(variant: dict[str, Any], axis_values: Any) -> bool: + action_family = _metadata_text(axis_values, "action_family") + if action_family and _clean(variant.get("action_family", "")).lower() != action_family: + return False + route_positions = _metadata_values(axis_values, "position_keys") + route_position = _metadata_text(axis_values, "position_key") + if route_position: + route_positions.add(route_position) + variant_positions = {str(position) for position in variant.get("position_keys", []) if str(position).strip()} + return not route_positions or not variant_positions or bool(route_positions & variant_positions) + + +def _selected_krea2_atlas_variant(axis_values: Any) -> dict[str, Any]: + if not isinstance(axis_values, dict): + return {} + keys = _list_values(axis_values.get("krea2_variant_keys")) + variants = [krea2_pose_variant_catalog.get_variant(key) for key in keys] + variants = [variant for variant in variants if variant.get("key")] + if not variants: + return {} + if len(variants) == 1: + return variants[0] + matching = [variant for variant in variants if _variant_matches_route(variant, axis_values)] + return matching[0] if len(matching) == 1 else {} + + +def _unique_texts(values: list[Any]) -> list[str]: + texts: list[str] = [] + seen: set[str] = set() + for value in values: + text = _clean(value).strip(" .;") + lower = text.lower() + if not text or lower in seen: + continue + texts.append(text) + seen.add(lower) + return texts + + +def _krea2_atlas_variant_sentence(axis_values: Any) -> str: + variant = _selected_krea2_atlas_variant(axis_values) + if not variant: + return "" + cues = _unique_texts([variant.get("canonical_geometry"), *(variant.get("prompt_cues") or [])]) + return _clean("; ".join(cues)).rstrip(".") + + def pov_ejaculation_target(context: str) -> str: if any( token in context @@ -268,6 +329,10 @@ def pov_hardcore_pose_sentence( details = pov_clean_oral_detail(action_text.split(";", 1)[1], f"{context} {base}", detail_density) return _clean(f"{base}; {details}" if details else base).rstrip(".") + atlas_sentence = _krea2_atlas_variant_sentence(axis_values) + if atlas_sentence: + return sentence(atlas_sentence) + def oral_direction() -> tuple[bool, bool]: oral_context = f"{context} {action_lower}" woman_gives = any( diff --git a/prompt_builder.py b/prompt_builder.py index fd8bad4..bae9243 100644 --- a/prompt_builder.py +++ b/prompt_builder.py @@ -717,6 +717,26 @@ def _axis_values_with_krea2_variant_keys( return merged +def _axis_values_with_hardcore_route_metadata( + axis_values: dict[str, Any], + *, + action_family: str, + position_family: str, + position_key: str, + position_keys: list[str], +) -> dict[str, Any]: + merged = dict(axis_values) + if action_family: + merged["action_family"] = action_family + if position_family: + merged["position_family"] = position_family + if position_key: + merged["position_key"] = position_key + if position_keys: + merged["position_keys"] = [str(key) for key in position_keys if str(key).strip()] + return merged + + def build_hardcore_position_pool_json( hardcore_position_config: str | dict[str, Any] | None = "", combine_mode: str = "replace", @@ -2473,6 +2493,13 @@ def _build_custom_row( position_keys = list(action_route.position_keys) position_key = action_route.position_key action_family = action_route.action_family + item_axis_values = _axis_values_with_hardcore_route_metadata( + item_axis_values, + action_family=action_family, + position_family=position_family, + position_key=position_key, + position_keys=position_keys, + ) text_fields = _row_text_fields(category, subcategory, item, style_config) assembly_request = row_assembly_policy.CustomRowAssemblyRequest( diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 07a1652..de6f38c 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -6781,6 +6781,80 @@ def smoke_krea2_pov_pose_variant_catalog() -> None: _expect(required_keys.issubset(seen_keys), "Krea2 POV pose-variant catalog lost a proven starter variant") +def _atlas_variant_include_key(variant_key: str) -> str: + key = "".join(char if char.isalnum() else "_" for char in str(variant_key).lower().removeprefix("pov_")).strip("_") + while "__" in key: + key = key.replace("__", "_") + return f"include_{key}" + + +def smoke_krea2_pov_atlas_variant_prompt_routes() -> None: + filter_by_action_family = { + "penetration": "SxCPKrea2POVPenetrationFilter", + "oral": "SxCPKrea2POVOralFilter", + "outercourse": "SxCPKrea2POVOutercourseFilter", + "manual": "SxCPKrea2POVManualFilter", + "toy": "SxCPKrea2POVToyFilter", + "climax": "SxCPKrea2POVClimaxFilter", + "interaction": "SxCPKrea2POVInteractionFilter", + } + variants = krea2_pose_variant_catalog.variants() + _expect(variants, "Krea2 POV atlas prompt route smoke found no variants") + for offset, variant in enumerate(variants, start=4510): + key = _expect_text("krea2_pov_atlas_variant_prompt_routes.key", variant.get("key"), 8) + action_family = _expect_text(f"{key}.action_family", variant.get("action_family"), 3) + node_name = filter_by_action_family.get(action_family) + _expect(node_name in sxcp_nodes.NODE_CLASS_MAPPINGS, f"{key} has no Krea2 POV filter node for action family {action_family!r}") + include_key = _atlas_variant_include_key(key) + node_cls = sxcp_nodes.NODE_CLASS_MAPPINGS[node_name] + _expect(include_key in (node_cls.INPUT_TYPES().get("required") or {}), f"{node_name} does not expose {include_key}") + variant_config = node_cls().build("replace", "", **{include_key: True})[0] + pair = pb.build_insta_of_pair( + row_number=1, + start_index=1, + seed=offset, + ethnicity="any", + figure="random", + no_plus_women=False, + no_black=False, + trigger=Trigger, + prepend_trigger_to_prompt=True, + options_json=_insta_options( + softcore_camera_mode="from_camera_config", + hardcore_camera_mode="from_camera_config", + camera_detail="compact", + ), + character_cast=_character_cast(pov_man=True), + hardcore_position_config=variant_config, + location_config=_coworking_location_config(), + hardcore_camera_config=_orbit_camera( + horizontal_angle=45, + vertical_angle=0, + zoom=7.5, + subject_focus="action", + ), + ) + _expect_pair(pair, f"krea2_pov_atlas_variant_prompt_routes.{key}") + hard_row = pair.get("hardcore_row") or {} + variant_keys = (hard_row.get("hardcore_position_config") or {}).get("krea2_variant_keys") or [] + _expect(variant_keys == [key], f"{key} row lost exact variant metadata: {variant_keys}") + krea = krea_formatter.format_krea2_prompt("", metadata_json=_json(pair), target="hardcore") + prompt = _expect_text(f"{key}.krea_prompt", krea.get("krea_prompt"), 80).lower() + for cue in variant.get("prompt_cues") or []: + cue_text = _expect_text(f"{key}.prompt_cue", cue, 8).lower() + _expect(cue_text in prompt, f"{key} final Krea prompt lost atlas cue {cue_text!r}: {prompt}") + _expect( + "framed as " not in prompt and "the image is framed as " not in prompt, + f"{key} final Krea prompt kept generic composition text after atlas route: {prompt}", + ) + atlas_geometry = " ".join([str(variant.get("canonical_geometry") or ""), *[str(cue) for cue in variant.get("prompt_cues") or []]]).lower() + if any(term in atlas_geometry for term in ("top-down", "top view", "top-view", "nadir", "overhead")): + _expect("eye-level shot" not in prompt, f"{key} final Krea prompt kept contradictory eye-level camera text: {prompt}") + for avoid in variant.get("avoid_cues") or []: + avoid_text = _expect_text(f"{key}.avoid_cue", avoid, 4).lower() + _expect(avoid_text not in prompt, f"{key} final Krea prompt leaked avoid cue {avoid_text!r}: {prompt}") + + def smoke_krea2_pose_variant_catalog_policy() -> None: catalog = krea2_pose_variant_catalog.load_catalog() _expect(catalog.get("version") == 1, "Krea2 pose-variant loader returned wrong catalog") @@ -8390,10 +8464,10 @@ def smoke_pov_oral_position_routes() -> None: ).lower() for term in ( "pov upright sitting oral position", - "woman sits low between his open thighs", - "face lowered close to the exact center contact point", - "open mouth covers the tip", - "hands stay low at the base", + "woman sits low between the viewer's open thighs", + "face lowers close to the exact center contact point", + "open mouth covers the centered tip", + "hands wrapped low at the base", ): _expect(term in sitting_variant_prompt, f"Sitting oral variant prompt missing {term!r}: {sitting_variant_prompt}") _expect( @@ -12058,6 +12132,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [ ("insta_pair_camera_split", smoke_insta_pair_camera_split), ("pov_camera_scene", smoke_pov_camera_scene), ("krea2_pov_pose_variant_catalog", smoke_krea2_pov_pose_variant_catalog), + ("krea2_pov_atlas_variant_prompt_routes", smoke_krea2_pov_atlas_variant_prompt_routes), ("krea2_pose_variant_catalog_policy", smoke_krea2_pose_variant_catalog_policy), ("krea2_eval_log_policy", smoke_krea2_eval_log_policy), ("krea2_prompt_guide_policy", smoke_krea2_prompt_guide_policy),