diff --git a/sdxl_tag_routes.py b/sdxl_tag_routes.py index 47c2e56..f5621c7 100644 --- a/sdxl_tag_routes.py +++ b/sdxl_tag_routes.py @@ -96,6 +96,14 @@ def _row_explicit_signal_text( return " ".join(deps.clean(value) for value in values if deps.clean(value)) +def _uses_hardcore_action_route(row: dict[str, Any]) -> bool: + return ( + str(row.get("category_slug") or "").strip() == "hardcore_sexual_poses" + or bool(str(row.get("action_family") or "").strip()) + or bool(str(row.get("position_family") or "").strip()) + ) + + def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute: row = request.row tags: list[str] = [] @@ -117,7 +125,7 @@ def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependenc item = deps.row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or deps.clean( row.get("custom_item") ) - pose = deps.row_value(row, "pose", ("Sexual pose", "Pose")) + pose = "" if _uses_hardcore_action_route(row) else deps.row_value(row, "pose", ("Sexual pose", "Pose")) role_graph = deps.clean(row.get("source_role_graph") or row.get("role_graph")) scene = deps.row_value(row, "scene_text", ("Setting", "Scene")) or deps.clean(row.get("scene")) expression = deps.row_value(row, "character_expression_text") or deps.row_value( diff --git a/tools/prompt_route_simulation.py b/tools/prompt_route_simulation.py index b76234d..3110be8 100644 --- a/tools/prompt_route_simulation.py +++ b/tools/prompt_route_simulation.py @@ -24,6 +24,7 @@ import caption_naturalizer # noqa: E402 import krea_formatter # noqa: E402 import prompt_builder as pb # noqa: E402 import sdxl_formatter # noqa: E402 +import sdxl_tag_policy # noqa: E402 TRIGGER = "sxcppnl7" @@ -52,14 +53,6 @@ HARDCORE_NOISE_TERMS = ( "the scene contains", ) -INCOMPATIBLE_SDXL_TAGS = { - "penetration": ("oral sex", "outercourse", "anal sex", "manual stimulation"), - "oral": ("penetrative sex", "penetration", "anal sex", "outercourse"), - "outercourse": ("penetrative sex", "penetration", "oral sex", "anal sex"), - "manual": ("penetrative sex", "penetration", "oral sex", "anal sex"), - "anal": ("oral sex", "outercourse"), -} - def _json(value: Any) -> str: return json.dumps(value, ensure_ascii=True, sort_keys=True) @@ -178,6 +171,82 @@ def _insta_options() -> str: ) +HARDCORE_ROUTE_CASES = ( + { + "name": "hardcore.single.oral", + "subcategory": "Oral sex", + "focus": "oral_only", + "family": "oral", + "expected_route": {"action_family": "oral", "position_family": "oral"}, + "expected_terms": { + "krea": ("mouth",), + "sdxl": ("oral sex",), + "caption": ("oral action",), + }, + }, + { + "name": "hardcore.single.manual", + "subcategory": "Manual stimulation", + "focus": "manual_only", + "family": "manual", + "expected_route": {"position_family": "manual"}, + "expected_terms": { + "krea": ("hand",), + "sdxl": ("manual stimulation",), + "caption": ("manual action",), + }, + }, + { + "name": "hardcore.single.outercourse", + "subcategory": "Outercourse and genital teasing", + "focus": "outercourse_only", + "family": "outercourse", + "expected_route": {"action_family": "outercourse", "position_family": "outercourse"}, + "expected_terms": { + "krea": ("penis",), + "sdxl": ("outercourse",), + "caption": ("non-penetrative action",), + }, + }, + { + "name": "hardcore.single.foreplay", + "subcategory": "Foreplay and teasing", + "focus": "foreplay_only", + "family": "foreplay", + "expected_route": {"action_family": "foreplay", "position_family": "foreplay"}, + "expected_terms": { + "krea": ("clothing",), + "sdxl": ("foreplay",), + "caption": ("foreplay action",), + }, + }, + { + "name": "hardcore.single.anal", + "subcategory": "Anal and double penetration", + "focus": "anal_only", + "family": "anal", + "expected_route": {"position_family": "anal"}, + "expected_terms": { + "krea": ("anal",), + "sdxl": ("anal sex",), + "caption": ("anal action",), + }, + }, + { + "name": "hardcore.single.climax", + "subcategory": "Cumshot and climax", + "focus": "climax_only", + "family": "climax", + "expected_route": {"action_family": "climax", "position_family": "climax"}, + "expected_terms": { + "krea": ("ejaculation",), + "sdxl": ("climax", "semen"), + "caption": ("climax action",), + }, + }, +) + + def _format_metadata(metadata: dict[str, Any], target: str) -> dict[str, Any]: metadata_json = _json(metadata) krea = krea_formatter.format_krea2_prompt( @@ -243,11 +312,36 @@ def _text_issues(label: str, value: Any, *, min_len: int = 8) -> list[str]: return issues +def _contains_all(text: str, required: tuple[str, ...]) -> bool: + lower = text.lower() + return all(term.lower() in lower for term in required) + + +def _formatter_expectation_issues( + name: str, + formats: dict[str, Any], + expected_terms: dict[str, tuple[str, ...]] | None, +) -> list[str]: + if not expected_terms: + return [] + prompts = { + "krea": str(formats["krea"].get("krea_prompt") or ""), + "sdxl": str(formats["sdxl"].get("sdxl_prompt") or ""), + "caption": str(formats["caption"].get("natural_caption") or ""), + } + issues: list[str] = [] + for formatter_name, required in expected_terms.items(): + if required and not _contains_all(prompts.get(formatter_name, ""), required): + issues.append(f"{name}.{formatter_name}: missing_route_terms:{required}") + return issues + + def _formatter_issues( name: str, formats: dict[str, Any], *, row: dict[str, Any] | None = None, + expected_terms: dict[str, tuple[str, ...]] | None = None, is_pov: bool = False, ) -> list[str]: issues: list[str] = [] @@ -289,11 +383,13 @@ def _formatter_issues( if noise in lower_krea: issues.append(f"{name}.krea_prompt: hardcore_noise:{noise}") if isinstance(row, dict): - family = str(row.get("action_family") or "").strip() sdxl_lower = f", {sdxl_prompt.lower()}, " - for tag in INCOMPATIBLE_SDXL_TAGS.get(family, ()): - if f", {tag}, " in sdxl_lower: - issues.append(f"{name}.sdxl_prompt: incompatible_family_tag:{family}:{tag}") + for scope, family in (("action", row.get("action_family")), ("position", row.get("position_family"))): + route_key = f"{scope}:{str(family or '').strip()}" + for tag in sdxl_tag_policy.INCOMPATIBLE_ROUTE_TAGS.get(route_key, ()): + if f", {tag}, " in sdxl_lower: + issues.append(f"{name}.sdxl_prompt: incompatible_family_tag:{route_key}:{tag}") + issues.extend(_formatter_expectation_issues(name, formats, expected_terms)) if is_pov: if "viewer" not in lower_krea or "first-person" not in lower_krea: issues.append(f"{name}.krea_prompt: pov_wording_missing") @@ -335,17 +431,36 @@ def _route_metadata_issues(name: str, row: dict[str, Any]) -> list[str]: return [] +def _route_expectation_issues(name: str, row: dict[str, Any], expected_route: dict[str, Any] | None) -> list[str]: + if not expected_route: + return [] + issues: list[str] = [] + for key in ("action_family", "position_family", "position_key"): + expected = expected_route.get(key) + if expected and row.get(key) != expected: + issues.append(f"{name}: {key}_mismatch:{row.get(key)} != {expected}") + for key, expected_values in (("position_keys", expected_route.get("position_keys") or ()),): + current = set(str(value) for value in (row.get(key) or [])) + for value in expected_values: + if str(value) not in current: + issues.append(f"{name}: missing_{key}:{value}") + return issues + + def _case_report( name: str, metadata: dict[str, Any], *, target: str, include_prompts: bool, + expected_route: dict[str, Any] | None = None, + expected_terms: dict[str, tuple[str, ...]] | None = None, is_pov: bool = False, ) -> dict[str, Any]: formats = _format_metadata(metadata, target) - issues = _formatter_issues(name, formats, row=metadata, is_pov=is_pov) + issues = _formatter_issues(name, formats, row=metadata, expected_terms=expected_terms, is_pov=is_pov) issues.extend(_route_metadata_issues(name, metadata)) + issues.extend(_route_expectation_issues(name, metadata, expected_route)) if target == "softcore": issues.extend(_softcore_issues(f"{name}.krea_prompt", formats["krea"].get("krea_prompt"))) report = { @@ -435,6 +550,36 @@ def _regular_single_case(seed: int) -> dict[str, Any]: ) +def _hardcore_single_case(seed: int, subcategory: str, focus: str, family: str) -> dict[str, Any]: + return pb.build_prompt( + category="Hardcore sexual poses", + subcategory=subcategory, + row_number=1, + start_index=1, + seed=seed, + clothing="random", + ethnicity="any", + poses="random", + backside_bias=0.0, + figure="random", + no_plus_women=False, + no_black=False, + minimal_clothing_ratio=-1, + standard_pose_ratio=-1, + trigger=TRIGGER, + prepend_trigger_to_prompt=True, + extra_positive="", + extra_negative="", + seed_config=pb.build_seed_lock_config_json(base_seed=seed), + women_count=1, + men_count=1, + character_cast=_character_cast(), + hardcore_position_config=_position_filter(focus, family, []), + location_config=_coworking_location_config(), + camera_config=_orbit_camera(horizontal_angle=35, vertical_angle=0, zoom=6.5), + ) + + def _insta_pair_case(seed: int, *, pov: bool, position: str, focus: str, family: str) -> dict[str, Any]: return pb.build_insta_of_pair( row_number=1, @@ -539,6 +684,23 @@ def run_simulation(seed: int = 3901, *, include_prompts: bool = False) -> dict[s cases: list[dict[str, Any]] = [] regular = _regular_single_case(seed) cases.append(_case_report("regular.single.casual", regular, target="single", include_prompts=include_prompts)) + for offset, route_case in enumerate(HARDCORE_ROUTE_CASES, start=10): + row = _hardcore_single_case( + seed + offset, + str(route_case["subcategory"]), + str(route_case["focus"]), + str(route_case["family"]), + ) + cases.append( + _case_report( + str(route_case["name"]), + row, + target="single", + include_prompts=include_prompts, + expected_route=route_case.get("expected_route"), + expected_terms=route_case.get("expected_terms"), + ) + ) penetration_pair = _insta_pair_case(seed + 1, pov=False, position="doggy", focus="penetration_only", family="penetration") cases.extend(_pair_reports("insta_pair.penetration", penetration_pair, include_prompts=include_prompts)) pov_pair = _insta_pair_case(seed + 2, pov=True, position="penis_licking", focus="outercourse_only", family="outercourse") diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 0b28028..29dae03 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -4776,6 +4776,27 @@ def smoke_sdxl_tag_routes() -> None: _expect("penis licking" in outercourse_noise_tags, "SDXL outercourse row lost specific position key") _expect("oral sex" not in outercourse_noise_tags, "SDXL outercourse row kept incompatible oral tag") _expect("penetration" not in outercourse_noise_tags, "SDXL outercourse row kept incompatible penetration tag") + stale_hardcore_pose_row = _fixture_hardcore_row( + item="oral contact with mouth on the visible genitals in side-lying oral position", + pose="kneeling and balancing a cucumber upright on an open palm held overhead", + role_graph="Woman A lies on her side while Man A's mouth is pressed to her pussy.", + source_role_graph="Woman A lies on her side while Man A's mouth is pressed to her pussy.", + item_axis_values={ + "position": "side-lying oral position", + "oral_act": "oral contact with mouth on the visible genitals", + }, + action_family="oral", + position_family="oral", + position_key="side_lying", + position_keys=["side_lying"], + ) + stale_hardcore_pose_tags = sdxl_tag_routes.row_core_tags_result( + sdxl_tag_routes.SDXLRowTagRequest(stale_hardcore_pose_row, 1.29), + deps, + ).as_text() + _expect("oral sex" in stale_hardcore_pose_tags, "SDXL hardcore route lost oral family tag") + _expect("side lying" in stale_hardcore_pose_tags, "SDXL hardcore route lost structured position key") + _expect("cucumber" not in stale_hardcore_pose_tags, "SDXL hardcore route leaked generic stale pose text") pair = pb.build_insta_of_pair( row_number=1, @@ -7812,10 +7833,19 @@ def smoke_seed_config_policy() -> None: def smoke_prompt_route_simulation_policy() -> None: report = prompt_route_simulation.run_simulation(seed=3901, include_prompts=False) summary = report.get("summary") or {} - _expect(summary.get("cases") == 5, "Prompt route simulation case count changed unexpectedly") + _expect(summary.get("cases") == 11, "Prompt route simulation case count changed unexpectedly") _expect(summary.get("axis_checks") == 1, "Prompt route simulation lost axis check coverage") _expect(summary.get("issues") == 0, f"Prompt route simulation reported issues: {report.get('issues')}") cases = {case.get("name"): case for case in report.get("cases") or []} + for route_name in ( + "hardcore.single.oral", + "hardcore.single.manual", + "hardcore.single.outercourse", + "hardcore.single.foreplay", + "hardcore.single.anal", + "hardcore.single.climax", + ): + _expect(route_name in cases, f"Prompt route simulation lost route family case {route_name}") pov_hard = cases.get("insta_pair.pov_outercourse.hardcore") or {} pov_summary = pov_hard.get("summary") or {} _expect(