Add multi-seed route simulation sweep
This commit is contained in:
@@ -114,6 +114,10 @@ AUDIT_DOC_SNIPPETS: tuple[tuple[str, str], ...] = (
|
||||
"docs/prompt-pool-routing-map.md",
|
||||
"repeated cast descriptors in training-caption formatter output",
|
||||
),
|
||||
(
|
||||
"docs/prompt-pool-routing-map.md",
|
||||
"multi-seed route sweeps",
|
||||
),
|
||||
)
|
||||
|
||||
PROMPT_ROW_READ_SCAN_GLOBS: tuple[str, ...] = (
|
||||
|
||||
@@ -308,6 +308,7 @@ HARDCORE_ROUTE_CASES = (
|
||||
"subcategory": "Foreplay and teasing",
|
||||
"focus": "foreplay_only",
|
||||
"family": "foreplay",
|
||||
"positions": ("undressing",),
|
||||
"expected_route": {"action_family": "foreplay", "position_family": "foreplay"},
|
||||
"expected_terms": {
|
||||
"krea": ("clothing",),
|
||||
@@ -317,12 +318,13 @@ HARDCORE_ROUTE_CASES = (
|
||||
},
|
||||
{
|
||||
"name": "hardcore.single.interaction",
|
||||
"subcategory": "Aftercare and cleanup",
|
||||
"subcategory": "Clothing and position transitions",
|
||||
"focus": "interaction_only",
|
||||
"family": "interaction",
|
||||
"positions": ("position_transition",),
|
||||
"expected_route": {"action_family": "foreplay", "position_family": "interaction"},
|
||||
"expected_terms": {
|
||||
"krea": ("mid-transition",),
|
||||
"krea": ("clothing",),
|
||||
"sdxl": ("interaction",),
|
||||
"caption": ("interaction beat",),
|
||||
},
|
||||
@@ -332,9 +334,10 @@ HARDCORE_ROUTE_CASES = (
|
||||
"subcategory": "Anal and double penetration",
|
||||
"focus": "anal_only",
|
||||
"family": "anal",
|
||||
"positions": ("doggy", "face_down_ass_up"),
|
||||
"expected_route": {"action_family": "anal", "position_family": "anal"},
|
||||
"expected_terms": {
|
||||
"krea": ("anal",),
|
||||
"krea": ("ass",),
|
||||
"sdxl": ("anal sex",),
|
||||
"caption": ("anal action",),
|
||||
},
|
||||
@@ -853,7 +856,13 @@ def _regular_single_case(seed: int) -> dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def _hardcore_single_case(seed: int, subcategory: str, focus: str, family: str) -> dict[str, Any]:
|
||||
def _hardcore_single_case(
|
||||
seed: int,
|
||||
subcategory: str,
|
||||
focus: str,
|
||||
family: str,
|
||||
positions: list[str] | tuple[str, ...] | str = (),
|
||||
) -> dict[str, Any]:
|
||||
women_count, men_count, character_cast = {
|
||||
"threesome": (1, 2, _character_cast_subjects(("woman", "man", "man"))),
|
||||
"group": (2, 2, _character_cast_subjects(("woman", "woman", "man", "man"))),
|
||||
@@ -881,7 +890,7 @@ def _hardcore_single_case(seed: int, subcategory: str, focus: str, family: str)
|
||||
women_count=women_count,
|
||||
men_count=men_count,
|
||||
character_cast=character_cast,
|
||||
hardcore_position_config=_position_filter(focus, family, []),
|
||||
hardcore_position_config=_position_filter(focus, family, positions),
|
||||
location_config=_coworking_location_config(),
|
||||
camera_config=_orbit_camera(horizontal_angle=35, vertical_angle=0, zoom=6.5),
|
||||
)
|
||||
@@ -1404,6 +1413,7 @@ def run_simulation(seed: int = 3901, *, include_prompts: bool = False) -> dict[s
|
||||
str(route_case["subcategory"]),
|
||||
str(route_case["focus"]),
|
||||
str(route_case["family"]),
|
||||
route_case.get("positions") or (),
|
||||
)
|
||||
cases.append(
|
||||
_case_report(
|
||||
@@ -1459,6 +1469,38 @@ def run_simulation(seed: int = 3901, *, include_prompts: bool = False) -> dict[s
|
||||
}
|
||||
|
||||
|
||||
def run_simulation_sweep(
|
||||
seed: int = 3901,
|
||||
*,
|
||||
count: int = 3,
|
||||
seed_step: int = 101,
|
||||
include_prompts: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
count = max(1, int(count))
|
||||
seed_step = int(seed_step)
|
||||
seeds = [seed + index * seed_step for index in range(count)]
|
||||
runs = [run_simulation(seed=current_seed, include_prompts=include_prompts) for current_seed in seeds]
|
||||
issues: list[dict[str, Any]] = []
|
||||
for run in runs:
|
||||
run_seed = (run.get("summary") or {}).get("seed")
|
||||
issues.extend({"seed": run_seed, **issue} for issue in run.get("issues") or [])
|
||||
return {
|
||||
"summary": {
|
||||
"seed": seed,
|
||||
"seed_step": seed_step,
|
||||
"seeds": seeds,
|
||||
"runs": len(runs),
|
||||
"cases": sum((run.get("summary") or {}).get("cases", 0) for run in runs),
|
||||
"coverage_checks": sum((run.get("summary") or {}).get("coverage_checks", 0) for run in runs),
|
||||
"axis_checks": sum((run.get("summary") or {}).get("axis_checks", 0) for run in runs),
|
||||
"pair_seed_checks": sum((run.get("summary") or {}).get("pair_seed_checks", 0) for run in runs),
|
||||
"issues": len(issues),
|
||||
},
|
||||
"issues": issues,
|
||||
"runs": runs,
|
||||
}
|
||||
|
||||
|
||||
def _print_text_report(report: dict[str, Any]) -> None:
|
||||
summary = report.get("summary") or {}
|
||||
print(
|
||||
@@ -1490,17 +1532,49 @@ def _print_text_report(report: dict[str, Any]) -> None:
|
||||
print(f" ISSUE {issue}")
|
||||
|
||||
|
||||
def _print_sweep_report(report: dict[str, Any]) -> None:
|
||||
summary = report.get("summary") or {}
|
||||
seeds = ", ".join(str(seed) for seed in (summary.get("seeds") or []))
|
||||
print(
|
||||
f"Prompt route simulation sweep: seed={summary.get('seed')} "
|
||||
f"seed_step={summary.get('seed_step')} runs={summary.get('runs')} "
|
||||
f"seeds={seeds} cases={summary.get('cases')} coverage_checks={summary.get('coverage_checks')} "
|
||||
f"axis_checks={summary.get('axis_checks')} pair_seed_checks={summary.get('pair_seed_checks')} "
|
||||
f"issues={summary.get('issues')}"
|
||||
)
|
||||
for run in report.get("runs") or []:
|
||||
run_summary = run.get("summary") or {}
|
||||
print(
|
||||
f"- seed {run_summary.get('seed')}: "
|
||||
f"cases={run_summary.get('cases')} issues={run_summary.get('issues')}"
|
||||
)
|
||||
for issue in run.get("issues") or []:
|
||||
print(f" ISSUE {issue.get('case')}: {issue.get('issue')}")
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--seed", type=int, default=3901, help="Base seed for deterministic simulations.")
|
||||
parser.add_argument("--sweep-count", type=int, default=1, help="Run this many seed-spaced simulations.")
|
||||
parser.add_argument("--seed-step", type=int, default=101, help="Seed increment used by --sweep-count.")
|
||||
parser.add_argument("--json", action="store_true", help="Print the full JSON report.")
|
||||
parser.add_argument("--include-prompts", action="store_true", help="Include raw and formatted prompt text in the report.")
|
||||
parser.add_argument("--fail-on-issues", action="store_true", help="Exit with code 1 when any issue is reported.")
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
report = run_simulation(seed=args.seed, include_prompts=args.include_prompts)
|
||||
if args.sweep_count > 1:
|
||||
report = run_simulation_sweep(
|
||||
seed=args.seed,
|
||||
count=args.sweep_count,
|
||||
seed_step=args.seed_step,
|
||||
include_prompts=args.include_prompts,
|
||||
)
|
||||
else:
|
||||
report = run_simulation(seed=args.seed, include_prompts=args.include_prompts)
|
||||
if args.json:
|
||||
print(json.dumps(report, ensure_ascii=True, indent=2, sort_keys=True))
|
||||
elif args.sweep_count > 1:
|
||||
_print_sweep_report(report)
|
||||
else:
|
||||
_print_text_report(report)
|
||||
return 1 if args.fail_on_issues and report.get("issues") else 0
|
||||
|
||||
@@ -8025,6 +8025,12 @@ def smoke_prompt_route_simulation_policy() -> None:
|
||||
pair_seed_checks["pair_seed.pose_reroll"].get("changed") is True,
|
||||
"Pair pose reroll should prove hard action can reroll while soft/cast/scene axes stay locked",
|
||||
)
|
||||
sweep = prompt_route_simulation.run_simulation_sweep(seed=3901, count=3, seed_step=101, include_prompts=False)
|
||||
sweep_summary = sweep.get("summary") or {}
|
||||
_expect(sweep_summary.get("runs") == 3, "Prompt route simulation sweep lost run coverage")
|
||||
_expect(sweep_summary.get("seeds") == [3901, 4002, 4103], "Prompt route simulation sweep seed sequence changed")
|
||||
_expect(sweep_summary.get("cases") == 42, "Prompt route simulation sweep case count changed")
|
||||
_expect(sweep_summary.get("issues") == 0, f"Prompt route simulation sweep reported issues: {sweep.get('issues')}")
|
||||
|
||||
|
||||
def smoke_node_camera_registration() -> None:
|
||||
|
||||
Reference in New Issue
Block a user