Files
ComfyUI-Ethanfel-Prompt-Bui…/item_axis_policy.py
T

133 lines
3.9 KiB
Python

from __future__ import annotations
import re
from typing import Any
PLACEHOLDER_VALUES = {"", "any", "auto", "random", "none", "null"}
PREFERRED_VALUE_KEYS = ("text", "prompt", "template", "value", "name")
METADATA_AXIS_KEYS = {"action_family", "position_family", "position_key", "position_keys", "krea2_variant_keys"}
ACTION_CONTEXT_PRIORITY = (
"position",
"body_position",
"body_arrangement",
"arrangement",
"angle",
"surface",
"body_contact",
"leg_detail",
"outer_act",
"contact_detail",
"texture_detail",
"hand_detail",
"visibility",
"expression_detail",
"oral_act",
"oral_detail",
"penetration_act",
"penetration_detail",
"anal_act",
"double_act",
"threesome_act",
"group_act",
)
def clean_text(value: Any) -> str:
text = "" if value is None else str(value)
text = text.replace("\n", " ")
text = re.sub(r"\s+", " ", text).strip()
text = re.sub(r"\s+([,.;:])", r"\1", text)
return text
def value_texts(value: Any) -> list[str]:
if isinstance(value, str):
text = clean_text(value).strip(" .")
return [text] if text and text.lower() not in PLACEHOLDER_VALUES else []
if isinstance(value, (int, float, bool)) or value is None:
return []
if isinstance(value, list):
texts: list[str] = []
for item in value:
texts.extend(value_texts(item))
return texts
if isinstance(value, dict):
for preferred in PREFERRED_VALUE_KEYS:
preferred_texts = value_texts(value.get(preferred))
if preferred_texts:
return preferred_texts
texts: list[str] = []
for item in value.values():
texts.extend(value_texts(item))
return texts
return []
def axis_value_texts(
axis_values: Any,
*,
priority: tuple[str, ...] = (),
include_unprioritized: bool = True,
skip_keys: set[str] | frozenset[str] | tuple[str, ...] = (),
existing_text: Any = "",
) -> list[str]:
if not isinstance(axis_values, dict):
return []
skipped = {str(key) for key in skip_keys}
keys: list[str] = []
for key in priority:
if key in axis_values and key not in skipped and key not in keys:
keys.append(key)
if include_unprioritized:
for key in axis_values:
if key not in skipped and key not in keys:
keys.append(key)
existing = clean_text(existing_text).lower()
texts: list[str] = []
seen: set[str] = set()
for key in keys:
for text in value_texts(axis_values.get(key)):
normalized = clean_text(text).strip(" .")
lower = normalized.lower()
if not normalized or lower in seen or (existing and lower in existing):
continue
texts.append(normalized)
seen.add(lower)
return texts
def action_context_text(axis_values: Any) -> str:
return " ".join(
axis_value_texts(
axis_values,
priority=ACTION_CONTEXT_PRIORITY,
include_unprioritized=False,
)
)
def context_text(*parts: Any, axis_values: Any = None) -> str:
text_parts = [clean_text(part) for part in parts if clean_text(part)]
text_parts.extend(axis_value_texts(axis_values, skip_keys=METADATA_AXIS_KEYS))
return " ".join(part.lower() for part in text_parts if part)
def key_text(axis_values: Any, key: str) -> str:
if not isinstance(axis_values, dict):
return ""
values = value_texts(axis_values.get(key))
return values[0].lower() if values else ""
def row_axis_value_texts(
row: dict[str, Any],
*,
skip_keys: set[str] | frozenset[str] | tuple[str, ...] = (),
existing_text: Any = "",
) -> list[str]:
if not isinstance(row, dict):
return []
return axis_value_texts(row.get("item_axis_values"), skip_keys=skip_keys, existing_text=existing_text)