Use atlas cues for exact Krea2 POV routes
This commit is contained in:
@@ -3,6 +3,11 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
try:
|
||||
from . import krea2_pose_variant_catalog
|
||||
except ImportError: # Allows local smoke tests with top-level imports.
|
||||
import krea2_pose_variant_catalog
|
||||
|
||||
|
||||
MOUTH_EXPRESSION_TERMS = ("mouth", "oral", "tongue", "lips", "gagging", "saliva", "drool")
|
||||
TOP_VIEW_ORAL_VARIANT = "pov_blowjob_top_down_vertical_shaft"
|
||||
@@ -110,6 +115,21 @@ def _has_krea2_oral_contact_variant(row: dict[str, Any]) -> bool:
|
||||
return any(key in ORAL_CONTACT_VARIANTS for key in _krea2_variant_keys(row))
|
||||
|
||||
|
||||
def _has_krea2_atlas_variant(row: dict[str, Any]) -> bool:
|
||||
return any(krea2_pose_variant_catalog.get_variant(key) for key in _krea2_variant_keys(row))
|
||||
|
||||
|
||||
def _has_krea2_top_down_variant(row: dict[str, Any]) -> bool:
|
||||
for key in _krea2_variant_keys(row):
|
||||
variant = krea2_pose_variant_catalog.get_variant(key)
|
||||
geometry = " ".join(
|
||||
[str(variant.get("canonical_geometry") or ""), *[str(cue) for cue in variant.get("prompt_cues") or []]]
|
||||
).lower()
|
||||
if any(term in geometry for term in ("top-down", "top view", "top-view", "nadir", "overhead")):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) -> Any:
|
||||
if not _has_krea2_oral_contact_variant(row):
|
||||
return expression
|
||||
@@ -126,7 +146,7 @@ def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) -
|
||||
|
||||
def _filter_camera_scene_for_krea2_variant(row: dict[str, Any], camera_scene: Any) -> str:
|
||||
text = str(camera_scene or "")
|
||||
if _has_krea2_oral_contact_variant(row) and "eye-level" in text.lower():
|
||||
if (_has_krea2_oral_contact_variant(row) or _has_krea2_top_down_variant(row)) and "eye-level" in text.lower():
|
||||
return ""
|
||||
return text
|
||||
|
||||
@@ -193,7 +213,7 @@ def format_configured_cast_result(
|
||||
action,
|
||||
)
|
||||
camera_scene = _filter_camera_scene_for_krea2_variant(row, request.camera_scene)
|
||||
output_composition = deps.pov_composition_text(composition, pov_labels)
|
||||
output_composition = "" if _has_krea2_atlas_variant(row) else deps.pov_composition_text(composition, pov_labels)
|
||||
parts = [
|
||||
action,
|
||||
scene_anchor,
|
||||
|
||||
+22
-2
@@ -3,6 +3,11 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
try:
|
||||
from . import krea2_pose_variant_catalog
|
||||
except ImportError: # Allows local smoke tests with top-level imports.
|
||||
import krea2_pose_variant_catalog
|
||||
|
||||
|
||||
MOUTH_EXPRESSION_TERMS = ("mouth", "oral", "tongue", "lips", "gagging", "saliva", "drool")
|
||||
TOP_VIEW_ORAL_VARIANT = "pov_blowjob_top_down_vertical_shaft"
|
||||
@@ -88,6 +93,21 @@ def _has_krea2_oral_contact_variant(row: dict[str, Any]) -> bool:
|
||||
return any(key in ORAL_CONTACT_VARIANTS for key in _krea2_variant_keys(row))
|
||||
|
||||
|
||||
def _has_krea2_atlas_variant(row: dict[str, Any]) -> bool:
|
||||
return any(krea2_pose_variant_catalog.get_variant(key) for key in _krea2_variant_keys(row))
|
||||
|
||||
|
||||
def _has_krea2_top_down_variant(row: dict[str, Any]) -> bool:
|
||||
for key in _krea2_variant_keys(row):
|
||||
variant = krea2_pose_variant_catalog.get_variant(key)
|
||||
geometry = " ".join(
|
||||
[str(variant.get("canonical_geometry") or ""), *[str(cue) for cue in variant.get("prompt_cues") or []]]
|
||||
).lower()
|
||||
if any(term in geometry for term in ("top-down", "top view", "top-view", "nadir", "overhead")):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) -> Any:
|
||||
if not _has_krea2_oral_contact_variant(row):
|
||||
return expression
|
||||
@@ -104,7 +124,7 @@ def _filter_expression_for_krea2_variant(row: dict[str, Any], expression: Any) -
|
||||
|
||||
def _filter_camera_scene_for_krea2_variant(row: dict[str, Any], camera_scene: Any) -> str:
|
||||
text = str(camera_scene or "")
|
||||
if _has_krea2_oral_contact_variant(row) and "eye-level" in text.lower():
|
||||
if (_has_krea2_oral_contact_variant(row) or _has_krea2_top_down_variant(row)) and "eye-level" in text.lower():
|
||||
return ""
|
||||
return text
|
||||
|
||||
@@ -190,7 +210,7 @@ def format_insta_pair_result(request: KreaPairFormatRequest, deps: KreaPairForma
|
||||
hard_axis_values,
|
||||
hard_detail_density,
|
||||
)
|
||||
hard_output_composition = deps.pov_composition_text(hard_composition, pov_labels)
|
||||
hard_output_composition = "" if _has_krea2_atlas_variant(hard) else deps.pov_composition_text(hard_composition, pov_labels)
|
||||
same_soft_cast = options.get("softcore_cast") == "same_as_hardcore"
|
||||
soft_output_composition = deps.pov_composition_text(soft.get("composition"), pov_labels if same_soft_cast else [])
|
||||
soft_cast_presence = deps.softcore_cast_presence_phrase(
|
||||
|
||||
@@ -4,6 +4,7 @@ import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from . import krea2_pose_variant_catalog
|
||||
from . import outercourse_action_policy as outercourse_policy
|
||||
from .krea_action_context import (
|
||||
axis_values_text,
|
||||
@@ -15,6 +16,7 @@ try:
|
||||
)
|
||||
from .krea_detail import limit_detail_for_density
|
||||
except ImportError: # Allows local smoke tests with `python -c`.
|
||||
import krea2_pose_variant_catalog
|
||||
import outercourse_action_policy as outercourse_policy
|
||||
from krea_action_context import (
|
||||
axis_values_text,
|
||||
@@ -52,6 +54,65 @@ def _has_krea2_variant(axis_values: Any, key: str) -> bool:
|
||||
return key in _list_values(axis_values.get("krea2_variant_keys"))
|
||||
|
||||
|
||||
def _metadata_text(axis_values: Any, key: str) -> str:
|
||||
if not isinstance(axis_values, dict):
|
||||
return ""
|
||||
return _clean(axis_values.get(key, "")).lower()
|
||||
|
||||
|
||||
def _metadata_values(axis_values: Any, key: str) -> set[str]:
|
||||
if not isinstance(axis_values, dict):
|
||||
return set()
|
||||
return {value for value in _list_values(axis_values.get(key)) if value}
|
||||
|
||||
|
||||
def _variant_matches_route(variant: dict[str, Any], axis_values: Any) -> bool:
|
||||
action_family = _metadata_text(axis_values, "action_family")
|
||||
if action_family and _clean(variant.get("action_family", "")).lower() != action_family:
|
||||
return False
|
||||
route_positions = _metadata_values(axis_values, "position_keys")
|
||||
route_position = _metadata_text(axis_values, "position_key")
|
||||
if route_position:
|
||||
route_positions.add(route_position)
|
||||
variant_positions = {str(position) for position in variant.get("position_keys", []) if str(position).strip()}
|
||||
return not route_positions or not variant_positions or bool(route_positions & variant_positions)
|
||||
|
||||
|
||||
def _selected_krea2_atlas_variant(axis_values: Any) -> dict[str, Any]:
|
||||
if not isinstance(axis_values, dict):
|
||||
return {}
|
||||
keys = _list_values(axis_values.get("krea2_variant_keys"))
|
||||
variants = [krea2_pose_variant_catalog.get_variant(key) for key in keys]
|
||||
variants = [variant for variant in variants if variant.get("key")]
|
||||
if not variants:
|
||||
return {}
|
||||
if len(variants) == 1:
|
||||
return variants[0]
|
||||
matching = [variant for variant in variants if _variant_matches_route(variant, axis_values)]
|
||||
return matching[0] if len(matching) == 1 else {}
|
||||
|
||||
|
||||
def _unique_texts(values: list[Any]) -> list[str]:
|
||||
texts: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for value in values:
|
||||
text = _clean(value).strip(" .;")
|
||||
lower = text.lower()
|
||||
if not text or lower in seen:
|
||||
continue
|
||||
texts.append(text)
|
||||
seen.add(lower)
|
||||
return texts
|
||||
|
||||
|
||||
def _krea2_atlas_variant_sentence(axis_values: Any) -> str:
|
||||
variant = _selected_krea2_atlas_variant(axis_values)
|
||||
if not variant:
|
||||
return ""
|
||||
cues = _unique_texts([variant.get("canonical_geometry"), *(variant.get("prompt_cues") or [])])
|
||||
return _clean("; ".join(cues)).rstrip(".")
|
||||
|
||||
|
||||
def pov_ejaculation_target(context: str) -> str:
|
||||
if any(
|
||||
token in context
|
||||
@@ -268,6 +329,10 @@ def pov_hardcore_pose_sentence(
|
||||
details = pov_clean_oral_detail(action_text.split(";", 1)[1], f"{context} {base}", detail_density)
|
||||
return _clean(f"{base}; {details}" if details else base).rstrip(".")
|
||||
|
||||
atlas_sentence = _krea2_atlas_variant_sentence(axis_values)
|
||||
if atlas_sentence:
|
||||
return sentence(atlas_sentence)
|
||||
|
||||
def oral_direction() -> tuple[bool, bool]:
|
||||
oral_context = f"{context} {action_lower}"
|
||||
woman_gives = any(
|
||||
|
||||
@@ -717,6 +717,26 @@ def _axis_values_with_krea2_variant_keys(
|
||||
return merged
|
||||
|
||||
|
||||
def _axis_values_with_hardcore_route_metadata(
|
||||
axis_values: dict[str, Any],
|
||||
*,
|
||||
action_family: str,
|
||||
position_family: str,
|
||||
position_key: str,
|
||||
position_keys: list[str],
|
||||
) -> dict[str, Any]:
|
||||
merged = dict(axis_values)
|
||||
if action_family:
|
||||
merged["action_family"] = action_family
|
||||
if position_family:
|
||||
merged["position_family"] = position_family
|
||||
if position_key:
|
||||
merged["position_key"] = position_key
|
||||
if position_keys:
|
||||
merged["position_keys"] = [str(key) for key in position_keys if str(key).strip()]
|
||||
return merged
|
||||
|
||||
|
||||
def build_hardcore_position_pool_json(
|
||||
hardcore_position_config: str | dict[str, Any] | None = "",
|
||||
combine_mode: str = "replace",
|
||||
@@ -2473,6 +2493,13 @@ def _build_custom_row(
|
||||
position_keys = list(action_route.position_keys)
|
||||
position_key = action_route.position_key
|
||||
action_family = action_route.action_family
|
||||
item_axis_values = _axis_values_with_hardcore_route_metadata(
|
||||
item_axis_values,
|
||||
action_family=action_family,
|
||||
position_family=position_family,
|
||||
position_key=position_key,
|
||||
position_keys=position_keys,
|
||||
)
|
||||
|
||||
text_fields = _row_text_fields(category, subcategory, item, style_config)
|
||||
assembly_request = row_assembly_policy.CustomRowAssemblyRequest(
|
||||
|
||||
+79
-4
@@ -6781,6 +6781,80 @@ def smoke_krea2_pov_pose_variant_catalog() -> None:
|
||||
_expect(required_keys.issubset(seen_keys), "Krea2 POV pose-variant catalog lost a proven starter variant")
|
||||
|
||||
|
||||
def _atlas_variant_include_key(variant_key: str) -> str:
|
||||
key = "".join(char if char.isalnum() else "_" for char in str(variant_key).lower().removeprefix("pov_")).strip("_")
|
||||
while "__" in key:
|
||||
key = key.replace("__", "_")
|
||||
return f"include_{key}"
|
||||
|
||||
|
||||
def smoke_krea2_pov_atlas_variant_prompt_routes() -> None:
|
||||
filter_by_action_family = {
|
||||
"penetration": "SxCPKrea2POVPenetrationFilter",
|
||||
"oral": "SxCPKrea2POVOralFilter",
|
||||
"outercourse": "SxCPKrea2POVOutercourseFilter",
|
||||
"manual": "SxCPKrea2POVManualFilter",
|
||||
"toy": "SxCPKrea2POVToyFilter",
|
||||
"climax": "SxCPKrea2POVClimaxFilter",
|
||||
"interaction": "SxCPKrea2POVInteractionFilter",
|
||||
}
|
||||
variants = krea2_pose_variant_catalog.variants()
|
||||
_expect(variants, "Krea2 POV atlas prompt route smoke found no variants")
|
||||
for offset, variant in enumerate(variants, start=4510):
|
||||
key = _expect_text("krea2_pov_atlas_variant_prompt_routes.key", variant.get("key"), 8)
|
||||
action_family = _expect_text(f"{key}.action_family", variant.get("action_family"), 3)
|
||||
node_name = filter_by_action_family.get(action_family)
|
||||
_expect(node_name in sxcp_nodes.NODE_CLASS_MAPPINGS, f"{key} has no Krea2 POV filter node for action family {action_family!r}")
|
||||
include_key = _atlas_variant_include_key(key)
|
||||
node_cls = sxcp_nodes.NODE_CLASS_MAPPINGS[node_name]
|
||||
_expect(include_key in (node_cls.INPUT_TYPES().get("required") or {}), f"{node_name} does not expose {include_key}")
|
||||
variant_config = node_cls().build("replace", "", **{include_key: True})[0]
|
||||
pair = pb.build_insta_of_pair(
|
||||
row_number=1,
|
||||
start_index=1,
|
||||
seed=offset,
|
||||
ethnicity="any",
|
||||
figure="random",
|
||||
no_plus_women=False,
|
||||
no_black=False,
|
||||
trigger=Trigger,
|
||||
prepend_trigger_to_prompt=True,
|
||||
options_json=_insta_options(
|
||||
softcore_camera_mode="from_camera_config",
|
||||
hardcore_camera_mode="from_camera_config",
|
||||
camera_detail="compact",
|
||||
),
|
||||
character_cast=_character_cast(pov_man=True),
|
||||
hardcore_position_config=variant_config,
|
||||
location_config=_coworking_location_config(),
|
||||
hardcore_camera_config=_orbit_camera(
|
||||
horizontal_angle=45,
|
||||
vertical_angle=0,
|
||||
zoom=7.5,
|
||||
subject_focus="action",
|
||||
),
|
||||
)
|
||||
_expect_pair(pair, f"krea2_pov_atlas_variant_prompt_routes.{key}")
|
||||
hard_row = pair.get("hardcore_row") or {}
|
||||
variant_keys = (hard_row.get("hardcore_position_config") or {}).get("krea2_variant_keys") or []
|
||||
_expect(variant_keys == [key], f"{key} row lost exact variant metadata: {variant_keys}")
|
||||
krea = krea_formatter.format_krea2_prompt("", metadata_json=_json(pair), target="hardcore")
|
||||
prompt = _expect_text(f"{key}.krea_prompt", krea.get("krea_prompt"), 80).lower()
|
||||
for cue in variant.get("prompt_cues") or []:
|
||||
cue_text = _expect_text(f"{key}.prompt_cue", cue, 8).lower()
|
||||
_expect(cue_text in prompt, f"{key} final Krea prompt lost atlas cue {cue_text!r}: {prompt}")
|
||||
_expect(
|
||||
"framed as " not in prompt and "the image is framed as " not in prompt,
|
||||
f"{key} final Krea prompt kept generic composition text after atlas route: {prompt}",
|
||||
)
|
||||
atlas_geometry = " ".join([str(variant.get("canonical_geometry") or ""), *[str(cue) for cue in variant.get("prompt_cues") or []]]).lower()
|
||||
if any(term in atlas_geometry for term in ("top-down", "top view", "top-view", "nadir", "overhead")):
|
||||
_expect("eye-level shot" not in prompt, f"{key} final Krea prompt kept contradictory eye-level camera text: {prompt}")
|
||||
for avoid in variant.get("avoid_cues") or []:
|
||||
avoid_text = _expect_text(f"{key}.avoid_cue", avoid, 4).lower()
|
||||
_expect(avoid_text not in prompt, f"{key} final Krea prompt leaked avoid cue {avoid_text!r}: {prompt}")
|
||||
|
||||
|
||||
def smoke_krea2_pose_variant_catalog_policy() -> None:
|
||||
catalog = krea2_pose_variant_catalog.load_catalog()
|
||||
_expect(catalog.get("version") == 1, "Krea2 pose-variant loader returned wrong catalog")
|
||||
@@ -8390,10 +8464,10 @@ def smoke_pov_oral_position_routes() -> None:
|
||||
).lower()
|
||||
for term in (
|
||||
"pov upright sitting oral position",
|
||||
"woman sits low between his open thighs",
|
||||
"face lowered close to the exact center contact point",
|
||||
"open mouth covers the tip",
|
||||
"hands stay low at the base",
|
||||
"woman sits low between the viewer's open thighs",
|
||||
"face lowers close to the exact center contact point",
|
||||
"open mouth covers the centered tip",
|
||||
"hands wrapped low at the base",
|
||||
):
|
||||
_expect(term in sitting_variant_prompt, f"Sitting oral variant prompt missing {term!r}: {sitting_variant_prompt}")
|
||||
_expect(
|
||||
@@ -12058,6 +12132,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
|
||||
("insta_pair_camera_split", smoke_insta_pair_camera_split),
|
||||
("pov_camera_scene", smoke_pov_camera_scene),
|
||||
("krea2_pov_pose_variant_catalog", smoke_krea2_pov_pose_variant_catalog),
|
||||
("krea2_pov_atlas_variant_prompt_routes", smoke_krea2_pov_atlas_variant_prompt_routes),
|
||||
("krea2_pose_variant_catalog_policy", smoke_krea2_pose_variant_catalog_policy),
|
||||
("krea2_eval_log_policy", smoke_krea2_eval_log_policy),
|
||||
("krea2_prompt_guide_policy", smoke_krea2_prompt_guide_policy),
|
||||
|
||||
Reference in New Issue
Block a user