diff --git a/sdxl_tag_policy.py b/sdxl_tag_policy.py index 8ef733d..edd692d 100644 --- a/sdxl_tag_policy.py +++ b/sdxl_tag_policy.py @@ -21,6 +21,17 @@ except ImportError: # Allows local smoke tests with `python -c`. PROMPT_FIELD_LABELS = input_policy.prompt_field_labels() +INCOMPATIBLE_ROUTE_TAGS = { + "action:penetration": ("oral sex", "outercourse", "anal sex", "manual stimulation"), + "action:oral": ("penetrative sex", "penetration", "anal sex", "outercourse"), + "action:outercourse": ("penetrative sex", "penetration", "oral sex", "anal sex", "manual stimulation"), + "position:penetrative": ("oral sex", "outercourse", "anal sex", "manual stimulation"), + "position:oral": ("penetrative sex", "penetration", "anal sex", "outercourse"), + "position:outercourse": ("penetrative sex", "penetration", "oral sex", "anal sex", "manual stimulation"), + "position:manual": ("penetrative sex", "penetration", "oral sex", "anal sex", "outercourse"), + "position:anal": ("oral sex", "outercourse", "manual stimulation"), +} + def clean(value: Any) -> str: return input_policy.clean_text(value) @@ -238,7 +249,33 @@ def explicit_tags(text: str, nude_weight: float) -> list[str]: tags.append("penetration") if "vaginal" in lower: tags.append("pussy") - if "oral" in lower or "mouth" in lower: + oral_terms = ( + "oral sex", + "oral-sex", + "blowjob", + "deepthroat", + "fellatio", + "cunnilingus", + "pussy licking", + "mouth on", + "mouth pressed", + "mouth contact", + "mouth around", + "lips wrapped", + "takes the penis in her mouth", + "takes the man's penis", + "takes the viewer's penis", + "penis in her mouth", + "tongue on pussy", + "tongue along the penis", + "tongue along the penis shaft", + "tongue touches the underside", + "licking the penis", + "testicle sucking", + "balls licking", + "balls-licking", + ) + if any(token in lower for token in oral_terms): tags.append("oral sex") if "anal" in lower: tags.append("anal sex") @@ -247,6 +284,18 @@ def explicit_tags(text: str, nude_weight: float) -> list[str]: return tags +def filter_incompatible_route_tags(tags: list[str], row: dict[str, Any]) -> list[str]: + action_family = route_metadata_policy.row_action_family(row) + position_family = route_metadata_policy.row_position_family(row) + blocked: set[str] = set() + for scope, family in (("action", action_family), ("position", position_family)): + for tag in INCOMPATIBLE_ROUTE_TAGS.get(f"{scope}:{family}", ()): + blocked.add(tag_key(tag)) + if not blocked: + return tags + return [tag for tag in tags if tag_key(tag) not in blocked] + + def softcore_pair_tags(row: dict[str, Any], root: dict[str, Any]) -> list[str]: tags = ["softcore teaser", softcore_text_policy.softcore_style_tag()] options = root.get("options") if isinstance(root.get("options"), dict) else {} @@ -274,5 +323,6 @@ def tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies: axis_value_tags=axis_value_tags, camera_tags=camera_tags, explicit_tags=explicit_tags, + filter_incompatible_route_tags=filter_incompatible_route_tags, softcore_pair_tags=softcore_pair_tags, ) diff --git a/sdxl_tag_routes.py b/sdxl_tag_routes.py index 2918781..47c2e56 100644 --- a/sdxl_tag_routes.py +++ b/sdxl_tag_routes.py @@ -42,6 +42,7 @@ class SDXLTagRouteDependencies: axis_value_tags: Callable[[dict[str, Any]], list[str]] camera_tags: Callable[..., list[str]] explicit_tags: Callable[[str, float], list[str]] + filter_incompatible_route_tags: Callable[[list[str], dict[str, Any]], list[str]] softcore_pair_tags: Callable[[dict[str, Any], dict[str, Any]], list[str]] @@ -148,6 +149,7 @@ def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependenc ) for tag in deps.explicit_tags(combined, request.nude_weight): deps.add_one(tags, seen, tag) + tags = deps.filter_incompatible_route_tags(tags, row) return SDXLTagRoute(tags) @@ -231,6 +233,7 @@ def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition]) for tag in deps.explicit_tags(combined, request.nude_weight): deps.add_one(tags, seen, tag) + tags = deps.filter_incompatible_route_tags(tags, row) return SDXLTagRoute(tags) diff --git a/tools/prompt_route_simulation.py b/tools/prompt_route_simulation.py index 8a7846d..b76234d 100644 --- a/tools/prompt_route_simulation.py +++ b/tools/prompt_route_simulation.py @@ -52,6 +52,14 @@ HARDCORE_NOISE_TERMS = ( "the scene contains", ) +INCOMPATIBLE_SDXL_TAGS = { + "penetration": ("oral sex", "outercourse", "anal sex", "manual stimulation"), + "oral": ("penetrative sex", "penetration", "anal sex", "outercourse"), + "outercourse": ("penetrative sex", "penetration", "oral sex", "anal sex"), + "manual": ("penetrative sex", "penetration", "oral sex", "anal sex"), + "anal": ("oral sex", "outercourse"), +} + def _json(value: Any) -> str: return json.dumps(value, ensure_ascii=True, sort_keys=True) @@ -235,7 +243,13 @@ def _text_issues(label: str, value: Any, *, min_len: int = 8) -> list[str]: return issues -def _formatter_issues(name: str, formats: dict[str, Any], *, is_pov: bool = False) -> list[str]: +def _formatter_issues( + name: str, + formats: dict[str, Any], + *, + row: dict[str, Any] | None = None, + is_pov: bool = False, +) -> list[str]: issues: list[str] = [] krea = formats["krea"] sdxl = formats["sdxl"] @@ -274,6 +288,12 @@ def _formatter_issues(name: str, formats: dict[str, Any], *, is_pov: bool = Fals for noise in HARDCORE_NOISE_TERMS: if noise in lower_krea: issues.append(f"{name}.krea_prompt: hardcore_noise:{noise}") + if isinstance(row, dict): + family = str(row.get("action_family") or "").strip() + sdxl_lower = f", {sdxl_prompt.lower()}, " + for tag in INCOMPATIBLE_SDXL_TAGS.get(family, ()): + if f", {tag}, " in sdxl_lower: + issues.append(f"{name}.sdxl_prompt: incompatible_family_tag:{family}:{tag}") if is_pov: if "viewer" not in lower_krea or "first-person" not in lower_krea: issues.append(f"{name}.krea_prompt: pov_wording_missing") @@ -324,7 +344,7 @@ def _case_report( is_pov: bool = False, ) -> dict[str, Any]: formats = _format_metadata(metadata, target) - issues = _formatter_issues(name, formats, is_pov=is_pov) + issues = _formatter_issues(name, formats, row=metadata, is_pov=is_pov) issues.extend(_route_metadata_issues(name, metadata)) if target == "softcore": issues.extend(_softcore_issues(f"{name}.krea_prompt", formats["krea"].get("krea_prompt"))) @@ -354,11 +374,11 @@ def _pair_reports(name: str, pair: dict[str, Any], *, include_prompts: bool) -> hard_row = dict(pair.get("hardcore_row") or {}) soft_formats = _format_metadata(pair, "softcore") hard_formats = _format_metadata(pair, "hardcore") - soft_issues = _formatter_issues(f"{name}.softcore", soft_formats) + soft_issues = _formatter_issues(f"{name}.softcore", soft_formats, row=soft_row) soft_issues.extend(_route_metadata_issues(f"{name}.softcore", soft_row)) soft_issues.extend(_softcore_issues(f"{name}.softcore.krea_prompt", soft_formats["krea"].get("krea_prompt"))) hard_is_pov = bool(hard_row.get("pov_character_labels")) - hard_issues = _formatter_issues(f"{name}.hardcore", hard_formats, is_pov=hard_is_pov) + hard_issues = _formatter_issues(f"{name}.hardcore", hard_formats, row=hard_row, is_pov=hard_is_pov) hard_issues.extend(_route_metadata_issues(f"{name}.hardcore", hard_row)) reports = [ { diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 6df3182..0b28028 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -4604,7 +4604,30 @@ def smoke_sdxl_tag_policy() -> None: _expect(deps.axis_value_tags is sdxl_tag_policy.axis_value_tags, "SDXL route deps lost axis-value tag policy") _expect(deps.camera_tags is sdxl_tag_policy.camera_tags, "SDXL route deps lost camera tag policy") _expect(deps.explicit_tags is sdxl_tag_policy.explicit_tags, "SDXL route deps lost explicit tag policy") + _expect( + deps.filter_incompatible_route_tags is sdxl_tag_policy.filter_incompatible_route_tags, + "SDXL route deps lost route-family tag filter", + ) _expect(deps.softcore_pair_tags is sdxl_tag_policy.softcore_pair_tags, "SDXL route deps lost softcore pair tag policy") + mouth_nearby_tags = sdxl_tag_policy.explicit_tags( + "missionary penetration with mouth close to the ear", + 1.29, + ) + _expect("penetration" in mouth_nearby_tags, "SDXL explicit tags lost penetration signal") + _expect("oral sex" not in mouth_nearby_tags, "SDXL explicit tags should not treat nearby mouth wording as oral") + outercourse_filtered_tags = sdxl_tag_policy.filter_incompatible_route_tags( + ["outercourse", "penis licking", "oral sex", "penetration"], + _fixture_hardcore_row( + action_family="outercourse", + position_family="outercourse", + position_key="penis_licking", + position_keys=["penis_licking"], + ), + ) + _expect("outercourse" in outercourse_filtered_tags, "SDXL route filter removed matching outercourse tag") + _expect("penis licking" in outercourse_filtered_tags, "SDXL route filter removed specific outercourse key") + _expect("oral sex" not in outercourse_filtered_tags, "SDXL route filter kept incompatible oral tag") + _expect("penetration" not in outercourse_filtered_tags, "SDXL route filter kept incompatible penetration tag") stale_character_row = { "prompt": "Characters: 99-year-old adult man, stale body, stale skin, stale hair, stale eyes.", @@ -4731,6 +4754,28 @@ def smoke_sdxl_tag_routes() -> None: ).as_text() for required in ("(naked:1.29)", "pussy", "penis", "penetration"): _expect(required in metadata_tags, f"SDXL row tags lost structured explicit metadata tag: {required}") + outercourse_noise_row = _fixture_hardcore_row( + item="penis-licking outercourse position with tongue along the penis shaft", + pose="configured outercourse pose", + role_graph="Woman A bends low while her tongue runs along Man A's penis shaft.", + source_role_graph="Woman A bends low while her tongue runs along Man A's penis shaft.", + item_axis_values={ + "position": "penis-licking outercourse position", + "outer_act": "tongue along the penis shaft", + }, + action_family="outercourse", + position_family="outercourse", + position_key="penis_licking", + position_keys=["penis_licking"], + ) + outercourse_noise_tags = sdxl_tag_routes.row_core_tags_result( + sdxl_tag_routes.SDXLRowTagRequest(outercourse_noise_row, 1.29), + deps, + ).as_text() + _expect("outercourse" in outercourse_noise_tags, "SDXL outercourse row lost matching family tag") + _expect("penis licking" in outercourse_noise_tags, "SDXL outercourse row lost specific position key") + _expect("oral sex" not in outercourse_noise_tags, "SDXL outercourse row kept incompatible oral tag") + _expect("penetration" not in outercourse_noise_tags, "SDXL outercourse row kept incompatible penetration tag") pair = pb.build_insta_of_pair( row_number=1,