From 867916ee510c29636cd7e0f6e89b9b8041f4a368 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 27 Jun 2026 18:12:34 +0200 Subject: [PATCH] Centralize item axis value flattening --- caption_text_policy.py | 49 ++----------- docs/prompt-pool-routing-map.md | 3 +- item_axis_policy.py | 119 ++++++++++++++++++++++++++++++++ krea_action_context.py | 34 ++------- sdxl_tag_policy.py | 35 ++-------- tools/prompt_smoke.py | 32 +++++++++ 6 files changed, 170 insertions(+), 102 deletions(-) create mode 100644 item_axis_policy.py diff --git a/caption_text_policy.py b/caption_text_policy.py index 60dfd66..6f40c3b 100644 --- a/caption_text_policy.py +++ b/caption_text_policy.py @@ -7,6 +7,7 @@ try: from . import caption_metadata_routes from . import caption_policy from . import formatter_input as input_policy + from . import item_axis_policy from . import krea_cast as cast_policy from . import route_metadata as route_metadata_policy from . import softcore_text_policy @@ -14,6 +15,7 @@ except ImportError: # Allows local smoke tests with `python -c`. import caption_metadata_routes import caption_policy import formatter_input as input_policy + import item_axis_policy import krea_cast as cast_policy import route_metadata as route_metadata_policy import softcore_text_policy @@ -97,49 +99,12 @@ def metadata_action_label(row: dict[str, Any], default: str = "sexual pose") -> return caption_policy.metadata_action_label(row, default) -def _axis_value_texts(value: Any) -> list[str]: - if isinstance(value, str): - text = clean_text(value).strip(" .") - return [text] if text and text.lower() not in ("any", "auto", "random", "none") else [] - if isinstance(value, (int, float, bool)) or value is None: - return [] - if isinstance(value, list): - texts: list[str] = [] - for item in value: - texts.extend(_axis_value_texts(item)) - return texts - if isinstance(value, dict): - for preferred in ("text", "prompt", "template", "value", "name"): - preferred_texts = _axis_value_texts(value.get(preferred)) - if preferred_texts: - return preferred_texts - texts: list[str] = [] - for item in value.values(): - texts.extend(_axis_value_texts(item)) - return texts - return [] - - def item_axis_detail_text(row: dict[str, Any], existing_text: str = "") -> str: - if not isinstance(row, dict): - return "" - axis_values = row.get("item_axis_values") - if not isinstance(axis_values, dict): - return "" - existing = clean_text(existing_text).lower() - details: list[str] = [] - seen: set[str] = set() - skipped_keys = {"action_family", "position_family", "position_key", "position_keys"} - for key, value in axis_values.items(): - if str(key) in skipped_keys: - continue - for text in _axis_value_texts(value): - normalized = clean_text(text).strip(" .") - lower = normalized.lower() - if not normalized or lower in seen or lower in existing: - continue - details.append(normalized) - seen.add(lower) + details = item_axis_policy.row_axis_value_texts( + row, + skip_keys=item_axis_policy.METADATA_AXIS_KEYS, + existing_text=existing_text, + ) return human_join(details) diff --git a/docs/prompt-pool-routing-map.md b/docs/prompt-pool-routing-map.md index ce62066..f3f2359 100644 --- a/docs/prompt-pool-routing-map.md +++ b/docs/prompt-pool-routing-map.md @@ -144,6 +144,7 @@ Core helper ownership: | `caption_format_route.py` | Top-level caption dispatch, input-hint and target normalization, caption profile application, metadata-vs-text branching, trigger wrapping, final prose hygiene, and method/output shape. | | `caption_policy.py` | Caption naturalizer policy data and helpers: caption profiles, style tails, item labels, metadata-family caption labels, detail/style-policy normalization, clothing cleanup, and composition cleanup. | | `caption_text_policy.py` | Caption sentence helpers, trigger wrapping, formatter-hint append, item-axis detail prose, row-value fallback wrappers, cast text wrappers, single-caption front parsing, and metadata-route dependency assembly used by `caption_naturalizer.py` and `caption_metadata_routes.py`. | +| `item_axis_policy.py` | Shared `item_axis_values` flattening, placeholder filtering, preferred dict-value extraction, priority-ordered Krea action context text, and row-axis text extraction used by Krea2, SDXL, and caption routes. | ## Node IO Map @@ -542,7 +543,7 @@ plain prompt text. When debugging, inspect these fields before editing pools. | `category_slug`, `subcategory_slug` | `row_category_route.select_category_item_route` | Debug/filtering | Stable-ish machine labels for selected category route. | | `content_seed_axis` | `row_category_route.select_category_item_route` | Debug | Shows whether the item/action was driven by `content` or `pose`. Critical for hardcore pose categories. | | `item` | `row_category_route.select_category_item_route` or Insta override | Krea/SDXL/Naturalizer | Clothing item, category item, or sexual scene/action text. | -| `item_axis_values` | `row_category_route.select_category_item_route` | Krea hardcore rewrite, SDXL tags | Filled template axes such as position/action/detail values. | +| `item_axis_values` | `row_category_route.select_category_item_route` | Krea hardcore rewrite, SDXL tags, natural captions | Filled template axes such as position/action/detail values. Shared flattening lives in `item_axis_policy.py`. | | `item_template_metadata` | `row_category_route.select_category_item_route` | Debug, Krea/SDXL/Naturalizer route metadata | Metadata inherited from category/subcategory/item `item_template_metadata` plus selected object-template metadata; used to prefer explicit action/position families and keys before inference. | | `formatter_hints` | `row_category_route.select_category_item_route` | Krea/SDXL/Naturalizer route specialization, debug | Normalized route-specific hints inherited from template metadata, keyed by `all`, `krea`, `sdxl`, or `caption`; each formatter consumes `all` plus its own route only. | | `action_family` | `row_route_metadata.resolve_action_position_route` | Krea hardcore rewrite, SDXL tags, natural captions, debug | Source-aware formatter semantic family such as `foreplay`, `outercourse`, `oral`, `penetration`, `toy_double`, or `climax`. | diff --git a/item_axis_policy.py b/item_axis_policy.py new file mode 100644 index 0000000..5afe38b --- /dev/null +++ b/item_axis_policy.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import re +from typing import Any + + +PLACEHOLDER_VALUES = {"", "any", "auto", "random", "none", "null"} +PREFERRED_VALUE_KEYS = ("text", "prompt", "template", "value", "name") +METADATA_AXIS_KEYS = {"action_family", "position_family", "position_key", "position_keys"} +ACTION_CONTEXT_PRIORITY = ( + "position", + "body_position", + "body_arrangement", + "arrangement", + "angle", + "surface", + "body_contact", + "leg_detail", + "outer_act", + "contact_detail", + "texture_detail", + "hand_detail", + "visibility", + "expression_detail", + "oral_act", + "oral_detail", + "penetration_act", + "penetration_detail", + "anal_act", + "double_act", + "threesome_act", + "group_act", +) + + +def clean_text(value: Any) -> str: + text = "" if value is None else str(value) + text = text.replace("\n", " ") + text = re.sub(r"\s+", " ", text).strip() + text = re.sub(r"\s+([,.;:])", r"\1", text) + return text + + +def value_texts(value: Any) -> list[str]: + if isinstance(value, str): + text = clean_text(value).strip(" .") + return [text] if text and text.lower() not in PLACEHOLDER_VALUES else [] + if isinstance(value, (int, float, bool)) or value is None: + return [] + if isinstance(value, list): + texts: list[str] = [] + for item in value: + texts.extend(value_texts(item)) + return texts + if isinstance(value, dict): + for preferred in PREFERRED_VALUE_KEYS: + preferred_texts = value_texts(value.get(preferred)) + if preferred_texts: + return preferred_texts + texts: list[str] = [] + for item in value.values(): + texts.extend(value_texts(item)) + return texts + return [] + + +def axis_value_texts( + axis_values: Any, + *, + priority: tuple[str, ...] = (), + include_unprioritized: bool = True, + skip_keys: set[str] | frozenset[str] | tuple[str, ...] = (), + existing_text: Any = "", +) -> list[str]: + if not isinstance(axis_values, dict): + return [] + skipped = {str(key) for key in skip_keys} + keys: list[str] = [] + for key in priority: + if key in axis_values and key not in skipped and key not in keys: + keys.append(key) + if include_unprioritized: + for key in axis_values: + if key not in skipped and key not in keys: + keys.append(key) + + existing = clean_text(existing_text).lower() + texts: list[str] = [] + seen: set[str] = set() + for key in keys: + for text in value_texts(axis_values.get(key)): + normalized = clean_text(text).strip(" .") + lower = normalized.lower() + if not normalized or lower in seen or (existing and lower in existing): + continue + texts.append(normalized) + seen.add(lower) + return texts + + +def action_context_text(axis_values: Any) -> str: + return " ".join( + axis_value_texts( + axis_values, + priority=ACTION_CONTEXT_PRIORITY, + include_unprioritized=False, + ) + ) + + +def row_axis_value_texts( + row: dict[str, Any], + *, + skip_keys: set[str] | frozenset[str] | tuple[str, ...] = (), + existing_text: Any = "", +) -> list[str]: + if not isinstance(row, dict): + return [] + return axis_value_texts(row.get("item_axis_values"), skip_keys=skip_keys, existing_text=existing_text) diff --git a/krea_action_context.py b/krea_action_context.py index c9d8e9f..ddb2eb1 100644 --- a/krea_action_context.py +++ b/krea_action_context.py @@ -3,6 +3,11 @@ from __future__ import annotations import re from typing import Any +try: + from . import item_axis_policy +except ImportError: # Allows local smoke tests with top-level imports. + import item_axis_policy + HARDCORE_DETAIL_DENSITY_CHOICES = {"compact", "balanced", "dense"} @@ -21,34 +26,7 @@ def normalize_hardcore_detail_density(value: Any) -> str: def axis_values_text(axis_values: Any) -> str: - if not isinstance(axis_values, dict): - return "" - priority = ( - "position", - "body_position", - "body_arrangement", - "arrangement", - "angle", - "surface", - "body_contact", - "leg_detail", - "outer_act", - "contact_detail", - "texture_detail", - "hand_detail", - "visibility", - "expression_detail", - "oral_act", - "oral_detail", - "penetration_act", - "penetration_detail", - "anal_act", - "double_act", - "threesome_act", - "group_act", - ) - parts = [_clean(axis_values.get(key)) for key in priority if _clean(axis_values.get(key))] - return " ".join(parts) + return item_axis_policy.action_context_text(axis_values) def position_context_text(role_graph: str, hard_item: str, composition: str = "", axis_values: Any = None) -> str: diff --git a/sdxl_tag_policy.py b/sdxl_tag_policy.py index a9b890e..8ef733d 100644 --- a/sdxl_tag_policy.py +++ b/sdxl_tag_policy.py @@ -5,12 +5,14 @@ from typing import Any try: from . import formatter_input as input_policy + from . import item_axis_policy from . import route_metadata as route_metadata_policy from . import sdxl_presets as sdxl_policy from . import sdxl_tag_routes from . import softcore_text_policy except ImportError: # Allows local smoke tests with `python -c`. import formatter_input as input_policy + import item_axis_policy import route_metadata as route_metadata_policy import sdxl_presets as sdxl_policy import sdxl_tag_routes @@ -102,40 +104,11 @@ def formatter_hint_tags(*rows: dict[str, Any]) -> list[str]: return tags -def _axis_value_texts(value: Any) -> list[str]: - if isinstance(value, str): - text = clean(value) - return [text] if text and text.lower() not in ("any", "auto", "random", "none") else [] - if isinstance(value, (int, float, bool)) or value is None: - return [] - if isinstance(value, list): - texts: list[str] = [] - for item in value: - texts.extend(_axis_value_texts(item)) - return texts - if isinstance(value, dict): - for preferred in ("text", "prompt", "template", "value", "name"): - preferred_texts = _axis_value_texts(value.get(preferred)) - if preferred_texts: - return preferred_texts - texts: list[str] = [] - for item in value.values(): - texts.extend(_axis_value_texts(item)) - return texts - return [] - - def axis_value_tags(row: dict[str, Any]) -> list[str]: - if not isinstance(row, dict): - return [] - axis_values = row.get("item_axis_values") - if not isinstance(axis_values, dict): - return [] tags: list[str] = [] seen: set[str] = set() - for value in axis_values.values(): - for text in _axis_value_texts(value): - add(tags, seen, text) + for text in item_axis_policy.row_axis_value_texts(row): + add(tags, seen, text) return tags diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index b7174c6..ba8ac27 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -50,6 +50,7 @@ import __init__ as sxcp_nodes # noqa: E402 import generation_profile_config # noqa: E402 import hardcore_role_outercourse # noqa: E402 import index_switch_policy # noqa: E402 +import item_axis_policy # noqa: E402 import node_tooltips # noqa: E402 import krea_cast # noqa: E402 import krea_action_details # noqa: E402 @@ -1605,6 +1606,36 @@ def smoke_outercourse_action_policy() -> None: _expect("wet lips" in deduped, "Krea outercourse dedupe removed useful texture clause") +def smoke_item_axis_policy() -> None: + axis_values = { + "ignored": "random", + "position": "kneeling oral position", + "contact_detail": {"text": "mouth contact at hip height"}, + "nested": {"unused": "fallback body detail"}, + "list_detail": ["hands on hips", "auto"], + "unprioritized_detail": "extra unprioritized cue", + } + texts = item_axis_policy.axis_value_texts(axis_values) + _expect("kneeling oral position" in texts, "Item axis policy lost position value") + _expect("mouth contact at hip height" in texts, "Item axis policy lost preferred dict text") + _expect("fallback body detail" in texts, "Item axis policy lost nested fallback text") + _expect("hands on hips" in texts, "Item axis policy lost list text") + _expect("random" not in texts and "auto" not in texts, "Item axis policy leaked placeholder values") + _expect( + item_axis_policy.axis_value_texts(axis_values, existing_text="kneeling oral position already present")[0] + == "mouth contact at hip height", + "Item axis policy should skip details already present in existing text", + ) + context_text = item_axis_policy.action_context_text(axis_values) + _expect("kneeling oral position" in context_text, "Item axis policy context lost priority position") + _expect("mouth contact at hip height" in context_text, "Item axis policy context lost priority contact") + _expect("extra unprioritized cue" not in context_text, "Item axis policy context should ignore unprioritized values") + _expect( + krea_action_context.axis_values_text(axis_values) == context_text, + "Krea action context should delegate to shared item axis policy", + ) + + def smoke_krea_row_fields_policy() -> None: row = { "subject_type": "configured_cast", @@ -8472,6 +8503,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [ ("krea_normal_row_routes", smoke_krea_normal_row_routes), ("krea_action_details_policy", smoke_krea_action_details_policy), ("outercourse_action_policy", smoke_outercourse_action_policy), + ("item_axis_policy", smoke_item_axis_policy), ("krea_row_fields_policy", smoke_krea_row_fields_policy), ("location_config_policy", smoke_location_config_policy), ("row_location_policy", smoke_row_location_policy),