Extract row normalization policy
This commit is contained in:
@@ -35,6 +35,7 @@ import generation_profile_config # noqa: E402
|
||||
import krea_formatter # noqa: E402
|
||||
import location_config # noqa: E402
|
||||
import prompt_builder as pb # noqa: E402
|
||||
import row_normalization # noqa: E402
|
||||
import sdxl_formatter # noqa: E402
|
||||
import seed_config # noqa: E402
|
||||
|
||||
@@ -770,6 +771,82 @@ def smoke_character_profile_policy() -> None:
|
||||
_expect(applied_profile.get("profile_type") == "character", "Profile context returned wrong profile")
|
||||
|
||||
|
||||
def smoke_row_normalization_policy() -> None:
|
||||
_expect(
|
||||
pb._prepend_trigger("base prompt", Trigger, True) == row_normalization.prepend_trigger("base prompt", Trigger, True),
|
||||
"Prompt builder trigger helper should delegate to row normalization policy",
|
||||
)
|
||||
_expect(
|
||||
pb._combined_negative("bad anatomy", "low quality") == row_normalization.combined_negative("bad anatomy", "low quality"),
|
||||
"Prompt builder negative helper should delegate to row normalization policy",
|
||||
)
|
||||
|
||||
row = row_normalization.normalize_prompt_row(
|
||||
{
|
||||
"prompt": f"{Trigger}, {Trigger}, base prompt.",
|
||||
"caption": f"{Trigger}, {Trigger}, base caption.",
|
||||
"negative_prompt": "bad anatomy, bad anatomy",
|
||||
},
|
||||
active_trigger=Trigger,
|
||||
prepend_trigger_to_prompt=True,
|
||||
extra_positive="extra detail",
|
||||
extra_negative="low quality, bad anatomy",
|
||||
default_negative="bad anatomy",
|
||||
)
|
||||
_expect_trigger_once("row_normalization.prompt", row.get("prompt"), Trigger)
|
||||
_expect_trigger_once("row_normalization.caption", row.get("caption"), Trigger)
|
||||
_expect("extra detail" in row.get("prompt", ""), "Row normalization lost extra positive text")
|
||||
_expect(row.get("trigger") == Trigger, "Row normalization lost active trigger")
|
||||
_expect_no_duplicate_comma_items("row_normalization.negative", row.get("negative_prompt"))
|
||||
|
||||
outputs = row_normalization.normalize_pair_text_outputs(
|
||||
active_trigger=Trigger,
|
||||
prepend_trigger_to_prompt=True,
|
||||
extra_positive="pair extra",
|
||||
extra_negative="low quality, bad anatomy",
|
||||
soft_prompt="soft prompt.",
|
||||
hard_prompt="hard prompt.",
|
||||
soft_negative_base="bad anatomy, bad anatomy",
|
||||
hard_negative_base="bad anatomy, low quality",
|
||||
soft_caption_parts=[Trigger, "soft caption"],
|
||||
hard_caption_parts=[Trigger, "hard caption"],
|
||||
)
|
||||
_expect_trigger_once("row_normalization.soft_prompt", outputs.get("soft_prompt"), Trigger)
|
||||
_expect_trigger_once("row_normalization.hard_prompt", outputs.get("hard_prompt"), Trigger)
|
||||
_expect_trigger_once("row_normalization.soft_caption", outputs.get("soft_caption"), Trigger)
|
||||
_expect_trigger_once("row_normalization.hard_caption", outputs.get("hard_caption"), Trigger)
|
||||
_expect_no_duplicate_comma_items("row_normalization.soft_negative", outputs.get("soft_negative"))
|
||||
_expect_no_duplicate_comma_items("row_normalization.hard_negative", outputs.get("hard_negative"))
|
||||
|
||||
pair = row_normalization.normalize_pair_metadata(
|
||||
{
|
||||
"softcore_prompt": f"{Trigger}, {Trigger}, soft pair.",
|
||||
"hardcore_prompt": f"{Trigger}, {Trigger}, hard pair.",
|
||||
"softcore_caption": f"{Trigger}, {Trigger}, soft caption.",
|
||||
"hardcore_caption": f"{Trigger}, {Trigger}, hard caption.",
|
||||
"softcore_negative_prompt": "bad anatomy, bad anatomy",
|
||||
"hardcore_negative_prompt": "bad anatomy, low quality, bad anatomy",
|
||||
"softcore_row": {
|
||||
"prompt": f"{Trigger}, {Trigger}, embedded soft.",
|
||||
"caption": f"{Trigger}, {Trigger}, embedded soft caption.",
|
||||
"negative_prompt": "bad anatomy, bad anatomy",
|
||||
},
|
||||
"hardcore_row": {
|
||||
"prompt": f"{Trigger}, {Trigger}, embedded hard.",
|
||||
"caption": f"{Trigger}, {Trigger}, embedded hard caption.",
|
||||
"negative_prompt": "low quality, bad anatomy, low quality",
|
||||
},
|
||||
},
|
||||
active_trigger=Trigger,
|
||||
)
|
||||
_expect_trigger_once("row_normalization.pair.softcore_prompt", pair.get("softcore_prompt"), Trigger)
|
||||
_expect_trigger_once("row_normalization.pair.hardcore_prompt", pair.get("hardcore_prompt"), Trigger)
|
||||
_expect_trigger_once("row_normalization.pair.softcore_row.prompt", pair["softcore_row"].get("prompt"), Trigger)
|
||||
_expect_trigger_once("row_normalization.pair.hardcore_row.caption", pair["hardcore_row"].get("caption"), Trigger)
|
||||
_expect_no_duplicate_comma_items("row_normalization.pair.soft_negative", pair.get("softcore_negative_prompt"))
|
||||
_expect_no_duplicate_comma_items("row_normalization.pair.hard_row_negative", pair["hardcore_row"].get("negative_prompt"))
|
||||
|
||||
|
||||
def smoke_hardcore_position_config_policy() -> None:
|
||||
_expect(
|
||||
pb.HARDCORE_POSITION_FAMILY_CHOICES is hardcore_position_config.HARDCORE_POSITION_FAMILY_CHOICES,
|
||||
@@ -2740,6 +2817,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
|
||||
("filter_config_policy", smoke_filter_config_policy),
|
||||
("character_config_policy", smoke_character_config_policy),
|
||||
("character_profile_policy", smoke_character_profile_policy),
|
||||
("row_normalization_policy", smoke_row_normalization_policy),
|
||||
("hardcore_position_config_policy", smoke_hardcore_position_config_policy),
|
||||
("category_library_route", smoke_category_library_route),
|
||||
("hardcore_category_routes", smoke_hardcore_category_routes),
|
||||
|
||||
Reference in New Issue
Block a user