Add multi-seed route simulation sweep

This commit is contained in:
2026-06-27 19:58:11 +02:00
parent 4a3610fbc9
commit 1ca9c95bfe
6 changed files with 111 additions and 20 deletions
+4
View File
@@ -114,6 +114,10 @@ AUDIT_DOC_SNIPPETS: tuple[tuple[str, str], ...] = (
"docs/prompt-pool-routing-map.md",
"repeated cast descriptors in training-caption formatter output",
),
(
"docs/prompt-pool-routing-map.md",
"multi-seed route sweeps",
),
)
PROMPT_ROW_READ_SCAN_GLOBS: tuple[str, ...] = (
+80 -6
View File
@@ -308,6 +308,7 @@ HARDCORE_ROUTE_CASES = (
"subcategory": "Foreplay and teasing",
"focus": "foreplay_only",
"family": "foreplay",
"positions": ("undressing",),
"expected_route": {"action_family": "foreplay", "position_family": "foreplay"},
"expected_terms": {
"krea": ("clothing",),
@@ -317,12 +318,13 @@ HARDCORE_ROUTE_CASES = (
},
{
"name": "hardcore.single.interaction",
"subcategory": "Aftercare and cleanup",
"subcategory": "Clothing and position transitions",
"focus": "interaction_only",
"family": "interaction",
"positions": ("position_transition",),
"expected_route": {"action_family": "foreplay", "position_family": "interaction"},
"expected_terms": {
"krea": ("mid-transition",),
"krea": ("clothing",),
"sdxl": ("interaction",),
"caption": ("interaction beat",),
},
@@ -332,9 +334,10 @@ HARDCORE_ROUTE_CASES = (
"subcategory": "Anal and double penetration",
"focus": "anal_only",
"family": "anal",
"positions": ("doggy", "face_down_ass_up"),
"expected_route": {"action_family": "anal", "position_family": "anal"},
"expected_terms": {
"krea": ("anal",),
"krea": ("ass",),
"sdxl": ("anal sex",),
"caption": ("anal action",),
},
@@ -853,7 +856,13 @@ def _regular_single_case(seed: int) -> dict[str, Any]:
)
def _hardcore_single_case(seed: int, subcategory: str, focus: str, family: str) -> dict[str, Any]:
def _hardcore_single_case(
seed: int,
subcategory: str,
focus: str,
family: str,
positions: list[str] | tuple[str, ...] | str = (),
) -> dict[str, Any]:
women_count, men_count, character_cast = {
"threesome": (1, 2, _character_cast_subjects(("woman", "man", "man"))),
"group": (2, 2, _character_cast_subjects(("woman", "woman", "man", "man"))),
@@ -881,7 +890,7 @@ def _hardcore_single_case(seed: int, subcategory: str, focus: str, family: str)
women_count=women_count,
men_count=men_count,
character_cast=character_cast,
hardcore_position_config=_position_filter(focus, family, []),
hardcore_position_config=_position_filter(focus, family, positions),
location_config=_coworking_location_config(),
camera_config=_orbit_camera(horizontal_angle=35, vertical_angle=0, zoom=6.5),
)
@@ -1404,6 +1413,7 @@ def run_simulation(seed: int = 3901, *, include_prompts: bool = False) -> dict[s
str(route_case["subcategory"]),
str(route_case["focus"]),
str(route_case["family"]),
route_case.get("positions") or (),
)
cases.append(
_case_report(
@@ -1459,6 +1469,38 @@ def run_simulation(seed: int = 3901, *, include_prompts: bool = False) -> dict[s
}
def run_simulation_sweep(
seed: int = 3901,
*,
count: int = 3,
seed_step: int = 101,
include_prompts: bool = False,
) -> dict[str, Any]:
count = max(1, int(count))
seed_step = int(seed_step)
seeds = [seed + index * seed_step for index in range(count)]
runs = [run_simulation(seed=current_seed, include_prompts=include_prompts) for current_seed in seeds]
issues: list[dict[str, Any]] = []
for run in runs:
run_seed = (run.get("summary") or {}).get("seed")
issues.extend({"seed": run_seed, **issue} for issue in run.get("issues") or [])
return {
"summary": {
"seed": seed,
"seed_step": seed_step,
"seeds": seeds,
"runs": len(runs),
"cases": sum((run.get("summary") or {}).get("cases", 0) for run in runs),
"coverage_checks": sum((run.get("summary") or {}).get("coverage_checks", 0) for run in runs),
"axis_checks": sum((run.get("summary") or {}).get("axis_checks", 0) for run in runs),
"pair_seed_checks": sum((run.get("summary") or {}).get("pair_seed_checks", 0) for run in runs),
"issues": len(issues),
},
"issues": issues,
"runs": runs,
}
def _print_text_report(report: dict[str, Any]) -> None:
summary = report.get("summary") or {}
print(
@@ -1490,17 +1532,49 @@ def _print_text_report(report: dict[str, Any]) -> None:
print(f" ISSUE {issue}")
def _print_sweep_report(report: dict[str, Any]) -> None:
summary = report.get("summary") or {}
seeds = ", ".join(str(seed) for seed in (summary.get("seeds") or []))
print(
f"Prompt route simulation sweep: seed={summary.get('seed')} "
f"seed_step={summary.get('seed_step')} runs={summary.get('runs')} "
f"seeds={seeds} cases={summary.get('cases')} coverage_checks={summary.get('coverage_checks')} "
f"axis_checks={summary.get('axis_checks')} pair_seed_checks={summary.get('pair_seed_checks')} "
f"issues={summary.get('issues')}"
)
for run in report.get("runs") or []:
run_summary = run.get("summary") or {}
print(
f"- seed {run_summary.get('seed')}: "
f"cases={run_summary.get('cases')} issues={run_summary.get('issues')}"
)
for issue in run.get("issues") or []:
print(f" ISSUE {issue.get('case')}: {issue.get('issue')}")
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--seed", type=int, default=3901, help="Base seed for deterministic simulations.")
parser.add_argument("--sweep-count", type=int, default=1, help="Run this many seed-spaced simulations.")
parser.add_argument("--seed-step", type=int, default=101, help="Seed increment used by --sweep-count.")
parser.add_argument("--json", action="store_true", help="Print the full JSON report.")
parser.add_argument("--include-prompts", action="store_true", help="Include raw and formatted prompt text in the report.")
parser.add_argument("--fail-on-issues", action="store_true", help="Exit with code 1 when any issue is reported.")
args = parser.parse_args(argv)
report = run_simulation(seed=args.seed, include_prompts=args.include_prompts)
if args.sweep_count > 1:
report = run_simulation_sweep(
seed=args.seed,
count=args.sweep_count,
seed_step=args.seed_step,
include_prompts=args.include_prompts,
)
else:
report = run_simulation(seed=args.seed, include_prompts=args.include_prompts)
if args.json:
print(json.dumps(report, ensure_ascii=True, indent=2, sort_keys=True))
elif args.sweep_count > 1:
_print_sweep_report(report)
else:
_print_text_report(report)
return 1 if args.fail_on_issues and report.get("issues") else 0
+6
View File
@@ -8025,6 +8025,12 @@ def smoke_prompt_route_simulation_policy() -> None:
pair_seed_checks["pair_seed.pose_reroll"].get("changed") is True,
"Pair pose reroll should prove hard action can reroll while soft/cast/scene axes stay locked",
)
sweep = prompt_route_simulation.run_simulation_sweep(seed=3901, count=3, seed_step=101, include_prompts=False)
sweep_summary = sweep.get("summary") or {}
_expect(sweep_summary.get("runs") == 3, "Prompt route simulation sweep lost run coverage")
_expect(sweep_summary.get("seeds") == [3901, 4002, 4103], "Prompt route simulation sweep seed sequence changed")
_expect(sweep_summary.get("cases") == 42, "Prompt route simulation sweep case count changed")
_expect(sweep_summary.get("issues") == 0, f"Prompt route simulation sweep reported issues: {sweep.get('issues')}")
def smoke_node_camera_registration() -> None: