diff --git a/caption_policy.py b/caption_policy.py index 042a7a4..d066eba 100644 --- a/caption_policy.py +++ b/caption_policy.py @@ -77,7 +77,12 @@ def normalize_detail_level(value: str) -> str: return detail_policy.normalize_detail_level(value) +def _choice_key(value: Any) -> str: + return str(value or "").strip().lower().replace("-", "_").replace(" ", "_") + + def normalize_style_policy(value: str) -> str: + value = _choice_key(value) return value if value in STYLE_POLICIES else "drop_style_tail" @@ -90,6 +95,7 @@ def caption_profile_choices() -> list[str]: def normalize_caption_profile(value: str) -> str: + value = _choice_key(value) return value if value in CAPTION_PROFILES else CAPTION_PROFILE_DEFAULT diff --git a/sdxl_presets.py b/sdxl_presets.py index df17812..a4b5699 100644 --- a/sdxl_presets.py +++ b/sdxl_presets.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + DEFAULT_STYLE_PRESET = "flat_vector_pony" DEFAULT_QUALITY_PRESET = "pony_high" @@ -79,15 +81,22 @@ def sdxl_formatter_profile_choices() -> list[str]: return list(SDXL_FORMATTER_PROFILES) +def _choice_key(value: Any) -> str: + return str(value or "").strip().lower().replace("-", "_").replace(" ", "_") + + def normalize_style_preset(value: str) -> str: + value = _choice_key(value) return value if value in SDXL_STYLE_PRESETS else DEFAULT_STYLE_PRESET def normalize_quality_preset(value: str) -> str: + value = _choice_key(value) return value if value in SDXL_QUALITY_PRESETS else DEFAULT_QUALITY_PRESET def normalize_formatter_profile(value: str) -> str: + value = _choice_key(value) return value if value in SDXL_FORMATTER_PROFILES else DEFAULT_FORMATTER_PROFILE diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 7147adc..317e4dc 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -3039,6 +3039,10 @@ def smoke_caption_policy() -> None: ) _expect(caption_policy.normalize_detail_level("bad") == "balanced", "Caption invalid detail fallback changed") _expect(caption_policy.normalize_style_policy("bad") == "drop_style_tail", "Caption invalid style fallback changed") + _expect( + caption_policy.normalize_style_policy("Keep Style Terms") == "keep_style_terms", + "Caption style policy should normalize spaces/case", + ) _expect(caption_policy.style_policy_choices() == ["drop_style_tail", "keep_style_terms"], "Caption style policy choices changed") _expect(caption_policy.keep_style_terms("keep_style_terms") is True, "Caption style policy keep flag changed") _expect(caption_policy.detail_allows("concise") is False, "Caption concise detail gate changed") @@ -3048,6 +3052,10 @@ def smoke_caption_policy() -> None: caption_policy.normalize_caption_profile("bad") == caption_policy.CAPTION_PROFILE_DEFAULT, "Caption invalid profile fallback changed", ) + _expect( + caption_policy.normalize_caption_profile("training-dense") == "training_dense", + "Caption profile should normalize hyphen spelling", + ) _expect( caption_policy.apply_caption_profile( "training_dense", @@ -3381,8 +3389,11 @@ def smoke_sdxl_presets_policy() -> None: sdxl_presets.normalize_formatter_profile("bad") == sdxl_presets.DEFAULT_FORMATTER_PROFILE, "SDXL invalid profile fallback changed", ) + _expect(sdxl_presets.normalize_formatter_profile("SDXL Photo") == "sdxl_photo", "SDXL profile should normalize spaces/case") _expect(sdxl_presets.normalize_style_preset("bad") == sdxl_presets.DEFAULT_STYLE_PRESET, "SDXL invalid style fallback changed") + _expect(sdxl_presets.normalize_style_preset("flat-vector-pony") == "flat_vector_pony", "SDXL style should normalize hyphens") _expect(sdxl_presets.normalize_quality_preset("bad") == sdxl_presets.DEFAULT_QUALITY_PRESET, "SDXL invalid quality fallback changed") + _expect(sdxl_presets.normalize_quality_preset("Pony High") == "pony_high", "SDXL quality should normalize spaces/case") _expect( sdxl_presets.apply_formatter_profile( "sdxl_photo",