120 lines
3.4 KiB
Python
120 lines
3.4 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import Any
|
|
|
|
|
|
PLACEHOLDER_VALUES = {"", "any", "auto", "random", "none", "null"}
|
|
PREFERRED_VALUE_KEYS = ("text", "prompt", "template", "value", "name")
|
|
METADATA_AXIS_KEYS = {"action_family", "position_family", "position_key", "position_keys"}
|
|
ACTION_CONTEXT_PRIORITY = (
|
|
"position",
|
|
"body_position",
|
|
"body_arrangement",
|
|
"arrangement",
|
|
"angle",
|
|
"surface",
|
|
"body_contact",
|
|
"leg_detail",
|
|
"outer_act",
|
|
"contact_detail",
|
|
"texture_detail",
|
|
"hand_detail",
|
|
"visibility",
|
|
"expression_detail",
|
|
"oral_act",
|
|
"oral_detail",
|
|
"penetration_act",
|
|
"penetration_detail",
|
|
"anal_act",
|
|
"double_act",
|
|
"threesome_act",
|
|
"group_act",
|
|
)
|
|
|
|
|
|
def clean_text(value: Any) -> str:
|
|
text = "" if value is None else str(value)
|
|
text = text.replace("\n", " ")
|
|
text = re.sub(r"\s+", " ", text).strip()
|
|
text = re.sub(r"\s+([,.;:])", r"\1", text)
|
|
return text
|
|
|
|
|
|
def value_texts(value: Any) -> list[str]:
|
|
if isinstance(value, str):
|
|
text = clean_text(value).strip(" .")
|
|
return [text] if text and text.lower() not in PLACEHOLDER_VALUES else []
|
|
if isinstance(value, (int, float, bool)) or value is None:
|
|
return []
|
|
if isinstance(value, list):
|
|
texts: list[str] = []
|
|
for item in value:
|
|
texts.extend(value_texts(item))
|
|
return texts
|
|
if isinstance(value, dict):
|
|
for preferred in PREFERRED_VALUE_KEYS:
|
|
preferred_texts = value_texts(value.get(preferred))
|
|
if preferred_texts:
|
|
return preferred_texts
|
|
texts: list[str] = []
|
|
for item in value.values():
|
|
texts.extend(value_texts(item))
|
|
return texts
|
|
return []
|
|
|
|
|
|
def axis_value_texts(
|
|
axis_values: Any,
|
|
*,
|
|
priority: tuple[str, ...] = (),
|
|
include_unprioritized: bool = True,
|
|
skip_keys: set[str] | frozenset[str] | tuple[str, ...] = (),
|
|
existing_text: Any = "",
|
|
) -> list[str]:
|
|
if not isinstance(axis_values, dict):
|
|
return []
|
|
skipped = {str(key) for key in skip_keys}
|
|
keys: list[str] = []
|
|
for key in priority:
|
|
if key in axis_values and key not in skipped and key not in keys:
|
|
keys.append(key)
|
|
if include_unprioritized:
|
|
for key in axis_values:
|
|
if key not in skipped and key not in keys:
|
|
keys.append(key)
|
|
|
|
existing = clean_text(existing_text).lower()
|
|
texts: list[str] = []
|
|
seen: set[str] = set()
|
|
for key in keys:
|
|
for text in value_texts(axis_values.get(key)):
|
|
normalized = clean_text(text).strip(" .")
|
|
lower = normalized.lower()
|
|
if not normalized or lower in seen or (existing and lower in existing):
|
|
continue
|
|
texts.append(normalized)
|
|
seen.add(lower)
|
|
return texts
|
|
|
|
|
|
def action_context_text(axis_values: Any) -> str:
|
|
return " ".join(
|
|
axis_value_texts(
|
|
axis_values,
|
|
priority=ACTION_CONTEXT_PRIORITY,
|
|
include_unprioritized=False,
|
|
)
|
|
)
|
|
|
|
|
|
def row_axis_value_texts(
|
|
row: dict[str, Any],
|
|
*,
|
|
skip_keys: set[str] | frozenset[str] | tuple[str, ...] = (),
|
|
existing_text: Any = "",
|
|
) -> list[str]:
|
|
if not isinstance(row, dict):
|
|
return []
|
|
return axis_value_texts(row.get("item_axis_values"), skip_keys=skip_keys, existing_text=existing_text)
|