Extract filter config policy
This commit is contained in:
+52
-194
@@ -25,6 +25,7 @@ try:
|
||||
)
|
||||
from . import camera_config as camera_policy
|
||||
from . import category_cast_config as category_cast_policy
|
||||
from . import filter_config as filter_policy
|
||||
from . import generate_prompt_batches as g
|
||||
from . import generation_profile_config as generation_profile_policy
|
||||
from . import location_config as location_policy
|
||||
@@ -65,6 +66,7 @@ except ImportError: # Allows local smoke tests with `python -c`.
|
||||
)
|
||||
import camera_config as camera_policy
|
||||
import category_cast_config as category_cast_policy
|
||||
import filter_config as filter_policy
|
||||
import generate_prompt_batches as g
|
||||
import generation_profile_config as generation_profile_policy
|
||||
import location_config as location_policy
|
||||
@@ -107,60 +109,11 @@ SEED_AXIS_ALIASES = seed_policy.SEED_AXIS_ALIASES
|
||||
SEED_LOCK_AXES = seed_policy.SEED_LOCK_AXES
|
||||
SEED_MODE_CHOICES = seed_policy.SEED_MODE_CHOICES
|
||||
|
||||
ETHNICITY_FILTER_CHOICES = [
|
||||
"any",
|
||||
"european",
|
||||
"mediterranean_mena",
|
||||
"latina",
|
||||
"east_asian",
|
||||
"southeast_asian",
|
||||
"south_asian",
|
||||
"black_african",
|
||||
"indigenous",
|
||||
"mixed",
|
||||
"asian",
|
||||
"white_asian",
|
||||
"western_european",
|
||||
"french_european",
|
||||
"germanic_european",
|
||||
"nordic_european",
|
||||
"celtic_european",
|
||||
"slavic_european",
|
||||
"baltic_european",
|
||||
"alpine_european",
|
||||
"balkan_european",
|
||||
"greek_mediterranean",
|
||||
"italian_mediterranean",
|
||||
"iberian_mediterranean",
|
||||
]
|
||||
ETHNICITY_LIST_KEYS = tuple(choice for choice in ETHNICITY_FILTER_CHOICES if choice != "any")
|
||||
ETHNICITY_BASE_LIST_KEYS = (
|
||||
"european",
|
||||
"mediterranean_mena",
|
||||
"latina",
|
||||
"east_asian",
|
||||
"southeast_asian",
|
||||
"south_asian",
|
||||
"black_african",
|
||||
"indigenous",
|
||||
"mixed",
|
||||
)
|
||||
EUROPEAN_REGIONAL_LIST_KEYS = (
|
||||
"western_european",
|
||||
"french_european",
|
||||
"germanic_european",
|
||||
"nordic_european",
|
||||
"celtic_european",
|
||||
"slavic_european",
|
||||
"baltic_european",
|
||||
"alpine_european",
|
||||
"balkan_european",
|
||||
)
|
||||
MEDITERRANEAN_REGIONAL_LIST_KEYS = (
|
||||
"greek_mediterranean",
|
||||
"italian_mediterranean",
|
||||
"iberian_mediterranean",
|
||||
)
|
||||
ETHNICITY_FILTER_CHOICES = filter_policy.ETHNICITY_FILTER_CHOICES
|
||||
ETHNICITY_LIST_KEYS = filter_policy.ETHNICITY_LIST_KEYS
|
||||
ETHNICITY_BASE_LIST_KEYS = filter_policy.ETHNICITY_BASE_LIST_KEYS
|
||||
EUROPEAN_REGIONAL_LIST_KEYS = filter_policy.EUROPEAN_REGIONAL_LIST_KEYS
|
||||
MEDITERRANEAN_REGIONAL_LIST_KEYS = filter_policy.MEDITERRANEAN_REGIONAL_LIST_KEYS
|
||||
|
||||
CHARACTER_LABEL_CHOICES = [
|
||||
"auto_chain",
|
||||
@@ -1114,38 +1067,21 @@ def build_filter_config_json(
|
||||
include_mixed: bool = True,
|
||||
include_plus_size: bool = True,
|
||||
) -> str:
|
||||
include_flags = {
|
||||
"european": include_european,
|
||||
"mediterranean_mena": include_mediterranean_mena,
|
||||
"latina": include_latina,
|
||||
"east_asian": include_east_asian,
|
||||
"southeast_asian": include_southeast_asian,
|
||||
"south_asian": include_south_asian,
|
||||
"black_african": include_black_african,
|
||||
"indigenous": include_indigenous,
|
||||
"mixed": include_mixed,
|
||||
}
|
||||
selected_ethnicities = [key for key, enabled in include_flags.items() if enabled]
|
||||
disabled_ethnicities = [key for key, enabled in include_flags.items() if not enabled]
|
||||
enabled_ethnicities = list(selected_ethnicities)
|
||||
if enabled_ethnicities:
|
||||
enabled_ethnicities.extend(f"exclude_{key}" for key in disabled_ethnicities)
|
||||
if 0 < len(selected_ethnicities) < len(include_flags):
|
||||
ethnicity = "+".join(enabled_ethnicities)
|
||||
elif not _is_valid_ethnicity_filter(ethnicity):
|
||||
ethnicity = "any"
|
||||
return json.dumps(
|
||||
{
|
||||
"ethnicity": ethnicity,
|
||||
"ethnicity_includes": selected_ethnicities,
|
||||
"figure": figure if figure in ("curvy", "balanced", "bombshell", "random") else "curvy",
|
||||
"include_plus_size": bool(include_plus_size),
|
||||
"include_black_african": bool(include_black_african),
|
||||
"no_plus_women": not bool(include_plus_size) or bool(no_plus_women),
|
||||
"no_black": not bool(include_black_african) or bool(no_black),
|
||||
},
|
||||
ensure_ascii=True,
|
||||
sort_keys=True,
|
||||
return filter_policy.build_filter_config_json(
|
||||
ethnicity=ethnicity,
|
||||
figure=figure,
|
||||
no_plus_women=no_plus_women,
|
||||
no_black=no_black,
|
||||
include_european=include_european,
|
||||
include_mediterranean_mena=include_mediterranean_mena,
|
||||
include_latina=include_latina,
|
||||
include_east_asian=include_east_asian,
|
||||
include_southeast_asian=include_southeast_asian,
|
||||
include_south_asian=include_south_asian,
|
||||
include_black_african=include_black_african,
|
||||
include_indigenous=include_indigenous,
|
||||
include_mixed=include_mixed,
|
||||
include_plus_size=include_plus_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -1260,31 +1196,15 @@ def build_thematic_location_json(
|
||||
|
||||
|
||||
def _ethnicity_text_from_value(value: Any) -> str:
|
||||
if isinstance(value, dict):
|
||||
return str(value.get("ethnicity") or "").strip()
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
return ""
|
||||
if text.startswith("{"):
|
||||
try:
|
||||
raw = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return text
|
||||
if isinstance(raw, dict):
|
||||
return str(raw.get("ethnicity") or "").strip()
|
||||
return text
|
||||
return filter_policy.ethnicity_text_from_value(value)
|
||||
|
||||
|
||||
def _is_valid_ethnicity_filter(value: Any) -> bool:
|
||||
text = _ethnicity_text_from_value(value)
|
||||
return text == "any" or text in ETHNICITY_FILTER_CHOICES or "+" in text
|
||||
return filter_policy.is_valid_ethnicity_filter(value)
|
||||
|
||||
|
||||
def normalize_ethnicity_filter(value: Any, default: str = "any", allow_random: bool = False) -> str:
|
||||
text = _ethnicity_text_from_value(value)
|
||||
if text.lower() in CHARACTER_RANDOM_TOKENS:
|
||||
return "random" if allow_random else default
|
||||
return text if _is_valid_ethnicity_filter(text) else default
|
||||
return filter_policy.normalize_ethnicity_filter(value, default, allow_random)
|
||||
|
||||
|
||||
def build_ethnicity_list_json(
|
||||
@@ -1313,98 +1233,36 @@ def build_ethnicity_list_json(
|
||||
include_iberian_mediterranean: bool = False,
|
||||
strict_excludes: bool = True,
|
||||
) -> dict[str, str]:
|
||||
include_flags = {
|
||||
"european": include_european,
|
||||
"mediterranean_mena": include_mediterranean_mena,
|
||||
"latina": include_latina,
|
||||
"east_asian": include_east_asian,
|
||||
"southeast_asian": include_southeast_asian,
|
||||
"south_asian": include_south_asian,
|
||||
"black_african": include_black_african,
|
||||
"indigenous": include_indigenous,
|
||||
"mixed": include_mixed,
|
||||
"asian": include_asian,
|
||||
"white_asian": include_white_asian,
|
||||
"western_european": include_western_european,
|
||||
"french_european": include_french_european,
|
||||
"germanic_european": include_germanic_european,
|
||||
"nordic_european": include_nordic_european,
|
||||
"celtic_european": include_celtic_european,
|
||||
"slavic_european": include_slavic_european,
|
||||
"baltic_european": include_baltic_european,
|
||||
"alpine_european": include_alpine_european,
|
||||
"balkan_european": include_balkan_european,
|
||||
"greek_mediterranean": include_greek_mediterranean,
|
||||
"italian_mediterranean": include_italian_mediterranean,
|
||||
"iberian_mediterranean": include_iberian_mediterranean,
|
||||
}
|
||||
selected = [key for key in ETHNICITY_LIST_KEYS if include_flags.get(key)]
|
||||
if not selected or set(selected) == set(ETHNICITY_LIST_KEYS):
|
||||
ethnicity = "any"
|
||||
else:
|
||||
tokens = list(selected)
|
||||
if strict_excludes:
|
||||
protected: set[str] = set()
|
||||
if "asian" in selected:
|
||||
protected.update(("east_asian", "southeast_asian", "south_asian"))
|
||||
if "white_asian" in selected:
|
||||
protected.update(("european", "east_asian", "southeast_asian", "south_asian", "mixed"))
|
||||
if any(key in selected for key in EUROPEAN_REGIONAL_LIST_KEYS):
|
||||
protected.add("european")
|
||||
if any(key in selected for key in MEDITERRANEAN_REGIONAL_LIST_KEYS):
|
||||
protected.add("mediterranean_mena")
|
||||
if "mixed" in selected:
|
||||
protected.update(ETHNICITY_BASE_LIST_KEYS)
|
||||
tokens.extend(
|
||||
f"exclude_{key}"
|
||||
for key in ETHNICITY_BASE_LIST_KEYS
|
||||
if key not in selected and key not in protected
|
||||
)
|
||||
ethnicity = "+".join(tokens)
|
||||
filter_config = {
|
||||
"ethnicity": ethnicity,
|
||||
"ethnicity_includes": selected,
|
||||
}
|
||||
summary = "any ethnicity" if ethnicity == "any" else "ethnicity list: " + ", ".join(selected)
|
||||
return {
|
||||
"ethnicity": ethnicity,
|
||||
"filter_config": json.dumps(filter_config, ensure_ascii=True, sort_keys=True),
|
||||
"summary": summary,
|
||||
}
|
||||
return filter_policy.build_ethnicity_list_json(
|
||||
include_european=include_european,
|
||||
include_mediterranean_mena=include_mediterranean_mena,
|
||||
include_latina=include_latina,
|
||||
include_east_asian=include_east_asian,
|
||||
include_southeast_asian=include_southeast_asian,
|
||||
include_south_asian=include_south_asian,
|
||||
include_black_african=include_black_african,
|
||||
include_indigenous=include_indigenous,
|
||||
include_mixed=include_mixed,
|
||||
include_asian=include_asian,
|
||||
include_white_asian=include_white_asian,
|
||||
include_western_european=include_western_european,
|
||||
include_french_european=include_french_european,
|
||||
include_germanic_european=include_germanic_european,
|
||||
include_nordic_european=include_nordic_european,
|
||||
include_celtic_european=include_celtic_european,
|
||||
include_slavic_european=include_slavic_european,
|
||||
include_baltic_european=include_baltic_european,
|
||||
include_alpine_european=include_alpine_european,
|
||||
include_balkan_european=include_balkan_european,
|
||||
include_greek_mediterranean=include_greek_mediterranean,
|
||||
include_italian_mediterranean=include_italian_mediterranean,
|
||||
include_iberian_mediterranean=include_iberian_mediterranean,
|
||||
strict_excludes=strict_excludes,
|
||||
)
|
||||
|
||||
|
||||
def _parse_filter_config(filter_config: str | dict[str, Any] | None) -> dict[str, Any]:
|
||||
defaults = {
|
||||
"ethnicity": "any",
|
||||
"figure": "curvy",
|
||||
"no_plus_women": False,
|
||||
"no_black": False,
|
||||
"include_plus_size": True,
|
||||
"include_black_african": True,
|
||||
}
|
||||
if not filter_config:
|
||||
return defaults
|
||||
if isinstance(filter_config, dict):
|
||||
raw = filter_config
|
||||
else:
|
||||
text = str(filter_config).strip()
|
||||
if not text.startswith("{"):
|
||||
raw = {"ethnicity": text}
|
||||
else:
|
||||
try:
|
||||
raw = json.loads(text)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"Invalid filter_config JSON: {exc}") from exc
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError("filter_config must be a JSON object")
|
||||
parsed = {**defaults, **raw}
|
||||
parsed["ethnicity"] = normalize_ethnicity_filter(parsed.get("ethnicity"), "any")
|
||||
parsed["figure"] = parsed["figure"] if parsed.get("figure") in ("curvy", "balanced", "bombshell", "random") else "curvy"
|
||||
parsed["include_plus_size"] = bool(parsed.get("include_plus_size"))
|
||||
parsed["include_black_african"] = bool(parsed.get("include_black_african"))
|
||||
parsed["no_plus_women"] = bool(parsed.get("no_plus_women"))
|
||||
parsed["no_black"] = bool(parsed.get("no_black"))
|
||||
return parsed
|
||||
return filter_policy.parse_filter_config(filter_config)
|
||||
|
||||
|
||||
def _normalize_hardcore_position_family(value: Any, default: str = "any") -> str:
|
||||
|
||||
Reference in New Issue
Block a user