Expand route simulation coverage
This commit is contained in:
@@ -24,6 +24,7 @@ import caption_naturalizer # noqa: E402
|
||||
import krea_formatter # noqa: E402
|
||||
import prompt_builder as pb # noqa: E402
|
||||
import sdxl_formatter # noqa: E402
|
||||
import sdxl_tag_policy # noqa: E402
|
||||
|
||||
|
||||
TRIGGER = "sxcppnl7"
|
||||
@@ -52,14 +53,6 @@ HARDCORE_NOISE_TERMS = (
|
||||
"the scene contains",
|
||||
)
|
||||
|
||||
INCOMPATIBLE_SDXL_TAGS = {
|
||||
"penetration": ("oral sex", "outercourse", "anal sex", "manual stimulation"),
|
||||
"oral": ("penetrative sex", "penetration", "anal sex", "outercourse"),
|
||||
"outercourse": ("penetrative sex", "penetration", "oral sex", "anal sex"),
|
||||
"manual": ("penetrative sex", "penetration", "oral sex", "anal sex"),
|
||||
"anal": ("oral sex", "outercourse"),
|
||||
}
|
||||
|
||||
|
||||
def _json(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=True, sort_keys=True)
|
||||
@@ -178,6 +171,82 @@ def _insta_options() -> str:
|
||||
)
|
||||
|
||||
|
||||
HARDCORE_ROUTE_CASES = (
|
||||
{
|
||||
"name": "hardcore.single.oral",
|
||||
"subcategory": "Oral sex",
|
||||
"focus": "oral_only",
|
||||
"family": "oral",
|
||||
"expected_route": {"action_family": "oral", "position_family": "oral"},
|
||||
"expected_terms": {
|
||||
"krea": ("mouth",),
|
||||
"sdxl": ("oral sex",),
|
||||
"caption": ("oral action",),
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "hardcore.single.manual",
|
||||
"subcategory": "Manual stimulation",
|
||||
"focus": "manual_only",
|
||||
"family": "manual",
|
||||
"expected_route": {"position_family": "manual"},
|
||||
"expected_terms": {
|
||||
"krea": ("hand",),
|
||||
"sdxl": ("manual stimulation",),
|
||||
"caption": ("manual action",),
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "hardcore.single.outercourse",
|
||||
"subcategory": "Outercourse and genital teasing",
|
||||
"focus": "outercourse_only",
|
||||
"family": "outercourse",
|
||||
"expected_route": {"action_family": "outercourse", "position_family": "outercourse"},
|
||||
"expected_terms": {
|
||||
"krea": ("penis",),
|
||||
"sdxl": ("outercourse",),
|
||||
"caption": ("non-penetrative action",),
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "hardcore.single.foreplay",
|
||||
"subcategory": "Foreplay and teasing",
|
||||
"focus": "foreplay_only",
|
||||
"family": "foreplay",
|
||||
"expected_route": {"action_family": "foreplay", "position_family": "foreplay"},
|
||||
"expected_terms": {
|
||||
"krea": ("clothing",),
|
||||
"sdxl": ("foreplay",),
|
||||
"caption": ("foreplay action",),
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "hardcore.single.anal",
|
||||
"subcategory": "Anal and double penetration",
|
||||
"focus": "anal_only",
|
||||
"family": "anal",
|
||||
"expected_route": {"position_family": "anal"},
|
||||
"expected_terms": {
|
||||
"krea": ("anal",),
|
||||
"sdxl": ("anal sex",),
|
||||
"caption": ("anal action",),
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "hardcore.single.climax",
|
||||
"subcategory": "Cumshot and climax",
|
||||
"focus": "climax_only",
|
||||
"family": "climax",
|
||||
"expected_route": {"action_family": "climax", "position_family": "climax"},
|
||||
"expected_terms": {
|
||||
"krea": ("ejaculation",),
|
||||
"sdxl": ("climax", "semen"),
|
||||
"caption": ("climax action",),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _format_metadata(metadata: dict[str, Any], target: str) -> dict[str, Any]:
|
||||
metadata_json = _json(metadata)
|
||||
krea = krea_formatter.format_krea2_prompt(
|
||||
@@ -243,11 +312,36 @@ def _text_issues(label: str, value: Any, *, min_len: int = 8) -> list[str]:
|
||||
return issues
|
||||
|
||||
|
||||
def _contains_all(text: str, required: tuple[str, ...]) -> bool:
|
||||
lower = text.lower()
|
||||
return all(term.lower() in lower for term in required)
|
||||
|
||||
|
||||
def _formatter_expectation_issues(
|
||||
name: str,
|
||||
formats: dict[str, Any],
|
||||
expected_terms: dict[str, tuple[str, ...]] | None,
|
||||
) -> list[str]:
|
||||
if not expected_terms:
|
||||
return []
|
||||
prompts = {
|
||||
"krea": str(formats["krea"].get("krea_prompt") or ""),
|
||||
"sdxl": str(formats["sdxl"].get("sdxl_prompt") or ""),
|
||||
"caption": str(formats["caption"].get("natural_caption") or ""),
|
||||
}
|
||||
issues: list[str] = []
|
||||
for formatter_name, required in expected_terms.items():
|
||||
if required and not _contains_all(prompts.get(formatter_name, ""), required):
|
||||
issues.append(f"{name}.{formatter_name}: missing_route_terms:{required}")
|
||||
return issues
|
||||
|
||||
|
||||
def _formatter_issues(
|
||||
name: str,
|
||||
formats: dict[str, Any],
|
||||
*,
|
||||
row: dict[str, Any] | None = None,
|
||||
expected_terms: dict[str, tuple[str, ...]] | None = None,
|
||||
is_pov: bool = False,
|
||||
) -> list[str]:
|
||||
issues: list[str] = []
|
||||
@@ -289,11 +383,13 @@ def _formatter_issues(
|
||||
if noise in lower_krea:
|
||||
issues.append(f"{name}.krea_prompt: hardcore_noise:{noise}")
|
||||
if isinstance(row, dict):
|
||||
family = str(row.get("action_family") or "").strip()
|
||||
sdxl_lower = f", {sdxl_prompt.lower()}, "
|
||||
for tag in INCOMPATIBLE_SDXL_TAGS.get(family, ()):
|
||||
if f", {tag}, " in sdxl_lower:
|
||||
issues.append(f"{name}.sdxl_prompt: incompatible_family_tag:{family}:{tag}")
|
||||
for scope, family in (("action", row.get("action_family")), ("position", row.get("position_family"))):
|
||||
route_key = f"{scope}:{str(family or '').strip()}"
|
||||
for tag in sdxl_tag_policy.INCOMPATIBLE_ROUTE_TAGS.get(route_key, ()):
|
||||
if f", {tag}, " in sdxl_lower:
|
||||
issues.append(f"{name}.sdxl_prompt: incompatible_family_tag:{route_key}:{tag}")
|
||||
issues.extend(_formatter_expectation_issues(name, formats, expected_terms))
|
||||
if is_pov:
|
||||
if "viewer" not in lower_krea or "first-person" not in lower_krea:
|
||||
issues.append(f"{name}.krea_prompt: pov_wording_missing")
|
||||
@@ -335,17 +431,36 @@ def _route_metadata_issues(name: str, row: dict[str, Any]) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def _route_expectation_issues(name: str, row: dict[str, Any], expected_route: dict[str, Any] | None) -> list[str]:
|
||||
if not expected_route:
|
||||
return []
|
||||
issues: list[str] = []
|
||||
for key in ("action_family", "position_family", "position_key"):
|
||||
expected = expected_route.get(key)
|
||||
if expected and row.get(key) != expected:
|
||||
issues.append(f"{name}: {key}_mismatch:{row.get(key)} != {expected}")
|
||||
for key, expected_values in (("position_keys", expected_route.get("position_keys") or ()),):
|
||||
current = set(str(value) for value in (row.get(key) or []))
|
||||
for value in expected_values:
|
||||
if str(value) not in current:
|
||||
issues.append(f"{name}: missing_{key}:{value}")
|
||||
return issues
|
||||
|
||||
|
||||
def _case_report(
|
||||
name: str,
|
||||
metadata: dict[str, Any],
|
||||
*,
|
||||
target: str,
|
||||
include_prompts: bool,
|
||||
expected_route: dict[str, Any] | None = None,
|
||||
expected_terms: dict[str, tuple[str, ...]] | None = None,
|
||||
is_pov: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
formats = _format_metadata(metadata, target)
|
||||
issues = _formatter_issues(name, formats, row=metadata, is_pov=is_pov)
|
||||
issues = _formatter_issues(name, formats, row=metadata, expected_terms=expected_terms, is_pov=is_pov)
|
||||
issues.extend(_route_metadata_issues(name, metadata))
|
||||
issues.extend(_route_expectation_issues(name, metadata, expected_route))
|
||||
if target == "softcore":
|
||||
issues.extend(_softcore_issues(f"{name}.krea_prompt", formats["krea"].get("krea_prompt")))
|
||||
report = {
|
||||
@@ -435,6 +550,36 @@ def _regular_single_case(seed: int) -> dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def _hardcore_single_case(seed: int, subcategory: str, focus: str, family: str) -> dict[str, Any]:
|
||||
return pb.build_prompt(
|
||||
category="Hardcore sexual poses",
|
||||
subcategory=subcategory,
|
||||
row_number=1,
|
||||
start_index=1,
|
||||
seed=seed,
|
||||
clothing="random",
|
||||
ethnicity="any",
|
||||
poses="random",
|
||||
backside_bias=0.0,
|
||||
figure="random",
|
||||
no_plus_women=False,
|
||||
no_black=False,
|
||||
minimal_clothing_ratio=-1,
|
||||
standard_pose_ratio=-1,
|
||||
trigger=TRIGGER,
|
||||
prepend_trigger_to_prompt=True,
|
||||
extra_positive="",
|
||||
extra_negative="",
|
||||
seed_config=pb.build_seed_lock_config_json(base_seed=seed),
|
||||
women_count=1,
|
||||
men_count=1,
|
||||
character_cast=_character_cast(),
|
||||
hardcore_position_config=_position_filter(focus, family, []),
|
||||
location_config=_coworking_location_config(),
|
||||
camera_config=_orbit_camera(horizontal_angle=35, vertical_angle=0, zoom=6.5),
|
||||
)
|
||||
|
||||
|
||||
def _insta_pair_case(seed: int, *, pov: bool, position: str, focus: str, family: str) -> dict[str, Any]:
|
||||
return pb.build_insta_of_pair(
|
||||
row_number=1,
|
||||
@@ -539,6 +684,23 @@ def run_simulation(seed: int = 3901, *, include_prompts: bool = False) -> dict[s
|
||||
cases: list[dict[str, Any]] = []
|
||||
regular = _regular_single_case(seed)
|
||||
cases.append(_case_report("regular.single.casual", regular, target="single", include_prompts=include_prompts))
|
||||
for offset, route_case in enumerate(HARDCORE_ROUTE_CASES, start=10):
|
||||
row = _hardcore_single_case(
|
||||
seed + offset,
|
||||
str(route_case["subcategory"]),
|
||||
str(route_case["focus"]),
|
||||
str(route_case["family"]),
|
||||
)
|
||||
cases.append(
|
||||
_case_report(
|
||||
str(route_case["name"]),
|
||||
row,
|
||||
target="single",
|
||||
include_prompts=include_prompts,
|
||||
expected_route=route_case.get("expected_route"),
|
||||
expected_terms=route_case.get("expected_terms"),
|
||||
)
|
||||
)
|
||||
penetration_pair = _insta_pair_case(seed + 1, pov=False, position="doggy", focus="penetration_only", family="penetration")
|
||||
cases.extend(_pair_reports("insta_pair.penetration", penetration_pair, include_prompts=include_prompts))
|
||||
pov_pair = _insta_pair_case(seed + 2, pov=True, position="penis_licking", focus="outercourse_only", family="outercourse")
|
||||
|
||||
+31
-1
@@ -4776,6 +4776,27 @@ def smoke_sdxl_tag_routes() -> None:
|
||||
_expect("penis licking" in outercourse_noise_tags, "SDXL outercourse row lost specific position key")
|
||||
_expect("oral sex" not in outercourse_noise_tags, "SDXL outercourse row kept incompatible oral tag")
|
||||
_expect("penetration" not in outercourse_noise_tags, "SDXL outercourse row kept incompatible penetration tag")
|
||||
stale_hardcore_pose_row = _fixture_hardcore_row(
|
||||
item="oral contact with mouth on the visible genitals in side-lying oral position",
|
||||
pose="kneeling and balancing a cucumber upright on an open palm held overhead",
|
||||
role_graph="Woman A lies on her side while Man A's mouth is pressed to her pussy.",
|
||||
source_role_graph="Woman A lies on her side while Man A's mouth is pressed to her pussy.",
|
||||
item_axis_values={
|
||||
"position": "side-lying oral position",
|
||||
"oral_act": "oral contact with mouth on the visible genitals",
|
||||
},
|
||||
action_family="oral",
|
||||
position_family="oral",
|
||||
position_key="side_lying",
|
||||
position_keys=["side_lying"],
|
||||
)
|
||||
stale_hardcore_pose_tags = sdxl_tag_routes.row_core_tags_result(
|
||||
sdxl_tag_routes.SDXLRowTagRequest(stale_hardcore_pose_row, 1.29),
|
||||
deps,
|
||||
).as_text()
|
||||
_expect("oral sex" in stale_hardcore_pose_tags, "SDXL hardcore route lost oral family tag")
|
||||
_expect("side lying" in stale_hardcore_pose_tags, "SDXL hardcore route lost structured position key")
|
||||
_expect("cucumber" not in stale_hardcore_pose_tags, "SDXL hardcore route leaked generic stale pose text")
|
||||
|
||||
pair = pb.build_insta_of_pair(
|
||||
row_number=1,
|
||||
@@ -7812,10 +7833,19 @@ def smoke_seed_config_policy() -> None:
|
||||
def smoke_prompt_route_simulation_policy() -> None:
|
||||
report = prompt_route_simulation.run_simulation(seed=3901, include_prompts=False)
|
||||
summary = report.get("summary") or {}
|
||||
_expect(summary.get("cases") == 5, "Prompt route simulation case count changed unexpectedly")
|
||||
_expect(summary.get("cases") == 11, "Prompt route simulation case count changed unexpectedly")
|
||||
_expect(summary.get("axis_checks") == 1, "Prompt route simulation lost axis check coverage")
|
||||
_expect(summary.get("issues") == 0, f"Prompt route simulation reported issues: {report.get('issues')}")
|
||||
cases = {case.get("name"): case for case in report.get("cases") or []}
|
||||
for route_name in (
|
||||
"hardcore.single.oral",
|
||||
"hardcore.single.manual",
|
||||
"hardcore.single.outercourse",
|
||||
"hardcore.single.foreplay",
|
||||
"hardcore.single.anal",
|
||||
"hardcore.single.climax",
|
||||
):
|
||||
_expect(route_name in cases, f"Prompt route simulation lost route family case {route_name}")
|
||||
pov_hard = cases.get("insta_pair.pov_outercourse.hardcore") or {}
|
||||
pov_summary = pov_hard.get("summary") or {}
|
||||
_expect(
|
||||
|
||||
Reference in New Issue
Block a user