Use atlas cues for exact Krea2 POV routes

This commit is contained in:
2026-06-30 21:11:26 +02:00
parent 3832044256
commit 4689cc7942
5 changed files with 215 additions and 8 deletions
+65
View File
@@ -4,6 +4,7 @@ import re
from typing import Any
try:
from . import krea2_pose_variant_catalog
from . import outercourse_action_policy as outercourse_policy
from .krea_action_context import (
axis_values_text,
@@ -15,6 +16,7 @@ try:
)
from .krea_detail import limit_detail_for_density
except ImportError: # Allows local smoke tests with `python -c`.
import krea2_pose_variant_catalog
import outercourse_action_policy as outercourse_policy
from krea_action_context import (
axis_values_text,
@@ -52,6 +54,65 @@ def _has_krea2_variant(axis_values: Any, key: str) -> bool:
return key in _list_values(axis_values.get("krea2_variant_keys"))
def _metadata_text(axis_values: Any, key: str) -> str:
if not isinstance(axis_values, dict):
return ""
return _clean(axis_values.get(key, "")).lower()
def _metadata_values(axis_values: Any, key: str) -> set[str]:
if not isinstance(axis_values, dict):
return set()
return {value for value in _list_values(axis_values.get(key)) if value}
def _variant_matches_route(variant: dict[str, Any], axis_values: Any) -> bool:
action_family = _metadata_text(axis_values, "action_family")
if action_family and _clean(variant.get("action_family", "")).lower() != action_family:
return False
route_positions = _metadata_values(axis_values, "position_keys")
route_position = _metadata_text(axis_values, "position_key")
if route_position:
route_positions.add(route_position)
variant_positions = {str(position) for position in variant.get("position_keys", []) if str(position).strip()}
return not route_positions or not variant_positions or bool(route_positions & variant_positions)
def _selected_krea2_atlas_variant(axis_values: Any) -> dict[str, Any]:
if not isinstance(axis_values, dict):
return {}
keys = _list_values(axis_values.get("krea2_variant_keys"))
variants = [krea2_pose_variant_catalog.get_variant(key) for key in keys]
variants = [variant for variant in variants if variant.get("key")]
if not variants:
return {}
if len(variants) == 1:
return variants[0]
matching = [variant for variant in variants if _variant_matches_route(variant, axis_values)]
return matching[0] if len(matching) == 1 else {}
def _unique_texts(values: list[Any]) -> list[str]:
texts: list[str] = []
seen: set[str] = set()
for value in values:
text = _clean(value).strip(" .;")
lower = text.lower()
if not text or lower in seen:
continue
texts.append(text)
seen.add(lower)
return texts
def _krea2_atlas_variant_sentence(axis_values: Any) -> str:
variant = _selected_krea2_atlas_variant(axis_values)
if not variant:
return ""
cues = _unique_texts([variant.get("canonical_geometry"), *(variant.get("prompt_cues") or [])])
return _clean("; ".join(cues)).rstrip(".")
def pov_ejaculation_target(context: str) -> str:
if any(
token in context
@@ -268,6 +329,10 @@ def pov_hardcore_pose_sentence(
details = pov_clean_oral_detail(action_text.split(";", 1)[1], f"{context} {base}", detail_density)
return _clean(f"{base}; {details}" if details else base).rstrip(".")
atlas_sentence = _krea2_atlas_variant_sentence(axis_values)
if atlas_sentence:
return sentence(atlas_sentence)
def oral_direction() -> tuple[bool, bool]:
oral_context = f"{context} {action_lower}"
woman_gives = any(