Extract SDXL tag route assembly
This commit is contained in:
@@ -73,6 +73,7 @@ import row_subject_route # noqa: E402
|
||||
import server_routes # noqa: E402
|
||||
import sdxl_formatter # noqa: E402
|
||||
import sdxl_presets # noqa: E402
|
||||
import sdxl_tag_routes # noqa: E402
|
||||
import seed_config # noqa: E402
|
||||
import krea_pov # noqa: E402
|
||||
import subject_context # noqa: E402
|
||||
@@ -2336,6 +2337,67 @@ def smoke_sdxl_presets_policy() -> None:
|
||||
_expect("score_9" not in profiled_prompt, "SDXL photo profile should switch away from Pony score quality tail")
|
||||
|
||||
|
||||
def smoke_sdxl_tag_routes() -> None:
|
||||
row = _fixture_hardcore_row(
|
||||
formatter_hints={
|
||||
"all": ["shared route anchor"],
|
||||
"sdxl": ["sdxl route tag"],
|
||||
}
|
||||
)
|
||||
deps = sdxl_formatter._sdxl_tag_route_dependencies()
|
||||
typed_row = sdxl_tag_routes.row_core_tags_result(
|
||||
sdxl_tag_routes.SDXLRowTagRequest(row, 1.29),
|
||||
deps,
|
||||
)
|
||||
_expect(
|
||||
typed_row.tags == sdxl_formatter._row_core_tags(row, 1.29),
|
||||
"Typed SDXL row tag route should match legacy wrapper output",
|
||||
)
|
||||
_expect("sdxl route tag" in typed_row.as_text(), "Typed SDXL row tag route lost route-specific formatter hint")
|
||||
|
||||
pair = pb.build_insta_of_pair(
|
||||
row_number=1,
|
||||
start_index=1,
|
||||
seed=3511,
|
||||
ethnicity="any",
|
||||
figure="random",
|
||||
no_plus_women=False,
|
||||
no_black=False,
|
||||
trigger=Trigger,
|
||||
prepend_trigger_to_prompt=True,
|
||||
options_json=_insta_options(hardcore_clothing_continuity="partially_removed"),
|
||||
character_cast=_character_cast(),
|
||||
hardcore_position_config=_action_filter("penetration_only"),
|
||||
)
|
||||
_expect_pair(pair, "sdxl_tag_routes_pair")
|
||||
soft_row = pair.get("softcore_row") if isinstance(pair.get("softcore_row"), dict) else {}
|
||||
hard_row = pair.get("hardcore_row") if isinstance(pair.get("hardcore_row"), dict) else {}
|
||||
typed_soft = sdxl_tag_routes.soft_tags_result(
|
||||
sdxl_tag_routes.SDXLPairTagRequest(soft_row, pair, 1.29),
|
||||
deps,
|
||||
)
|
||||
typed_hard = sdxl_tag_routes.hard_tags_result(
|
||||
sdxl_tag_routes.SDXLPairTagRequest(hard_row, pair, 1.29),
|
||||
deps,
|
||||
)
|
||||
_expect(
|
||||
typed_soft.as_text() == sdxl_formatter._soft_tags(soft_row, pair, 1.29),
|
||||
"Typed SDXL pair soft tag route should match legacy wrapper output",
|
||||
)
|
||||
_expect(
|
||||
typed_hard.as_text() == sdxl_formatter._hard_tags(hard_row, pair, 1.29),
|
||||
"Typed SDXL pair hard tag route should match legacy wrapper output",
|
||||
)
|
||||
formatted = sdxl_formatter.format_sdxl_prompt(
|
||||
"",
|
||||
metadata_json=_json(pair),
|
||||
target="hardcore",
|
||||
trigger=SdxlTrigger,
|
||||
prepend_trigger=True,
|
||||
)
|
||||
_expect("sdxl(insta_of_pair)" in formatted.get("method", ""), "SDXL pair formatter route changed method")
|
||||
|
||||
|
||||
def smoke_hardcore_position_config_policy() -> None:
|
||||
_expect(
|
||||
pb.HARDCORE_POSITION_FAMILY_CHOICES is hardcore_position_config.HARDCORE_POSITION_FAMILY_CHOICES,
|
||||
@@ -5090,6 +5152,7 @@ SMOKE_CASES: list[tuple[str, Callable[[], None]]] = [
|
||||
("formatter_cast_policy", smoke_formatter_cast_policy),
|
||||
("caption_policy", smoke_caption_policy),
|
||||
("sdxl_presets_policy", smoke_sdxl_presets_policy),
|
||||
("sdxl_tag_routes", smoke_sdxl_tag_routes),
|
||||
("hardcore_position_config_policy", smoke_hardcore_position_config_policy),
|
||||
("row_route_metadata_policy", smoke_row_route_metadata_policy),
|
||||
("category_library_route", smoke_category_library_route),
|
||||
|
||||
Reference in New Issue
Block a user