from __future__ import annotations import re from typing import Any try: from .hardcore_action_metadata import normalize_hardcore_action_family from .hardcore_position_config import normalize_hardcore_position_family, normalize_hardcore_position_values except ImportError: # Allows local smoke tests from the repository root. from hardcore_action_metadata import normalize_hardcore_action_family from hardcore_position_config import normalize_hardcore_position_family, normalize_hardcore_position_values TEMPLATE_METADATA_KEYS = ( "action_family", "action_type", "family", "position_family", "position_key", "position_keys", "formatter_hint", ) FORMATTER_HINT_ROUTES = ("all", "krea", "sdxl", "caption") FORMATTER_HINT_ROUTE_ALIASES = { "krea2": "krea", "naturalizer": "caption", "training_caption": "caption", } def template_metadata(item: Any) -> dict[str, Any]: if not isinstance(item, dict): return {} return {key: item[key] for key in TEMPLATE_METADATA_KEYS if key in item} def template_position_family(metadata: dict[str, Any]) -> str: return normalize_hardcore_position_family( metadata.get("position_family") or metadata.get("family"), "", ) def template_position_keys(metadata: dict[str, Any]) -> list[str]: keys: list[Any] = [] if metadata.get("position_keys") is not None: raw_keys = metadata.get("position_keys") keys.extend(raw_keys if isinstance(raw_keys, list) else [raw_keys]) if metadata.get("position_key") is not None: keys.append(metadata.get("position_key")) return normalize_hardcore_position_values(keys) def template_action_family(metadata: dict[str, Any]) -> str: return normalize_hardcore_action_family(metadata.get("action_family") or metadata.get("action_type"), "") def _list_from(value: Any) -> list[Any]: if value is None: return [] if isinstance(value, list): return value return [value] def _clean_hint(value: Any) -> str: return str(value or "").strip() def normalize_formatter_route(value: Any) -> str: route = re.sub(r"[^a-z0-9]+", "_", str(value or "").strip().lower()).strip("_") route = FORMATTER_HINT_ROUTE_ALIASES.get(route, route) return route if route in FORMATTER_HINT_ROUTES else "" def formatter_hints(metadata: dict[str, Any]) -> dict[str, list[str]]: raw = metadata.get("formatter_hint") if raw is None: return {} normalized: dict[str, list[str]] = {} def add(route: str, values: Any) -> None: route = normalize_formatter_route(route) if not route: return for value in _list_from(values): hint = _clean_hint(value) if hint and hint not in normalized.setdefault(route, []): normalized[route].append(hint) if isinstance(raw, dict): for route, values in raw.items(): add(str(route), values) else: add("all", raw) return {route: hints for route, hints in normalized.items() if hints} def formatter_hints_for_route(row_or_hints: Any, route: str) -> list[str]: route = normalize_formatter_route(route) if not route or not isinstance(row_or_hints, dict): return [] if isinstance(row_or_hints.get("formatter_hints"), dict): raw_hints = row_or_hints.get("formatter_hints") or {} elif "formatter_hint" in row_or_hints: raw_hints = formatter_hints(row_or_hints) elif row_or_hints and all(normalize_formatter_route(raw_route) for raw_route in row_or_hints): raw_hints = row_or_hints else: return [] normalized: dict[str, list[str]] = {} if isinstance(raw_hints, dict): for raw_route, values in raw_hints.items(): normalized_route = normalize_formatter_route(raw_route) if not normalized_route: continue for value in _list_from(values): hint = _clean_hint(value) if hint and hint not in normalized.setdefault(normalized_route, []): normalized[normalized_route].append(hint) hints: list[str] = [] for raw_route in ("all", route): for hint in normalized.get(raw_route, []): if hint not in hints: hints.append(hint) return hints def merge_position_keys(primary: list[str], fallback: list[str]) -> list[str]: merged: list[str] = [] for key in [*primary, *fallback]: if key and key not in merged: merged.append(key) return merged def _position_key_slug(value: Any) -> str: return re.sub(r"[^a-z0-9]+", "_", str(value or "").strip().lower()).strip("_") def template_metadata_errors(metadata: dict[str, Any]) -> list[str]: errors: list[str] = [] raw_action_family = metadata.get("action_family") or metadata.get("action_type") if raw_action_family and not template_action_family(metadata): errors.append(f"unknown action_family/action_type: {raw_action_family}") raw_position_family = metadata.get("position_family") or metadata.get("family") if raw_position_family and not template_position_family(metadata): errors.append(f"unknown position_family/family: {raw_position_family}") raw_position_keys = [] if metadata.get("position_keys") is not None: values = metadata.get("position_keys") raw_position_keys.extend(values if isinstance(values, list) else [values]) if metadata.get("position_key") is not None: raw_position_keys.append(metadata.get("position_key")) normalized_keys = template_position_keys(metadata) invalid_keys = [ str(value) for value in raw_position_keys if str(value or "").strip() and str(value or "").strip() != "any" and _position_key_slug(value) not in normalized_keys ] if invalid_keys: errors.append("unknown position key(s): " + ", ".join(invalid_keys)) raw_hint = metadata.get("formatter_hint") if raw_hint is not None: if isinstance(raw_hint, dict): for route, values in raw_hint.items(): if not normalize_formatter_route(route): errors.append(f"unknown formatter_hint route: {route}") invalid_values = [ repr(value) for value in _list_from(values) if not isinstance(value, str) or not value.strip() ] if invalid_values: errors.append(f"invalid formatter_hint value(s) for {route}: " + ", ".join(invalid_values)) else: invalid_values = [ repr(value) for value in _list_from(raw_hint) if not isinstance(value, str) or not value.strip() ] if invalid_values: errors.append("invalid formatter_hint value(s): " + ", ".join(invalid_values)) return errors