diff --git a/tools/prompt_route_simulation.py b/tools/prompt_route_simulation.py index d3f186c..aa78813 100644 --- a/tools/prompt_route_simulation.py +++ b/tools/prompt_route_simulation.py @@ -21,6 +21,8 @@ if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import caption_naturalizer # noqa: E402 +import hardcore_action_metadata # noqa: E402 +import hardcore_position_config # noqa: E402 import krea_formatter # noqa: E402 import prompt_builder as pb # noqa: E402 import sdxl_formatter # noqa: E402 @@ -313,6 +315,18 @@ HARDCORE_ROUTE_CASES = ( "caption": ("foreplay action",), }, }, + { + "name": "hardcore.single.interaction", + "subcategory": "Aftercare and cleanup", + "focus": "interaction_only", + "family": "interaction", + "expected_route": {"action_family": "foreplay", "position_family": "interaction"}, + "expected_terms": { + "krea": ("mid-transition",), + "sdxl": ("interaction",), + "caption": ("interaction beat",), + }, + }, { "name": "hardcore.single.anal", "subcategory": "Anal and double penetration", @@ -363,6 +377,14 @@ HARDCORE_ROUTE_CASES = ( }, ) +ROUTE_SIM_ACTION_FAMILY_EXCLUSIONS = { + "default", + # Dedicated double-contact route assertions live in prompt_smoke because they + # need a multi-person, position-specific fixture rather than a broad family case. + "toy_double", +} +ROUTE_SIM_POSITION_FAMILY_EXCLUSIONS = {"any"} + def _format_metadata(metadata: dict[str, Any], target: str) -> dict[str, Any]: metadata_json = _json(metadata) @@ -996,6 +1018,61 @@ def _seed_axis_checks(seed: int) -> list[dict[str, Any]]: ] +def _route_family_coverage_check( + name: str, + *, + expected: set[str], + observed: set[str], +) -> dict[str, Any]: + missing = sorted(expected - observed) + unexpected = sorted(observed - expected) + issues: list[str] = [] + if missing: + issues.append(f"{name}: missing_family_coverage:{missing}") + if unexpected: + issues.append(f"{name}: unexpected_family_coverage:{unexpected}") + return { + "name": name, + "expected": sorted(expected), + "observed": sorted(observed), + "missing": missing, + "unexpected": unexpected, + "issues": issues, + } + + +def _route_family_coverage_checks(cases: list[dict[str, Any]]) -> list[dict[str, Any]]: + summaries = [ + case.get("summary") or {} + for case in cases + if case.get("target") in ("single", "hardcore") + ] + observed_actions = { + hardcore_action_metadata.normalize_hardcore_action_family(summary.get("action_family"), "") + for summary in summaries + } + observed_actions.discard("") + observed_positions = { + hardcore_position_config.normalize_hardcore_position_family(summary.get("position_family"), "") + for summary in summaries + } + observed_positions.discard("") + expected_actions = set(hardcore_action_metadata.HARDCORE_ACTION_FAMILY_CHOICES) - ROUTE_SIM_ACTION_FAMILY_EXCLUSIONS + expected_positions = set(hardcore_position_config.hardcore_position_family_choices()) - ROUTE_SIM_POSITION_FAMILY_EXCLUSIONS + return [ + _route_family_coverage_check( + "route_coverage.action_families", + expected=expected_actions, + observed=observed_actions, + ), + _route_family_coverage_check( + "route_coverage.position_families", + expected=expected_positions, + observed=observed_positions, + ), + ] + + def run_simulation(seed: int = 3901, *, include_prompts: bool = False) -> dict[str, Any]: cases: list[dict[str, Any]] = [] regular = _regular_single_case(seed) @@ -1021,12 +1098,18 @@ def run_simulation(seed: int = 3901, *, include_prompts: bool = False) -> dict[s cases.extend(_pair_reports("insta_pair.penetration", penetration_pair, include_prompts=include_prompts)) pov_pair = _insta_pair_case(seed + 2, pov=True, position="penis_licking", focus="outercourse_only", family="outercourse") cases.extend(_pair_reports("insta_pair.pov_outercourse", pov_pair, include_prompts=include_prompts)) + coverage_checks = _route_family_coverage_checks(cases) axis_checks = _seed_axis_checks(seed + 3) issues = [ {"case": case["name"], "issue": issue} for case in cases for issue in case.get("issues", []) ] + issues.extend( + {"case": check["name"], "issue": issue} + for check in coverage_checks + for issue in check.get("issues", []) + ) issues.extend( {"case": check["name"], "issue": issue} for check in axis_checks @@ -1036,11 +1119,13 @@ def run_simulation(seed: int = 3901, *, include_prompts: bool = False) -> dict[s "summary": { "seed": seed, "cases": len(cases), + "coverage_checks": len(coverage_checks), "axis_checks": len(axis_checks), "issues": len(issues), }, "issues": issues, "cases": cases, + "coverage_checks": coverage_checks, "axis_checks": axis_checks, } @@ -1049,7 +1134,8 @@ def _print_text_report(report: dict[str, Any]) -> None: summary = report.get("summary") or {} print( f"Prompt route simulation: seed={summary.get('seed')} " - f"cases={summary.get('cases')} axis_checks={summary.get('axis_checks')} issues={summary.get('issues')}" + f"cases={summary.get('cases')} coverage_checks={summary.get('coverage_checks')} " + f"axis_checks={summary.get('axis_checks')} issues={summary.get('issues')}" ) for case in report.get("cases") or []: summary_text = case.get("summary") or {} @@ -1057,6 +1143,13 @@ def _print_text_report(report: dict[str, Any]) -> None: print(f"- {case.get('name')} [{case.get('target')}]: {route}") for issue in case.get("issues") or []: print(f" ISSUE {issue}") + for check in report.get("coverage_checks") or []: + print( + f"- {check.get('name')}: " + f"observed={', '.join(check.get('observed') or [])}" + ) + for issue in check.get("issues") or []: + print(f" ISSUE {issue}") for check in report.get("axis_checks") or []: print(f"- {check.get('name')}: changed={check.get('changed')}") for issue in check.get("issues") or []: diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 58b149e..bc90331 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -7911,7 +7911,8 @@ def smoke_seed_config_policy() -> None: def smoke_prompt_route_simulation_policy() -> None: report = prompt_route_simulation.run_simulation(seed=3901, include_prompts=False) summary = report.get("summary") or {} - _expect(summary.get("cases") == 13, "Prompt route simulation case count changed unexpectedly") + _expect(summary.get("cases") == 14, "Prompt route simulation case count changed unexpectedly") + _expect(summary.get("coverage_checks") == 2, "Prompt route simulation lost family coverage checks") _expect(summary.get("axis_checks") == 6, "Prompt route simulation lost axis check coverage") _expect(summary.get("issues") == 0, f"Prompt route simulation reported issues: {report.get('issues')}") cases = {case.get("name"): case for case in report.get("cases") or []} @@ -7920,10 +7921,29 @@ def smoke_prompt_route_simulation_policy() -> None: "hardcore.single.manual", "hardcore.single.outercourse", "hardcore.single.foreplay", + "hardcore.single.interaction", "hardcore.single.anal", + "hardcore.single.threesome", + "hardcore.single.group", "hardcore.single.climax", + "insta_pair.penetration.hardcore", ): _expect(route_name in cases, f"Prompt route simulation lost route family case {route_name}") + coverage_checks = {check.get("name"): check for check in report.get("coverage_checks") or []} + action_coverage = coverage_checks.get("route_coverage.action_families") or {} + position_coverage = coverage_checks.get("route_coverage.position_families") or {} + _expect(not action_coverage.get("issues"), f"Prompt route simulation action coverage failed: {action_coverage}") + _expect(not position_coverage.get("issues"), f"Prompt route simulation position coverage failed: {position_coverage}") + expected_actions = ( + set(hardcore_action_metadata.HARDCORE_ACTION_FAMILY_CHOICES) + - prompt_route_simulation.ROUTE_SIM_ACTION_FAMILY_EXCLUSIONS + ) + expected_positions = ( + set(hardcore_position_config.hardcore_position_family_choices()) + - prompt_route_simulation.ROUTE_SIM_POSITION_FAMILY_EXCLUSIONS + ) + _expect(set(action_coverage.get("observed") or []) == expected_actions, "Prompt route simulation action coverage drifted") + _expect(set(position_coverage.get("observed") or []) == expected_positions, "Prompt route simulation position coverage drifted") pov_hard = cases.get("insta_pair.pov_outercourse.hardcore") or {} pov_summary = pov_hard.get("summary") or {} _expect(