Use atlas cues for exact Krea2 POV routes

This commit is contained in:
2026-06-30 21:11:26 +02:00
parent 3832044256
commit 4689cc7942
5 changed files with 215 additions and 8 deletions
+22 -2
View File
@@ -3,6 +3,11 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
try:
from . import krea2_pose_variant_catalog
except ImportError: # Allows local smoke tests with top-level imports.
import krea2_pose_variant_catalog
MOUTH_EXPRESSION_TERMS = ("mouth", "oral", "tongue", "lips", "gagging", "saliva", "drool")
TOP_VIEW_ORAL_VARIANT = "pov_blowjob_top_down_vertical_shaft"
@@ -88,6 +93,21 @@ def _has_krea2_oral_contact_variant(row: dict[str, Any]) -> bool:
return any(key in ORAL_CONTACT_VARIANTS for key in _krea2_variant_keys(row))
def _has_krea2_atlas_variant(row: dict[str, Any]) -> bool:
return any(krea2_pose_variant_catalog.get_variant(key) for key in _krea2_variant_keys(row))
def _has_krea2_top_down_variant(row: dict[str, Any]) -> bool:
for key in _krea2_variant_keys(row):
variant = krea2_pose_variant_catalog.get_variant(key)
geometry = " ".join(
[str(variant.get("canonical_geometry") or ""), *[str(cue) for cue in variant.get("prompt_cues") or []]]
).lower()
if any(term in geometry for term in ("top-down", "top view", "top-view", "nadir", "overhead")):
return True
return False
def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) -> Any:
if not _has_krea2_oral_contact_variant(row):
return expression
@@ -104,7 +124,7 @@ def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) -
def _filter_camera_scene_for_krea2_variant(row: dict[str, Any], camera_scene: Any) -> str:
text = str(camera_scene or "")
if _has_krea2_oral_contact_variant(row) and "eye-level" in text.lower():
if (_has_krea2_oral_contact_variant(row) or _has_krea2_top_down_variant(row)) and "eye-level" in text.lower():
return ""
return text
@@ -190,7 +210,7 @@ def format_insta_pair_result(request: KreaPairFormatRequest, deps: KreaPairForma
hard_axis_values,
hard_detail_density,
)
hard_output_composition = deps.pov_composition_text(hard_composition, pov_labels)
hard_output_composition = "" if _has_krea2_atlas_variant(hard) else deps.pov_composition_text(hard_composition, pov_labels)
same_soft_cast = options.get("softcore_cast") == "same_as_hardcore"
soft_output_composition = deps.pov_composition_text(soft.get("composition"), pov_labels if same_soft_cast else [])
soft_cast_presence = deps.softcore_cast_presence_phrase(