Filter incompatible SDXL route tags
This commit is contained in:
+51
-1
@@ -21,6 +21,17 @@ except ImportError: # Allows local smoke tests with `python -c`.
|
|||||||
|
|
||||||
PROMPT_FIELD_LABELS = input_policy.prompt_field_labels()
|
PROMPT_FIELD_LABELS = input_policy.prompt_field_labels()
|
||||||
|
|
||||||
|
INCOMPATIBLE_ROUTE_TAGS = {
|
||||||
|
"action:penetration": ("oral sex", "outercourse", "anal sex", "manual stimulation"),
|
||||||
|
"action:oral": ("penetrative sex", "penetration", "anal sex", "outercourse"),
|
||||||
|
"action:outercourse": ("penetrative sex", "penetration", "oral sex", "anal sex", "manual stimulation"),
|
||||||
|
"position:penetrative": ("oral sex", "outercourse", "anal sex", "manual stimulation"),
|
||||||
|
"position:oral": ("penetrative sex", "penetration", "anal sex", "outercourse"),
|
||||||
|
"position:outercourse": ("penetrative sex", "penetration", "oral sex", "anal sex", "manual stimulation"),
|
||||||
|
"position:manual": ("penetrative sex", "penetration", "oral sex", "anal sex", "outercourse"),
|
||||||
|
"position:anal": ("oral sex", "outercourse", "manual stimulation"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def clean(value: Any) -> str:
|
def clean(value: Any) -> str:
|
||||||
return input_policy.clean_text(value)
|
return input_policy.clean_text(value)
|
||||||
@@ -238,7 +249,33 @@ def explicit_tags(text: str, nude_weight: float) -> list[str]:
|
|||||||
tags.append("penetration")
|
tags.append("penetration")
|
||||||
if "vaginal" in lower:
|
if "vaginal" in lower:
|
||||||
tags.append("pussy")
|
tags.append("pussy")
|
||||||
if "oral" in lower or "mouth" in lower:
|
oral_terms = (
|
||||||
|
"oral sex",
|
||||||
|
"oral-sex",
|
||||||
|
"blowjob",
|
||||||
|
"deepthroat",
|
||||||
|
"fellatio",
|
||||||
|
"cunnilingus",
|
||||||
|
"pussy licking",
|
||||||
|
"mouth on",
|
||||||
|
"mouth pressed",
|
||||||
|
"mouth contact",
|
||||||
|
"mouth around",
|
||||||
|
"lips wrapped",
|
||||||
|
"takes the penis in her mouth",
|
||||||
|
"takes the man's penis",
|
||||||
|
"takes the viewer's penis",
|
||||||
|
"penis in her mouth",
|
||||||
|
"tongue on pussy",
|
||||||
|
"tongue along the penis",
|
||||||
|
"tongue along the penis shaft",
|
||||||
|
"tongue touches the underside",
|
||||||
|
"licking the penis",
|
||||||
|
"testicle sucking",
|
||||||
|
"balls licking",
|
||||||
|
"balls-licking",
|
||||||
|
)
|
||||||
|
if any(token in lower for token in oral_terms):
|
||||||
tags.append("oral sex")
|
tags.append("oral sex")
|
||||||
if "anal" in lower:
|
if "anal" in lower:
|
||||||
tags.append("anal sex")
|
tags.append("anal sex")
|
||||||
@@ -247,6 +284,18 @@ def explicit_tags(text: str, nude_weight: float) -> list[str]:
|
|||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
def filter_incompatible_route_tags(tags: list[str], row: dict[str, Any]) -> list[str]:
|
||||||
|
action_family = route_metadata_policy.row_action_family(row)
|
||||||
|
position_family = route_metadata_policy.row_position_family(row)
|
||||||
|
blocked: set[str] = set()
|
||||||
|
for scope, family in (("action", action_family), ("position", position_family)):
|
||||||
|
for tag in INCOMPATIBLE_ROUTE_TAGS.get(f"{scope}:{family}", ()):
|
||||||
|
blocked.add(tag_key(tag))
|
||||||
|
if not blocked:
|
||||||
|
return tags
|
||||||
|
return [tag for tag in tags if tag_key(tag) not in blocked]
|
||||||
|
|
||||||
|
|
||||||
def softcore_pair_tags(row: dict[str, Any], root: dict[str, Any]) -> list[str]:
|
def softcore_pair_tags(row: dict[str, Any], root: dict[str, Any]) -> list[str]:
|
||||||
tags = ["softcore teaser", softcore_text_policy.softcore_style_tag()]
|
tags = ["softcore teaser", softcore_text_policy.softcore_style_tag()]
|
||||||
options = root.get("options") if isinstance(root.get("options"), dict) else {}
|
options = root.get("options") if isinstance(root.get("options"), dict) else {}
|
||||||
@@ -274,5 +323,6 @@ def tag_route_dependencies() -> sdxl_tag_routes.SDXLTagRouteDependencies:
|
|||||||
axis_value_tags=axis_value_tags,
|
axis_value_tags=axis_value_tags,
|
||||||
camera_tags=camera_tags,
|
camera_tags=camera_tags,
|
||||||
explicit_tags=explicit_tags,
|
explicit_tags=explicit_tags,
|
||||||
|
filter_incompatible_route_tags=filter_incompatible_route_tags,
|
||||||
softcore_pair_tags=softcore_pair_tags,
|
softcore_pair_tags=softcore_pair_tags,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class SDXLTagRouteDependencies:
|
|||||||
axis_value_tags: Callable[[dict[str, Any]], list[str]]
|
axis_value_tags: Callable[[dict[str, Any]], list[str]]
|
||||||
camera_tags: Callable[..., list[str]]
|
camera_tags: Callable[..., list[str]]
|
||||||
explicit_tags: Callable[[str, float], list[str]]
|
explicit_tags: Callable[[str, float], list[str]]
|
||||||
|
filter_incompatible_route_tags: Callable[[list[str], dict[str, Any]], list[str]]
|
||||||
softcore_pair_tags: Callable[[dict[str, Any], dict[str, Any]], list[str]]
|
softcore_pair_tags: Callable[[dict[str, Any], dict[str, Any]], list[str]]
|
||||||
|
|
||||||
|
|
||||||
@@ -148,6 +149,7 @@ def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependenc
|
|||||||
)
|
)
|
||||||
for tag in deps.explicit_tags(combined, request.nude_weight):
|
for tag in deps.explicit_tags(combined, request.nude_weight):
|
||||||
deps.add_one(tags, seen, tag)
|
deps.add_one(tags, seen, tag)
|
||||||
|
tags = deps.filter_incompatible_route_tags(tags, row)
|
||||||
return SDXLTagRoute(tags)
|
return SDXLTagRoute(tags)
|
||||||
|
|
||||||
|
|
||||||
@@ -231,6 +233,7 @@ def hard_tags_result(request: SDXLPairTagRequest, deps: SDXLTagRouteDependencies
|
|||||||
combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition])
|
combined = " ".join([hard_role, hard_item, hard_clothing, expression, composition])
|
||||||
for tag in deps.explicit_tags(combined, request.nude_weight):
|
for tag in deps.explicit_tags(combined, request.nude_weight):
|
||||||
deps.add_one(tags, seen, tag)
|
deps.add_one(tags, seen, tag)
|
||||||
|
tags = deps.filter_incompatible_route_tags(tags, row)
|
||||||
return SDXLTagRoute(tags)
|
return SDXLTagRoute(tags)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,14 @@ HARDCORE_NOISE_TERMS = (
|
|||||||
"the scene contains",
|
"the scene contains",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
INCOMPATIBLE_SDXL_TAGS = {
|
||||||
|
"penetration": ("oral sex", "outercourse", "anal sex", "manual stimulation"),
|
||||||
|
"oral": ("penetrative sex", "penetration", "anal sex", "outercourse"),
|
||||||
|
"outercourse": ("penetrative sex", "penetration", "oral sex", "anal sex"),
|
||||||
|
"manual": ("penetrative sex", "penetration", "oral sex", "anal sex"),
|
||||||
|
"anal": ("oral sex", "outercourse"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _json(value: Any) -> str:
|
def _json(value: Any) -> str:
|
||||||
return json.dumps(value, ensure_ascii=True, sort_keys=True)
|
return json.dumps(value, ensure_ascii=True, sort_keys=True)
|
||||||
@@ -235,7 +243,13 @@ def _text_issues(label: str, value: Any, *, min_len: int = 8) -> list[str]:
|
|||||||
return issues
|
return issues
|
||||||
|
|
||||||
|
|
||||||
def _formatter_issues(name: str, formats: dict[str, Any], *, is_pov: bool = False) -> list[str]:
|
def _formatter_issues(
|
||||||
|
name: str,
|
||||||
|
formats: dict[str, Any],
|
||||||
|
*,
|
||||||
|
row: dict[str, Any] | None = None,
|
||||||
|
is_pov: bool = False,
|
||||||
|
) -> list[str]:
|
||||||
issues: list[str] = []
|
issues: list[str] = []
|
||||||
krea = formats["krea"]
|
krea = formats["krea"]
|
||||||
sdxl = formats["sdxl"]
|
sdxl = formats["sdxl"]
|
||||||
@@ -274,6 +288,12 @@ def _formatter_issues(name: str, formats: dict[str, Any], *, is_pov: bool = Fals
|
|||||||
for noise in HARDCORE_NOISE_TERMS:
|
for noise in HARDCORE_NOISE_TERMS:
|
||||||
if noise in lower_krea:
|
if noise in lower_krea:
|
||||||
issues.append(f"{name}.krea_prompt: hardcore_noise:{noise}")
|
issues.append(f"{name}.krea_prompt: hardcore_noise:{noise}")
|
||||||
|
if isinstance(row, dict):
|
||||||
|
family = str(row.get("action_family") or "").strip()
|
||||||
|
sdxl_lower = f", {sdxl_prompt.lower()}, "
|
||||||
|
for tag in INCOMPATIBLE_SDXL_TAGS.get(family, ()):
|
||||||
|
if f", {tag}, " in sdxl_lower:
|
||||||
|
issues.append(f"{name}.sdxl_prompt: incompatible_family_tag:{family}:{tag}")
|
||||||
if is_pov:
|
if is_pov:
|
||||||
if "viewer" not in lower_krea or "first-person" not in lower_krea:
|
if "viewer" not in lower_krea or "first-person" not in lower_krea:
|
||||||
issues.append(f"{name}.krea_prompt: pov_wording_missing")
|
issues.append(f"{name}.krea_prompt: pov_wording_missing")
|
||||||
@@ -324,7 +344,7 @@ def _case_report(
|
|||||||
is_pov: bool = False,
|
is_pov: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
formats = _format_metadata(metadata, target)
|
formats = _format_metadata(metadata, target)
|
||||||
issues = _formatter_issues(name, formats, is_pov=is_pov)
|
issues = _formatter_issues(name, formats, row=metadata, is_pov=is_pov)
|
||||||
issues.extend(_route_metadata_issues(name, metadata))
|
issues.extend(_route_metadata_issues(name, metadata))
|
||||||
if target == "softcore":
|
if target == "softcore":
|
||||||
issues.extend(_softcore_issues(f"{name}.krea_prompt", formats["krea"].get("krea_prompt")))
|
issues.extend(_softcore_issues(f"{name}.krea_prompt", formats["krea"].get("krea_prompt")))
|
||||||
@@ -354,11 +374,11 @@ def _pair_reports(name: str, pair: dict[str, Any], *, include_prompts: bool) ->
|
|||||||
hard_row = dict(pair.get("hardcore_row") or {})
|
hard_row = dict(pair.get("hardcore_row") or {})
|
||||||
soft_formats = _format_metadata(pair, "softcore")
|
soft_formats = _format_metadata(pair, "softcore")
|
||||||
hard_formats = _format_metadata(pair, "hardcore")
|
hard_formats = _format_metadata(pair, "hardcore")
|
||||||
soft_issues = _formatter_issues(f"{name}.softcore", soft_formats)
|
soft_issues = _formatter_issues(f"{name}.softcore", soft_formats, row=soft_row)
|
||||||
soft_issues.extend(_route_metadata_issues(f"{name}.softcore", soft_row))
|
soft_issues.extend(_route_metadata_issues(f"{name}.softcore", soft_row))
|
||||||
soft_issues.extend(_softcore_issues(f"{name}.softcore.krea_prompt", soft_formats["krea"].get("krea_prompt")))
|
soft_issues.extend(_softcore_issues(f"{name}.softcore.krea_prompt", soft_formats["krea"].get("krea_prompt")))
|
||||||
hard_is_pov = bool(hard_row.get("pov_character_labels"))
|
hard_is_pov = bool(hard_row.get("pov_character_labels"))
|
||||||
hard_issues = _formatter_issues(f"{name}.hardcore", hard_formats, is_pov=hard_is_pov)
|
hard_issues = _formatter_issues(f"{name}.hardcore", hard_formats, row=hard_row, is_pov=hard_is_pov)
|
||||||
hard_issues.extend(_route_metadata_issues(f"{name}.hardcore", hard_row))
|
hard_issues.extend(_route_metadata_issues(f"{name}.hardcore", hard_row))
|
||||||
reports = [
|
reports = [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -4604,7 +4604,30 @@ def smoke_sdxl_tag_policy() -> None:
|
|||||||
_expect(deps.axis_value_tags is sdxl_tag_policy.axis_value_tags, "SDXL route deps lost axis-value tag policy")
|
_expect(deps.axis_value_tags is sdxl_tag_policy.axis_value_tags, "SDXL route deps lost axis-value tag policy")
|
||||||
_expect(deps.camera_tags is sdxl_tag_policy.camera_tags, "SDXL route deps lost camera tag policy")
|
_expect(deps.camera_tags is sdxl_tag_policy.camera_tags, "SDXL route deps lost camera tag policy")
|
||||||
_expect(deps.explicit_tags is sdxl_tag_policy.explicit_tags, "SDXL route deps lost explicit tag policy")
|
_expect(deps.explicit_tags is sdxl_tag_policy.explicit_tags, "SDXL route deps lost explicit tag policy")
|
||||||
|
_expect(
|
||||||
|
deps.filter_incompatible_route_tags is sdxl_tag_policy.filter_incompatible_route_tags,
|
||||||
|
"SDXL route deps lost route-family tag filter",
|
||||||
|
)
|
||||||
_expect(deps.softcore_pair_tags is sdxl_tag_policy.softcore_pair_tags, "SDXL route deps lost softcore pair tag policy")
|
_expect(deps.softcore_pair_tags is sdxl_tag_policy.softcore_pair_tags, "SDXL route deps lost softcore pair tag policy")
|
||||||
|
mouth_nearby_tags = sdxl_tag_policy.explicit_tags(
|
||||||
|
"missionary penetration with mouth close to the ear",
|
||||||
|
1.29,
|
||||||
|
)
|
||||||
|
_expect("penetration" in mouth_nearby_tags, "SDXL explicit tags lost penetration signal")
|
||||||
|
_expect("oral sex" not in mouth_nearby_tags, "SDXL explicit tags should not treat nearby mouth wording as oral")
|
||||||
|
outercourse_filtered_tags = sdxl_tag_policy.filter_incompatible_route_tags(
|
||||||
|
["outercourse", "penis licking", "oral sex", "penetration"],
|
||||||
|
_fixture_hardcore_row(
|
||||||
|
action_family="outercourse",
|
||||||
|
position_family="outercourse",
|
||||||
|
position_key="penis_licking",
|
||||||
|
position_keys=["penis_licking"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
_expect("outercourse" in outercourse_filtered_tags, "SDXL route filter removed matching outercourse tag")
|
||||||
|
_expect("penis licking" in outercourse_filtered_tags, "SDXL route filter removed specific outercourse key")
|
||||||
|
_expect("oral sex" not in outercourse_filtered_tags, "SDXL route filter kept incompatible oral tag")
|
||||||
|
_expect("penetration" not in outercourse_filtered_tags, "SDXL route filter kept incompatible penetration tag")
|
||||||
|
|
||||||
stale_character_row = {
|
stale_character_row = {
|
||||||
"prompt": "Characters: 99-year-old adult man, stale body, stale skin, stale hair, stale eyes.",
|
"prompt": "Characters: 99-year-old adult man, stale body, stale skin, stale hair, stale eyes.",
|
||||||
@@ -4731,6 +4754,28 @@ def smoke_sdxl_tag_routes() -> None:
|
|||||||
).as_text()
|
).as_text()
|
||||||
for required in ("(naked:1.29)", "pussy", "penis", "penetration"):
|
for required in ("(naked:1.29)", "pussy", "penis", "penetration"):
|
||||||
_expect(required in metadata_tags, f"SDXL row tags lost structured explicit metadata tag: {required}")
|
_expect(required in metadata_tags, f"SDXL row tags lost structured explicit metadata tag: {required}")
|
||||||
|
outercourse_noise_row = _fixture_hardcore_row(
|
||||||
|
item="penis-licking outercourse position with tongue along the penis shaft",
|
||||||
|
pose="configured outercourse pose",
|
||||||
|
role_graph="Woman A bends low while her tongue runs along Man A's penis shaft.",
|
||||||
|
source_role_graph="Woman A bends low while her tongue runs along Man A's penis shaft.",
|
||||||
|
item_axis_values={
|
||||||
|
"position": "penis-licking outercourse position",
|
||||||
|
"outer_act": "tongue along the penis shaft",
|
||||||
|
},
|
||||||
|
action_family="outercourse",
|
||||||
|
position_family="outercourse",
|
||||||
|
position_key="penis_licking",
|
||||||
|
position_keys=["penis_licking"],
|
||||||
|
)
|
||||||
|
outercourse_noise_tags = sdxl_tag_routes.row_core_tags_result(
|
||||||
|
sdxl_tag_routes.SDXLRowTagRequest(outercourse_noise_row, 1.29),
|
||||||
|
deps,
|
||||||
|
).as_text()
|
||||||
|
_expect("outercourse" in outercourse_noise_tags, "SDXL outercourse row lost matching family tag")
|
||||||
|
_expect("penis licking" in outercourse_noise_tags, "SDXL outercourse row lost specific position key")
|
||||||
|
_expect("oral sex" not in outercourse_noise_tags, "SDXL outercourse row kept incompatible oral tag")
|
||||||
|
_expect("penetration" not in outercourse_noise_tags, "SDXL outercourse row kept incompatible penetration tag")
|
||||||
|
|
||||||
pair = pb.build_insta_of_pair(
|
pair = pb.build_insta_of_pair(
|
||||||
row_number=1,
|
row_number=1,
|
||||||
|
|||||||
Reference in New Issue
Block a user