Validate formatter route traces in simulation

This commit is contained in:
2026-06-27 18:57:40 +02:00
parent 7778a5f31f
commit 80e7e6e156
+78 -3
View File
@@ -386,11 +386,72 @@ def _formatter_expectation_issues(
return issues return issues
def _trace_dict(formatter_name: str, payload: dict[str, Any]) -> tuple[dict[str, Any], str]:
trace_text = str(payload.get("route_trace_json") or "")
if not trace_text:
return {}, f"{formatter_name}: missing_route_trace"
try:
trace = json.loads(trace_text)
except json.JSONDecodeError as exc:
return {}, f"{formatter_name}: invalid_route_trace:{exc}"
if not isinstance(trace, dict):
return {}, f"{formatter_name}: route_trace_not_object"
return trace, ""
def _formatter_trace_issues(
name: str,
formats: dict[str, Any],
*,
target: str,
) -> list[str]:
expected_formatters = {
"krea": "krea2",
"sdxl": "sdxl",
"caption": "caption",
}
issues: list[str] = []
for formatter_name, expected_formatter in expected_formatters.items():
payload = formats[formatter_name]
trace, error = _trace_dict(f"{name}.{formatter_name}", payload)
if error:
issues.append(error)
continue
method = str(payload.get("method") or "")
branch = str(trace.get("branch") or "")
if trace.get("formatter") != expected_formatter:
issues.append(f"{name}.{formatter_name}: trace_formatter_mismatch:{trace.get('formatter')} != {expected_formatter}")
if trace.get("method") != method:
issues.append(f"{name}.{formatter_name}: trace_method_mismatch:{trace.get('method')} != {method}")
if trace.get("target") != target:
issues.append(f"{name}.{formatter_name}: trace_target_mismatch:{trace.get('target')} != {target}")
if trace.get("input_hint") != "metadata_json":
issues.append(f"{name}.{formatter_name}: trace_input_hint_mismatch:{trace.get('input_hint')}")
if branch in ("", "fallback", "text"):
issues.append(f"{name}.{formatter_name}: trace_branch_not_metadata:{branch}")
if "metadata" not in method:
issues.append(f"{name}.{formatter_name}: trace_method_not_metadata:{method}")
if "insta_of_pair" in method:
if formatter_name in ("krea", "sdxl"):
if branch != "insta_of_pair":
issues.append(f"{name}.{formatter_name}: trace_pair_branch_mismatch:{branch}")
if trace.get("selected_side") != target:
issues.append(f"{name}.{formatter_name}: trace_selected_side_mismatch:{trace.get('selected_side')} != {target}")
elif "metadata(insta_of_pair)" not in method:
issues.append(f"{name}.{formatter_name}: trace_caption_pair_method_mismatch:{method}")
elif formatter_name == "krea" and not branch.startswith("metadata("):
issues.append(f"{name}.{formatter_name}: trace_krea_metadata_branch_mismatch:{branch}")
elif formatter_name in ("sdxl", "caption") and branch != "metadata":
issues.append(f"{name}.{formatter_name}: trace_metadata_branch_mismatch:{branch}")
return issues
def _formatter_issues( def _formatter_issues(
name: str, name: str,
formats: dict[str, Any], formats: dict[str, Any],
*, *,
row: dict[str, Any] | None = None, row: dict[str, Any] | None = None,
target: str,
expected_terms: dict[str, tuple[str, ...]] | None = None, expected_terms: dict[str, tuple[str, ...]] | None = None,
is_pov: bool = False, is_pov: bool = False,
) -> list[str]: ) -> list[str]:
@@ -416,6 +477,7 @@ def _formatter_issues(
): ):
if "metadata" not in str(method or ""): if "metadata" not in str(method or ""):
issues.append(f"{name}.{formatter_name}: not_metadata_route:{method}") issues.append(f"{name}.{formatter_name}: not_metadata_route:{method}")
issues.extend(_formatter_trace_issues(name, formats, target=target))
for label, value in ( for label, value in (
(f"{name}.krea_negative", krea.get("negative_prompt")), (f"{name}.krea_negative", krea.get("negative_prompt")),
@@ -508,7 +570,14 @@ def _case_report(
is_pov: bool = False, is_pov: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
formats = _format_metadata(metadata, target) formats = _format_metadata(metadata, target)
issues = _formatter_issues(name, formats, row=metadata, expected_terms=expected_terms, is_pov=is_pov) issues = _formatter_issues(
name,
formats,
row=metadata,
target=target,
expected_terms=expected_terms,
is_pov=is_pov,
)
issues.extend(_route_metadata_issues(name, metadata)) issues.extend(_route_metadata_issues(name, metadata))
issues.extend(_route_expectation_issues(name, metadata, expected_route)) issues.extend(_route_expectation_issues(name, metadata, expected_route))
if target == "softcore": if target == "softcore":
@@ -539,11 +608,17 @@ def _pair_reports(name: str, pair: dict[str, Any], *, include_prompts: bool) ->
hard_row = dict(pair.get("hardcore_row") or {}) hard_row = dict(pair.get("hardcore_row") or {})
soft_formats = _format_metadata(pair, "softcore") soft_formats = _format_metadata(pair, "softcore")
hard_formats = _format_metadata(pair, "hardcore") hard_formats = _format_metadata(pair, "hardcore")
soft_issues = _formatter_issues(f"{name}.softcore", soft_formats, row=soft_row) soft_issues = _formatter_issues(f"{name}.softcore", soft_formats, row=soft_row, target="softcore")
soft_issues.extend(_route_metadata_issues(f"{name}.softcore", soft_row)) soft_issues.extend(_route_metadata_issues(f"{name}.softcore", soft_row))
soft_issues.extend(_softcore_issues(f"{name}.softcore.krea_prompt", soft_formats["krea"].get("krea_prompt"))) soft_issues.extend(_softcore_issues(f"{name}.softcore.krea_prompt", soft_formats["krea"].get("krea_prompt")))
hard_is_pov = bool(hard_row.get("pov_character_labels")) hard_is_pov = bool(hard_row.get("pov_character_labels"))
hard_issues = _formatter_issues(f"{name}.hardcore", hard_formats, row=hard_row, is_pov=hard_is_pov) hard_issues = _formatter_issues(
f"{name}.hardcore",
hard_formats,
row=hard_row,
target="hardcore",
is_pov=hard_is_pov,
)
hard_issues.extend(_route_metadata_issues(f"{name}.hardcore", hard_row)) hard_issues.extend(_route_metadata_issues(f"{name}.hardcore", hard_row))
reports = [ reports = [
{ {