From 80e7e6e1565b30a9f4212e29a4cd7b6ede899e43 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 27 Jun 2026 18:57:40 +0200 Subject: [PATCH] Validate formatter route traces in simulation --- tools/prompt_route_simulation.py | 81 ++++++++++++++++++++++++++++++-- 1 file changed, 78 insertions(+), 3 deletions(-) diff --git a/tools/prompt_route_simulation.py b/tools/prompt_route_simulation.py index bc1f45f..1f53185 100644 --- a/tools/prompt_route_simulation.py +++ b/tools/prompt_route_simulation.py @@ -386,11 +386,72 @@ def _formatter_expectation_issues( return issues +def _trace_dict(formatter_name: str, payload: dict[str, Any]) -> tuple[dict[str, Any], str]: + trace_text = str(payload.get("route_trace_json") or "") + if not trace_text: + return {}, f"{formatter_name}: missing_route_trace" + try: + trace = json.loads(trace_text) + except json.JSONDecodeError as exc: + return {}, f"{formatter_name}: invalid_route_trace:{exc}" + if not isinstance(trace, dict): + return {}, f"{formatter_name}: route_trace_not_object" + return trace, "" + + +def _formatter_trace_issues( + name: str, + formats: dict[str, Any], + *, + target: str, +) -> list[str]: + expected_formatters = { + "krea": "krea2", + "sdxl": "sdxl", + "caption": "caption", + } + issues: list[str] = [] + for formatter_name, expected_formatter in expected_formatters.items(): + payload = formats[formatter_name] + trace, error = _trace_dict(f"{name}.{formatter_name}", payload) + if error: + issues.append(error) + continue + method = str(payload.get("method") or "") + branch = str(trace.get("branch") or "") + if trace.get("formatter") != expected_formatter: + issues.append(f"{name}.{formatter_name}: trace_formatter_mismatch:{trace.get('formatter')} != {expected_formatter}") + if trace.get("method") != method: + issues.append(f"{name}.{formatter_name}: trace_method_mismatch:{trace.get('method')} != {method}") + if trace.get("target") != target: + issues.append(f"{name}.{formatter_name}: trace_target_mismatch:{trace.get('target')} != {target}") + if trace.get("input_hint") != "metadata_json": + issues.append(f"{name}.{formatter_name}: trace_input_hint_mismatch:{trace.get('input_hint')}") + if branch in ("", "fallback", "text"): + issues.append(f"{name}.{formatter_name}: trace_branch_not_metadata:{branch}") + if "metadata" not in method: + issues.append(f"{name}.{formatter_name}: trace_method_not_metadata:{method}") + if "insta_of_pair" in method: + if formatter_name in ("krea", "sdxl"): + if branch != "insta_of_pair": + issues.append(f"{name}.{formatter_name}: trace_pair_branch_mismatch:{branch}") + if trace.get("selected_side") != target: + issues.append(f"{name}.{formatter_name}: trace_selected_side_mismatch:{trace.get('selected_side')} != {target}") + elif "metadata(insta_of_pair)" not in method: + issues.append(f"{name}.{formatter_name}: trace_caption_pair_method_mismatch:{method}") + elif formatter_name == "krea" and not branch.startswith("metadata("): + issues.append(f"{name}.{formatter_name}: trace_krea_metadata_branch_mismatch:{branch}") + elif formatter_name in ("sdxl", "caption") and branch != "metadata": + issues.append(f"{name}.{formatter_name}: trace_metadata_branch_mismatch:{branch}") + return issues + + def _formatter_issues( name: str, formats: dict[str, Any], *, row: dict[str, Any] | None = None, + target: str, expected_terms: dict[str, tuple[str, ...]] | None = None, is_pov: bool = False, ) -> list[str]: @@ -416,6 +477,7 @@ def _formatter_issues( ): if "metadata" not in str(method or ""): issues.append(f"{name}.{formatter_name}: not_metadata_route:{method}") + issues.extend(_formatter_trace_issues(name, formats, target=target)) for label, value in ( (f"{name}.krea_negative", krea.get("negative_prompt")), @@ -508,7 +570,14 @@ def _case_report( is_pov: bool = False, ) -> dict[str, Any]: formats = _format_metadata(metadata, target) - issues = _formatter_issues(name, formats, row=metadata, expected_terms=expected_terms, is_pov=is_pov) + issues = _formatter_issues( + name, + formats, + row=metadata, + target=target, + expected_terms=expected_terms, + is_pov=is_pov, + ) issues.extend(_route_metadata_issues(name, metadata)) issues.extend(_route_expectation_issues(name, metadata, expected_route)) if target == "softcore": @@ -539,11 +608,17 @@ def _pair_reports(name: str, pair: dict[str, Any], *, include_prompts: bool) -> hard_row = dict(pair.get("hardcore_row") or {}) soft_formats = _format_metadata(pair, "softcore") hard_formats = _format_metadata(pair, "hardcore") - soft_issues = _formatter_issues(f"{name}.softcore", soft_formats, row=soft_row) + soft_issues = _formatter_issues(f"{name}.softcore", soft_formats, row=soft_row, target="softcore") soft_issues.extend(_route_metadata_issues(f"{name}.softcore", soft_row)) soft_issues.extend(_softcore_issues(f"{name}.softcore.krea_prompt", soft_formats["krea"].get("krea_prompt"))) hard_is_pov = bool(hard_row.get("pov_character_labels")) - hard_issues = _formatter_issues(f"{name}.hardcore", hard_formats, row=hard_row, is_pov=hard_is_pov) + hard_issues = _formatter_issues( + f"{name}.hardcore", + hard_formats, + row=hard_row, + target="hardcore", + is_pov=hard_is_pov, + ) hard_issues.extend(_route_metadata_issues(f"{name}.hardcore", hard_row)) reports = [ {