Expand route simulation coverage

This commit is contained in:
2026-06-27 18:49:01 +02:00
parent cac4fe47cd
commit f91953f12b
3 changed files with 215 additions and 15 deletions
+9 -1
View File
@@ -96,6 +96,14 @@ def _row_explicit_signal_text(
return " ".join(deps.clean(value) for value in values if deps.clean(value)) return " ".join(deps.clean(value) for value in values if deps.clean(value))
def _uses_hardcore_action_route(row: dict[str, Any]) -> bool:
return (
str(row.get("category_slug") or "").strip() == "hardcore_sexual_poses"
or bool(str(row.get("action_family") or "").strip())
or bool(str(row.get("position_family") or "").strip())
)
def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute: def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependencies) -> SDXLTagRoute:
row = request.row row = request.row
tags: list[str] = [] tags: list[str] = []
@@ -117,7 +125,7 @@ def row_core_tags_result(request: SDXLRowTagRequest, deps: SDXLTagRouteDependenc
item = deps.row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or deps.clean( item = deps.row_value(row, "item", ("Sexual scene", "Sexual pose", "Erotic outfit", "Clothing")) or deps.clean(
row.get("custom_item") row.get("custom_item")
) )
pose = deps.row_value(row, "pose", ("Sexual pose", "Pose")) pose = "" if _uses_hardcore_action_route(row) else deps.row_value(row, "pose", ("Sexual pose", "Pose"))
role_graph = deps.clean(row.get("source_role_graph") or row.get("role_graph")) role_graph = deps.clean(row.get("source_role_graph") or row.get("role_graph"))
scene = deps.row_value(row, "scene_text", ("Setting", "Scene")) or deps.clean(row.get("scene")) scene = deps.row_value(row, "scene_text", ("Setting", "Scene")) or deps.clean(row.get("scene"))
expression = deps.row_value(row, "character_expression_text") or deps.row_value( expression = deps.row_value(row, "character_expression_text") or deps.row_value(
+174 -12
View File
@@ -24,6 +24,7 @@ import caption_naturalizer # noqa: E402
import krea_formatter # noqa: E402 import krea_formatter # noqa: E402
import prompt_builder as pb # noqa: E402 import prompt_builder as pb # noqa: E402
import sdxl_formatter # noqa: E402 import sdxl_formatter # noqa: E402
import sdxl_tag_policy # noqa: E402
TRIGGER = "sxcppnl7" TRIGGER = "sxcppnl7"
@@ -52,14 +53,6 @@ HARDCORE_NOISE_TERMS = (
"the scene contains", "the scene contains",
) )
INCOMPATIBLE_SDXL_TAGS = {
"penetration": ("oral sex", "outercourse", "anal sex", "manual stimulation"),
"oral": ("penetrative sex", "penetration", "anal sex", "outercourse"),
"outercourse": ("penetrative sex", "penetration", "oral sex", "anal sex"),
"manual": ("penetrative sex", "penetration", "oral sex", "anal sex"),
"anal": ("oral sex", "outercourse"),
}
def _json(value: Any) -> str: def _json(value: Any) -> str:
return json.dumps(value, ensure_ascii=True, sort_keys=True) return json.dumps(value, ensure_ascii=True, sort_keys=True)
@@ -178,6 +171,82 @@ def _insta_options() -> str:
) )
HARDCORE_ROUTE_CASES = (
{
"name": "hardcore.single.oral",
"subcategory": "Oral sex",
"focus": "oral_only",
"family": "oral",
"expected_route": {"action_family": "oral", "position_family": "oral"},
"expected_terms": {
"krea": ("mouth",),
"sdxl": ("oral sex",),
"caption": ("oral action",),
},
},
{
"name": "hardcore.single.manual",
"subcategory": "Manual stimulation",
"focus": "manual_only",
"family": "manual",
"expected_route": {"position_family": "manual"},
"expected_terms": {
"krea": ("hand",),
"sdxl": ("manual stimulation",),
"caption": ("manual action",),
},
},
{
"name": "hardcore.single.outercourse",
"subcategory": "Outercourse and genital teasing",
"focus": "outercourse_only",
"family": "outercourse",
"expected_route": {"action_family": "outercourse", "position_family": "outercourse"},
"expected_terms": {
"krea": ("penis",),
"sdxl": ("outercourse",),
"caption": ("non-penetrative action",),
},
},
{
"name": "hardcore.single.foreplay",
"subcategory": "Foreplay and teasing",
"focus": "foreplay_only",
"family": "foreplay",
"expected_route": {"action_family": "foreplay", "position_family": "foreplay"},
"expected_terms": {
"krea": ("clothing",),
"sdxl": ("foreplay",),
"caption": ("foreplay action",),
},
},
{
"name": "hardcore.single.anal",
"subcategory": "Anal and double penetration",
"focus": "anal_only",
"family": "anal",
"expected_route": {"position_family": "anal"},
"expected_terms": {
"krea": ("anal",),
"sdxl": ("anal sex",),
"caption": ("anal action",),
},
},
{
"name": "hardcore.single.climax",
"subcategory": "Cumshot and climax",
"focus": "climax_only",
"family": "climax",
"expected_route": {"action_family": "climax", "position_family": "climax"},
"expected_terms": {
"krea": ("ejaculation",),
"sdxl": ("climax", "semen"),
"caption": ("climax action",),
},
},
)
def _format_metadata(metadata: dict[str, Any], target: str) -> dict[str, Any]: def _format_metadata(metadata: dict[str, Any], target: str) -> dict[str, Any]:
metadata_json = _json(metadata) metadata_json = _json(metadata)
krea = krea_formatter.format_krea2_prompt( krea = krea_formatter.format_krea2_prompt(
@@ -243,11 +312,36 @@ def _text_issues(label: str, value: Any, *, min_len: int = 8) -> list[str]:
return issues return issues
def _contains_all(text: str, required: tuple[str, ...]) -> bool:
lower = text.lower()
return all(term.lower() in lower for term in required)
def _formatter_expectation_issues(
name: str,
formats: dict[str, Any],
expected_terms: dict[str, tuple[str, ...]] | None,
) -> list[str]:
if not expected_terms:
return []
prompts = {
"krea": str(formats["krea"].get("krea_prompt") or ""),
"sdxl": str(formats["sdxl"].get("sdxl_prompt") or ""),
"caption": str(formats["caption"].get("natural_caption") or ""),
}
issues: list[str] = []
for formatter_name, required in expected_terms.items():
if required and not _contains_all(prompts.get(formatter_name, ""), required):
issues.append(f"{name}.{formatter_name}: missing_route_terms:{required}")
return issues
def _formatter_issues( def _formatter_issues(
name: str, name: str,
formats: dict[str, Any], formats: dict[str, Any],
*, *,
row: dict[str, Any] | None = None, row: dict[str, Any] | None = None,
expected_terms: dict[str, tuple[str, ...]] | None = None,
is_pov: bool = False, is_pov: bool = False,
) -> list[str]: ) -> list[str]:
issues: list[str] = [] issues: list[str] = []
@@ -289,11 +383,13 @@ def _formatter_issues(
if noise in lower_krea: if noise in lower_krea:
issues.append(f"{name}.krea_prompt: hardcore_noise:{noise}") issues.append(f"{name}.krea_prompt: hardcore_noise:{noise}")
if isinstance(row, dict): if isinstance(row, dict):
family = str(row.get("action_family") or "").strip()
sdxl_lower = f", {sdxl_prompt.lower()}, " sdxl_lower = f", {sdxl_prompt.lower()}, "
for tag in INCOMPATIBLE_SDXL_TAGS.get(family, ()): for scope, family in (("action", row.get("action_family")), ("position", row.get("position_family"))):
route_key = f"{scope}:{str(family or '').strip()}"
for tag in sdxl_tag_policy.INCOMPATIBLE_ROUTE_TAGS.get(route_key, ()):
if f", {tag}, " in sdxl_lower: if f", {tag}, " in sdxl_lower:
issues.append(f"{name}.sdxl_prompt: incompatible_family_tag:{family}:{tag}") issues.append(f"{name}.sdxl_prompt: incompatible_family_tag:{route_key}:{tag}")
issues.extend(_formatter_expectation_issues(name, formats, expected_terms))
if is_pov: if is_pov:
if "viewer" not in lower_krea or "first-person" not in lower_krea: if "viewer" not in lower_krea or "first-person" not in lower_krea:
issues.append(f"{name}.krea_prompt: pov_wording_missing") issues.append(f"{name}.krea_prompt: pov_wording_missing")
@@ -335,17 +431,36 @@ def _route_metadata_issues(name: str, row: dict[str, Any]) -> list[str]:
return [] return []
def _route_expectation_issues(name: str, row: dict[str, Any], expected_route: dict[str, Any] | None) -> list[str]:
if not expected_route:
return []
issues: list[str] = []
for key in ("action_family", "position_family", "position_key"):
expected = expected_route.get(key)
if expected and row.get(key) != expected:
issues.append(f"{name}: {key}_mismatch:{row.get(key)} != {expected}")
for key, expected_values in (("position_keys", expected_route.get("position_keys") or ()),):
current = set(str(value) for value in (row.get(key) or []))
for value in expected_values:
if str(value) not in current:
issues.append(f"{name}: missing_{key}:{value}")
return issues
def _case_report( def _case_report(
name: str, name: str,
metadata: dict[str, Any], metadata: dict[str, Any],
*, *,
target: str, target: str,
include_prompts: bool, include_prompts: bool,
expected_route: dict[str, Any] | None = None,
expected_terms: dict[str, tuple[str, ...]] | None = None,
is_pov: bool = False, is_pov: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
formats = _format_metadata(metadata, target) formats = _format_metadata(metadata, target)
issues = _formatter_issues(name, formats, row=metadata, is_pov=is_pov) issues = _formatter_issues(name, formats, row=metadata, expected_terms=expected_terms, is_pov=is_pov)
issues.extend(_route_metadata_issues(name, metadata)) issues.extend(_route_metadata_issues(name, metadata))
issues.extend(_route_expectation_issues(name, metadata, expected_route))
if target == "softcore": if target == "softcore":
issues.extend(_softcore_issues(f"{name}.krea_prompt", formats["krea"].get("krea_prompt"))) issues.extend(_softcore_issues(f"{name}.krea_prompt", formats["krea"].get("krea_prompt")))
report = { report = {
@@ -435,6 +550,36 @@ def _regular_single_case(seed: int) -> dict[str, Any]:
) )
def _hardcore_single_case(seed: int, subcategory: str, focus: str, family: str) -> dict[str, Any]:
return pb.build_prompt(
category="Hardcore sexual poses",
subcategory=subcategory,
row_number=1,
start_index=1,
seed=seed,
clothing="random",
ethnicity="any",
poses="random",
backside_bias=0.0,
figure="random",
no_plus_women=False,
no_black=False,
minimal_clothing_ratio=-1,
standard_pose_ratio=-1,
trigger=TRIGGER,
prepend_trigger_to_prompt=True,
extra_positive="",
extra_negative="",
seed_config=pb.build_seed_lock_config_json(base_seed=seed),
women_count=1,
men_count=1,
character_cast=_character_cast(),
hardcore_position_config=_position_filter(focus, family, []),
location_config=_coworking_location_config(),
camera_config=_orbit_camera(horizontal_angle=35, vertical_angle=0, zoom=6.5),
)
def _insta_pair_case(seed: int, *, pov: bool, position: str, focus: str, family: str) -> dict[str, Any]: def _insta_pair_case(seed: int, *, pov: bool, position: str, focus: str, family: str) -> dict[str, Any]:
return pb.build_insta_of_pair( return pb.build_insta_of_pair(
row_number=1, row_number=1,
@@ -539,6 +684,23 @@ def run_simulation(seed: int = 3901, *, include_prompts: bool = False) -> dict[s
cases: list[dict[str, Any]] = [] cases: list[dict[str, Any]] = []
regular = _regular_single_case(seed) regular = _regular_single_case(seed)
cases.append(_case_report("regular.single.casual", regular, target="single", include_prompts=include_prompts)) cases.append(_case_report("regular.single.casual", regular, target="single", include_prompts=include_prompts))
for offset, route_case in enumerate(HARDCORE_ROUTE_CASES, start=10):
row = _hardcore_single_case(
seed + offset,
str(route_case["subcategory"]),
str(route_case["focus"]),
str(route_case["family"]),
)
cases.append(
_case_report(
str(route_case["name"]),
row,
target="single",
include_prompts=include_prompts,
expected_route=route_case.get("expected_route"),
expected_terms=route_case.get("expected_terms"),
)
)
penetration_pair = _insta_pair_case(seed + 1, pov=False, position="doggy", focus="penetration_only", family="penetration") penetration_pair = _insta_pair_case(seed + 1, pov=False, position="doggy", focus="penetration_only", family="penetration")
cases.extend(_pair_reports("insta_pair.penetration", penetration_pair, include_prompts=include_prompts)) cases.extend(_pair_reports("insta_pair.penetration", penetration_pair, include_prompts=include_prompts))
pov_pair = _insta_pair_case(seed + 2, pov=True, position="penis_licking", focus="outercourse_only", family="outercourse") pov_pair = _insta_pair_case(seed + 2, pov=True, position="penis_licking", focus="outercourse_only", family="outercourse")
+31 -1
View File
@@ -4776,6 +4776,27 @@ def smoke_sdxl_tag_routes() -> None:
_expect("penis licking" in outercourse_noise_tags, "SDXL outercourse row lost specific position key") _expect("penis licking" in outercourse_noise_tags, "SDXL outercourse row lost specific position key")
_expect("oral sex" not in outercourse_noise_tags, "SDXL outercourse row kept incompatible oral tag") _expect("oral sex" not in outercourse_noise_tags, "SDXL outercourse row kept incompatible oral tag")
_expect("penetration" not in outercourse_noise_tags, "SDXL outercourse row kept incompatible penetration tag") _expect("penetration" not in outercourse_noise_tags, "SDXL outercourse row kept incompatible penetration tag")
stale_hardcore_pose_row = _fixture_hardcore_row(
item="oral contact with mouth on the visible genitals in side-lying oral position",
pose="kneeling and balancing a cucumber upright on an open palm held overhead",
role_graph="Woman A lies on her side while Man A's mouth is pressed to her pussy.",
source_role_graph="Woman A lies on her side while Man A's mouth is pressed to her pussy.",
item_axis_values={
"position": "side-lying oral position",
"oral_act": "oral contact with mouth on the visible genitals",
},
action_family="oral",
position_family="oral",
position_key="side_lying",
position_keys=["side_lying"],
)
stale_hardcore_pose_tags = sdxl_tag_routes.row_core_tags_result(
sdxl_tag_routes.SDXLRowTagRequest(stale_hardcore_pose_row, 1.29),
deps,
).as_text()
_expect("oral sex" in stale_hardcore_pose_tags, "SDXL hardcore route lost oral family tag")
_expect("side lying" in stale_hardcore_pose_tags, "SDXL hardcore route lost structured position key")
_expect("cucumber" not in stale_hardcore_pose_tags, "SDXL hardcore route leaked generic stale pose text")
pair = pb.build_insta_of_pair( pair = pb.build_insta_of_pair(
row_number=1, row_number=1,
@@ -7812,10 +7833,19 @@ def smoke_seed_config_policy() -> None:
def smoke_prompt_route_simulation_policy() -> None: def smoke_prompt_route_simulation_policy() -> None:
report = prompt_route_simulation.run_simulation(seed=3901, include_prompts=False) report = prompt_route_simulation.run_simulation(seed=3901, include_prompts=False)
summary = report.get("summary") or {} summary = report.get("summary") or {}
_expect(summary.get("cases") == 5, "Prompt route simulation case count changed unexpectedly") _expect(summary.get("cases") == 11, "Prompt route simulation case count changed unexpectedly")
_expect(summary.get("axis_checks") == 1, "Prompt route simulation lost axis check coverage") _expect(summary.get("axis_checks") == 1, "Prompt route simulation lost axis check coverage")
_expect(summary.get("issues") == 0, f"Prompt route simulation reported issues: {report.get('issues')}") _expect(summary.get("issues") == 0, f"Prompt route simulation reported issues: {report.get('issues')}")
cases = {case.get("name"): case for case in report.get("cases") or []} cases = {case.get("name"): case for case in report.get("cases") or []}
for route_name in (
"hardcore.single.oral",
"hardcore.single.manual",
"hardcore.single.outercourse",
"hardcore.single.foreplay",
"hardcore.single.anal",
"hardcore.single.climax",
):
_expect(route_name in cases, f"Prompt route simulation lost route family case {route_name}")
pov_hard = cases.get("insta_pair.pov_outercourse.hardcore") or {} pov_hard = cases.get("insta_pair.pov_outercourse.hardcore") or {}
pov_summary = pov_hard.get("summary") or {} pov_summary = pov_hard.get("summary") or {}
_expect( _expect(