Use atlas cues for exact Krea2 POV routes

This commit is contained in:
2026-06-30 21:11:26 +02:00
parent 3832044256
commit 4689cc7942
5 changed files with 215 additions and 8 deletions
+22 -2
View File
@@ -3,6 +3,11 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
try:
from . import krea2_pose_variant_catalog
except ImportError: # Allows local smoke tests with top-level imports.
import krea2_pose_variant_catalog
MOUTH_EXPRESSION_TERMS = ("mouth", "oral", "tongue", "lips", "gagging", "saliva", "drool")
TOP_VIEW_ORAL_VARIANT = "pov_blowjob_top_down_vertical_shaft"
@@ -110,6 +115,21 @@ def _has_krea2_oral_contact_variant(row: dict[str, Any]) -> bool:
return any(key in ORAL_CONTACT_VARIANTS for key in _krea2_variant_keys(row))
def _has_krea2_atlas_variant(row: dict[str, Any]) -> bool:
return any(krea2_pose_variant_catalog.get_variant(key) for key in _krea2_variant_keys(row))
def _has_krea2_top_down_variant(row: dict[str, Any]) -> bool:
for key in _krea2_variant_keys(row):
variant = krea2_pose_variant_catalog.get_variant(key)
geometry = " ".join(
[str(variant.get("canonical_geometry") or ""), *[str(cue) for cue in variant.get("prompt_cues") or []]]
).lower()
if any(term in geometry for term in ("top-down", "top view", "top-view", "nadir", "overhead")):
return True
return False
def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) -> Any:
if not _has_krea2_oral_contact_variant(row):
return expression
@@ -126,7 +146,7 @@ def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) -
def _filter_camera_scene_for_krea2_variant(row: dict[str, Any], camera_scene: Any) -> str:
text = str(camera_scene or "")
if _has_krea2_oral_contact_variant(row) and "eye-level" in text.lower():
if (_has_krea2_oral_contact_variant(row) or _has_krea2_top_down_variant(row)) and "eye-level" in text.lower():
return ""
return text
@@ -193,7 +213,7 @@ def format_configured_cast_result(
action,
)
camera_scene = _filter_camera_scene_for_krea2_variant(row, request.camera_scene)
output_composition = deps.pov_composition_text(composition, pov_labels)
output_composition = "" if _has_krea2_atlas_variant(row) else deps.pov_composition_text(composition, pov_labels)
parts = [
action,
scene_anchor,