diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index be93b6a..9c74591 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -610,6 +610,40 @@ def smoke_config_route_location_theme() -> None: def smoke_builder_prompt_route_policy() -> None: + def legacy_from_request(request: builder_prompt_route.PromptBuildRequest) -> dict[str, Any]: + return pb.build_prompt( + category=request.category, + subcategory=request.subcategory, + row_number=request.row_number, + start_index=request.start_index, + seed=request.seed, + clothing=request.clothing, + ethnicity=request.ethnicity, + poses=request.poses, + backside_bias=request.backside_bias, + figure=request.figure, + no_plus_women=request.no_plus_women, + no_black=request.no_black, + minimal_clothing_ratio=request.minimal_clothing_ratio, + standard_pose_ratio=request.standard_pose_ratio, + trigger=request.trigger, + prepend_trigger_to_prompt=request.prepend_trigger_to_prompt, + extra_positive=request.extra_positive, + extra_negative=request.extra_negative, + seed_config=request.seed_config, + women_count=request.women_count, + men_count=request.men_count, + camera_config=request.camera_config, + expression_intensity=request.expression_intensity, + character_profile=request.character_profile, + character_cast=request.character_cast, + expression_enabled=request.expression_enabled, + expression_phase=request.expression_phase, + hardcore_position_config=request.hardcore_position_config, + location_config=request.location_config, + composition_config=request.composition_config, + ) + seed_config_json = pb.build_seed_lock_config_json(base_seed=3501, reroll_axis="content", reroll_seed=3502) request = builder_prompt_route.PromptBuildRequest( category="Casual clothes", @@ -638,32 +672,7 @@ def smoke_builder_prompt_route_policy() -> None: expression_enabled=True, ) typed_route = builder_prompt_route.build_prompt_result(request, pb._prompt_build_dependencies()) - legacy_row = pb.build_prompt( - category=request.category, - subcategory=request.subcategory, - row_number=request.row_number, - start_index=request.start_index, - seed=request.seed, - clothing=request.clothing, - ethnicity=request.ethnicity, - poses=request.poses, - backside_bias=request.backside_bias, - figure=request.figure, - no_plus_women=request.no_plus_women, - no_black=request.no_black, - minimal_clothing_ratio=request.minimal_clothing_ratio, - standard_pose_ratio=request.standard_pose_ratio, - trigger=request.trigger, - prepend_trigger_to_prompt=request.prepend_trigger_to_prompt, - extra_positive=request.extra_positive, - extra_negative=request.extra_negative, - seed_config=request.seed_config, - women_count=request.women_count, - men_count=request.men_count, - camera_config=request.camera_config, - expression_intensity=request.expression_intensity, - expression_enabled=request.expression_enabled, - ) + legacy_row = legacy_from_request(request) _expect(typed_route.row == legacy_row, "Typed builder prompt route should match public wrapper output") _expect(typed_route.category == "Casual clothes", "Builder prompt route changed category") _expect(typed_route.subcategory == "Casual clothes / Smart casual", "Builder prompt route changed subcategory") @@ -677,6 +686,64 @@ def smoke_builder_prompt_route_policy() -> None: ) _expect_trigger_once("builder_prompt_route_policy.prompt", typed_route.row.get("prompt"), "sxcpinup_coloredpencil") + built_in_request = builder_prompt_route.PromptBuildRequest( + category="woman", + subcategory="random", + row_number=1, + start_index=1, + seed=3503, + clothing="full", + ethnicity="any", + poses="standard", + backside_bias=0.0, + figure="curvy", + no_plus_women=False, + no_black=False, + minimal_clothing_ratio=0.0, + standard_pose_ratio=1.0, + trigger=Trigger, + prepend_trigger_to_prompt=True, + extra_positive="built-in route marker", + extra_negative="built-in route negative", + expression_intensity=0.5, + expression_enabled=False, + ) + built_in_route = builder_prompt_route.build_prompt_result(built_in_request, pb._prompt_build_dependencies()) + _expect(built_in_route.row == legacy_from_request(built_in_request), "Builder built-in route should match public wrapper") + _expect(built_in_route.branch == "built_in", "Builder prompt route lost built-in branch") + _expect(built_in_route.row.get("source") == "built_in_generator", "Builder built-in branch changed source") + _expect(built_in_route.row.get("expression_disabled") is True, "Builder built-in branch lost expression disable") + _expect("built-in route marker" in built_in_route.row.get("prompt", ""), "Builder built-in branch lost extra positive") + + auto_weighted_request = builder_prompt_route.PromptBuildRequest( + category="auto_weighted", + subcategory="random", + row_number=2, + start_index=10, + seed=3504, + clothing="random", + ethnicity="any", + poses="random", + backside_bias=0.35, + figure="random", + no_plus_women=False, + no_black=False, + minimal_clothing_ratio=0.4, + standard_pose_ratio=0.6, + trigger=Trigger, + prepend_trigger_to_prompt=True, + extra_positive="auto route marker", + extra_negative="auto route negative", + seed_config=pb.build_seed_lock_config_json(base_seed=3504, reroll_axis="person", reroll_seed=3505), + expression_intensity=0.7, + expression_enabled=True, + ) + auto_route = builder_prompt_route.build_prompt_result(auto_weighted_request, pb._prompt_build_dependencies()) + _expect(auto_route.row == legacy_from_request(auto_weighted_request), "Builder auto-weighted route should match public wrapper") + _expect(auto_route.branch == "auto_weighted", "Builder prompt route lost auto-weighted branch") + _expect(auto_route.parsed_seed_config.get("person_seed") == 3505, "Builder auto-weighted branch lost person seed lock") + _expect("auto route marker" in auto_route.row.get("prompt", ""), "Builder auto-weighted branch lost extra positive") + def smoke_builder_config_route_policy() -> None: category_config = pb.build_category_config_json("women_casual", "Casual clothes / Smart casual")