190 lines
6.8 KiB
Python
190 lines
6.8 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import Any
|
|
|
|
try:
|
|
from .hardcore_action_metadata import normalize_hardcore_action_family
|
|
from .hardcore_position_config import normalize_hardcore_position_family, normalize_hardcore_position_values
|
|
except ImportError: # Allows local smoke tests from the repository root.
|
|
from hardcore_action_metadata import normalize_hardcore_action_family
|
|
from hardcore_position_config import normalize_hardcore_position_family, normalize_hardcore_position_values
|
|
|
|
|
|
TEMPLATE_METADATA_KEYS = (
|
|
"action_family",
|
|
"action_type",
|
|
"family",
|
|
"position_family",
|
|
"position_key",
|
|
"position_keys",
|
|
"formatter_hint",
|
|
)
|
|
FORMATTER_HINT_ROUTES = ("all", "krea", "sdxl", "caption")
|
|
FORMATTER_HINT_ROUTE_ALIASES = {
|
|
"krea2": "krea",
|
|
"naturalizer": "caption",
|
|
"training_caption": "caption",
|
|
}
|
|
|
|
|
|
def template_metadata(item: Any) -> dict[str, Any]:
|
|
if not isinstance(item, dict):
|
|
return {}
|
|
return {key: item[key] for key in TEMPLATE_METADATA_KEYS if key in item}
|
|
|
|
|
|
def template_position_family(metadata: dict[str, Any]) -> str:
|
|
return normalize_hardcore_position_family(
|
|
metadata.get("position_family") or metadata.get("family"),
|
|
"",
|
|
)
|
|
|
|
|
|
def template_position_keys(metadata: dict[str, Any]) -> list[str]:
|
|
keys: list[Any] = []
|
|
if metadata.get("position_keys") is not None:
|
|
raw_keys = metadata.get("position_keys")
|
|
keys.extend(raw_keys if isinstance(raw_keys, list) else [raw_keys])
|
|
if metadata.get("position_key") is not None:
|
|
keys.append(metadata.get("position_key"))
|
|
return normalize_hardcore_position_values(keys)
|
|
|
|
|
|
def template_action_family(metadata: dict[str, Any]) -> str:
|
|
return normalize_hardcore_action_family(metadata.get("action_family") or metadata.get("action_type"), "")
|
|
|
|
|
|
def _list_from(value: Any) -> list[Any]:
|
|
if value is None:
|
|
return []
|
|
if isinstance(value, list):
|
|
return value
|
|
return [value]
|
|
|
|
|
|
def _clean_hint(value: Any) -> str:
|
|
return str(value or "").strip()
|
|
|
|
|
|
def normalize_formatter_route(value: Any) -> str:
|
|
route = re.sub(r"[^a-z0-9]+", "_", str(value or "").strip().lower()).strip("_")
|
|
route = FORMATTER_HINT_ROUTE_ALIASES.get(route, route)
|
|
return route if route in FORMATTER_HINT_ROUTES else ""
|
|
|
|
|
|
def formatter_hints(metadata: dict[str, Any]) -> dict[str, list[str]]:
|
|
raw = metadata.get("formatter_hint")
|
|
if raw is None:
|
|
return {}
|
|
normalized: dict[str, list[str]] = {}
|
|
|
|
def add(route: str, values: Any) -> None:
|
|
route = normalize_formatter_route(route)
|
|
if not route:
|
|
return
|
|
for value in _list_from(values):
|
|
hint = _clean_hint(value)
|
|
if hint and hint not in normalized.setdefault(route, []):
|
|
normalized[route].append(hint)
|
|
|
|
if isinstance(raw, dict):
|
|
for route, values in raw.items():
|
|
add(str(route), values)
|
|
else:
|
|
add("all", raw)
|
|
return {route: hints for route, hints in normalized.items() if hints}
|
|
|
|
|
|
def formatter_hints_for_route(row_or_hints: Any, route: str) -> list[str]:
|
|
route = normalize_formatter_route(route)
|
|
if not route or not isinstance(row_or_hints, dict):
|
|
return []
|
|
|
|
if isinstance(row_or_hints.get("formatter_hints"), dict):
|
|
raw_hints = row_or_hints.get("formatter_hints") or {}
|
|
elif "formatter_hint" in row_or_hints:
|
|
raw_hints = formatter_hints(row_or_hints)
|
|
elif row_or_hints and all(normalize_formatter_route(raw_route) for raw_route in row_or_hints):
|
|
raw_hints = row_or_hints
|
|
else:
|
|
return []
|
|
|
|
normalized: dict[str, list[str]] = {}
|
|
if isinstance(raw_hints, dict):
|
|
for raw_route, values in raw_hints.items():
|
|
normalized_route = normalize_formatter_route(raw_route)
|
|
if not normalized_route:
|
|
continue
|
|
for value in _list_from(values):
|
|
hint = _clean_hint(value)
|
|
if hint and hint not in normalized.setdefault(normalized_route, []):
|
|
normalized[normalized_route].append(hint)
|
|
|
|
hints: list[str] = []
|
|
for raw_route in ("all", route):
|
|
for hint in normalized.get(raw_route, []):
|
|
if hint not in hints:
|
|
hints.append(hint)
|
|
return hints
|
|
|
|
|
|
def merge_position_keys(primary: list[str], fallback: list[str]) -> list[str]:
|
|
merged: list[str] = []
|
|
for key in [*primary, *fallback]:
|
|
if key and key not in merged:
|
|
merged.append(key)
|
|
return merged
|
|
|
|
|
|
def _position_key_slug(value: Any) -> str:
|
|
return re.sub(r"[^a-z0-9]+", "_", str(value or "").strip().lower()).strip("_")
|
|
|
|
|
|
def template_metadata_errors(metadata: dict[str, Any]) -> list[str]:
|
|
errors: list[str] = []
|
|
raw_action_family = metadata.get("action_family") or metadata.get("action_type")
|
|
if raw_action_family and not template_action_family(metadata):
|
|
errors.append(f"unknown action_family/action_type: {raw_action_family}")
|
|
raw_position_family = metadata.get("position_family") or metadata.get("family")
|
|
if raw_position_family and not template_position_family(metadata):
|
|
errors.append(f"unknown position_family/family: {raw_position_family}")
|
|
raw_position_keys = []
|
|
if metadata.get("position_keys") is not None:
|
|
values = metadata.get("position_keys")
|
|
raw_position_keys.extend(values if isinstance(values, list) else [values])
|
|
if metadata.get("position_key") is not None:
|
|
raw_position_keys.append(metadata.get("position_key"))
|
|
normalized_keys = template_position_keys(metadata)
|
|
invalid_keys = [
|
|
str(value)
|
|
for value in raw_position_keys
|
|
if str(value or "").strip()
|
|
and str(value or "").strip() != "any"
|
|
and _position_key_slug(value) not in normalized_keys
|
|
]
|
|
if invalid_keys:
|
|
errors.append("unknown position key(s): " + ", ".join(invalid_keys))
|
|
raw_hint = metadata.get("formatter_hint")
|
|
if raw_hint is not None:
|
|
if isinstance(raw_hint, dict):
|
|
for route, values in raw_hint.items():
|
|
if not normalize_formatter_route(route):
|
|
errors.append(f"unknown formatter_hint route: {route}")
|
|
invalid_values = [
|
|
repr(value)
|
|
for value in _list_from(values)
|
|
if not isinstance(value, str) or not value.strip()
|
|
]
|
|
if invalid_values:
|
|
errors.append(f"invalid formatter_hint value(s) for {route}: " + ", ".join(invalid_values))
|
|
else:
|
|
invalid_values = [
|
|
repr(value)
|
|
for value in _list_from(raw_hint)
|
|
if not isinstance(value, str) or not value.strip()
|
|
]
|
|
if invalid_values:
|
|
errors.append("invalid formatter_hint value(s): " + ", ".join(invalid_values))
|
|
return errors
|