Validate formatter route traces in simulation
This commit is contained in:
@@ -386,11 +386,72 @@ def _formatter_expectation_issues(
|
|||||||
return issues
|
return issues
|
||||||
|
|
||||||
|
|
||||||
|
def _trace_dict(formatter_name: str, payload: dict[str, Any]) -> tuple[dict[str, Any], str]:
|
||||||
|
trace_text = str(payload.get("route_trace_json") or "")
|
||||||
|
if not trace_text:
|
||||||
|
return {}, f"{formatter_name}: missing_route_trace"
|
||||||
|
try:
|
||||||
|
trace = json.loads(trace_text)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
return {}, f"{formatter_name}: invalid_route_trace:{exc}"
|
||||||
|
if not isinstance(trace, dict):
|
||||||
|
return {}, f"{formatter_name}: route_trace_not_object"
|
||||||
|
return trace, ""
|
||||||
|
|
||||||
|
|
||||||
|
def _formatter_trace_issues(
|
||||||
|
name: str,
|
||||||
|
formats: dict[str, Any],
|
||||||
|
*,
|
||||||
|
target: str,
|
||||||
|
) -> list[str]:
|
||||||
|
expected_formatters = {
|
||||||
|
"krea": "krea2",
|
||||||
|
"sdxl": "sdxl",
|
||||||
|
"caption": "caption",
|
||||||
|
}
|
||||||
|
issues: list[str] = []
|
||||||
|
for formatter_name, expected_formatter in expected_formatters.items():
|
||||||
|
payload = formats[formatter_name]
|
||||||
|
trace, error = _trace_dict(f"{name}.{formatter_name}", payload)
|
||||||
|
if error:
|
||||||
|
issues.append(error)
|
||||||
|
continue
|
||||||
|
method = str(payload.get("method") or "")
|
||||||
|
branch = str(trace.get("branch") or "")
|
||||||
|
if trace.get("formatter") != expected_formatter:
|
||||||
|
issues.append(f"{name}.{formatter_name}: trace_formatter_mismatch:{trace.get('formatter')} != {expected_formatter}")
|
||||||
|
if trace.get("method") != method:
|
||||||
|
issues.append(f"{name}.{formatter_name}: trace_method_mismatch:{trace.get('method')} != {method}")
|
||||||
|
if trace.get("target") != target:
|
||||||
|
issues.append(f"{name}.{formatter_name}: trace_target_mismatch:{trace.get('target')} != {target}")
|
||||||
|
if trace.get("input_hint") != "metadata_json":
|
||||||
|
issues.append(f"{name}.{formatter_name}: trace_input_hint_mismatch:{trace.get('input_hint')}")
|
||||||
|
if branch in ("", "fallback", "text"):
|
||||||
|
issues.append(f"{name}.{formatter_name}: trace_branch_not_metadata:{branch}")
|
||||||
|
if "metadata" not in method:
|
||||||
|
issues.append(f"{name}.{formatter_name}: trace_method_not_metadata:{method}")
|
||||||
|
if "insta_of_pair" in method:
|
||||||
|
if formatter_name in ("krea", "sdxl"):
|
||||||
|
if branch != "insta_of_pair":
|
||||||
|
issues.append(f"{name}.{formatter_name}: trace_pair_branch_mismatch:{branch}")
|
||||||
|
if trace.get("selected_side") != target:
|
||||||
|
issues.append(f"{name}.{formatter_name}: trace_selected_side_mismatch:{trace.get('selected_side')} != {target}")
|
||||||
|
elif "metadata(insta_of_pair)" not in method:
|
||||||
|
issues.append(f"{name}.{formatter_name}: trace_caption_pair_method_mismatch:{method}")
|
||||||
|
elif formatter_name == "krea" and not branch.startswith("metadata("):
|
||||||
|
issues.append(f"{name}.{formatter_name}: trace_krea_metadata_branch_mismatch:{branch}")
|
||||||
|
elif formatter_name in ("sdxl", "caption") and branch != "metadata":
|
||||||
|
issues.append(f"{name}.{formatter_name}: trace_metadata_branch_mismatch:{branch}")
|
||||||
|
return issues
|
||||||
|
|
||||||
|
|
||||||
def _formatter_issues(
|
def _formatter_issues(
|
||||||
name: str,
|
name: str,
|
||||||
formats: dict[str, Any],
|
formats: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
row: dict[str, Any] | None = None,
|
row: dict[str, Any] | None = None,
|
||||||
|
target: str,
|
||||||
expected_terms: dict[str, tuple[str, ...]] | None = None,
|
expected_terms: dict[str, tuple[str, ...]] | None = None,
|
||||||
is_pov: bool = False,
|
is_pov: bool = False,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
@@ -416,6 +477,7 @@ def _formatter_issues(
|
|||||||
):
|
):
|
||||||
if "metadata" not in str(method or ""):
|
if "metadata" not in str(method or ""):
|
||||||
issues.append(f"{name}.{formatter_name}: not_metadata_route:{method}")
|
issues.append(f"{name}.{formatter_name}: not_metadata_route:{method}")
|
||||||
|
issues.extend(_formatter_trace_issues(name, formats, target=target))
|
||||||
|
|
||||||
for label, value in (
|
for label, value in (
|
||||||
(f"{name}.krea_negative", krea.get("negative_prompt")),
|
(f"{name}.krea_negative", krea.get("negative_prompt")),
|
||||||
@@ -508,7 +570,14 @@ def _case_report(
|
|||||||
is_pov: bool = False,
|
is_pov: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
formats = _format_metadata(metadata, target)
|
formats = _format_metadata(metadata, target)
|
||||||
issues = _formatter_issues(name, formats, row=metadata, expected_terms=expected_terms, is_pov=is_pov)
|
issues = _formatter_issues(
|
||||||
|
name,
|
||||||
|
formats,
|
||||||
|
row=metadata,
|
||||||
|
target=target,
|
||||||
|
expected_terms=expected_terms,
|
||||||
|
is_pov=is_pov,
|
||||||
|
)
|
||||||
issues.extend(_route_metadata_issues(name, metadata))
|
issues.extend(_route_metadata_issues(name, metadata))
|
||||||
issues.extend(_route_expectation_issues(name, metadata, expected_route))
|
issues.extend(_route_expectation_issues(name, metadata, expected_route))
|
||||||
if target == "softcore":
|
if target == "softcore":
|
||||||
@@ -539,11 +608,17 @@ def _pair_reports(name: str, pair: dict[str, Any], *, include_prompts: bool) ->
|
|||||||
hard_row = dict(pair.get("hardcore_row") or {})
|
hard_row = dict(pair.get("hardcore_row") or {})
|
||||||
soft_formats = _format_metadata(pair, "softcore")
|
soft_formats = _format_metadata(pair, "softcore")
|
||||||
hard_formats = _format_metadata(pair, "hardcore")
|
hard_formats = _format_metadata(pair, "hardcore")
|
||||||
soft_issues = _formatter_issues(f"{name}.softcore", soft_formats, row=soft_row)
|
soft_issues = _formatter_issues(f"{name}.softcore", soft_formats, row=soft_row, target="softcore")
|
||||||
soft_issues.extend(_route_metadata_issues(f"{name}.softcore", soft_row))
|
soft_issues.extend(_route_metadata_issues(f"{name}.softcore", soft_row))
|
||||||
soft_issues.extend(_softcore_issues(f"{name}.softcore.krea_prompt", soft_formats["krea"].get("krea_prompt")))
|
soft_issues.extend(_softcore_issues(f"{name}.softcore.krea_prompt", soft_formats["krea"].get("krea_prompt")))
|
||||||
hard_is_pov = bool(hard_row.get("pov_character_labels"))
|
hard_is_pov = bool(hard_row.get("pov_character_labels"))
|
||||||
hard_issues = _formatter_issues(f"{name}.hardcore", hard_formats, row=hard_row, is_pov=hard_is_pov)
|
hard_issues = _formatter_issues(
|
||||||
|
f"{name}.hardcore",
|
||||||
|
hard_formats,
|
||||||
|
row=hard_row,
|
||||||
|
target="hardcore",
|
||||||
|
is_pov=hard_is_pov,
|
||||||
|
)
|
||||||
hard_issues.extend(_route_metadata_issues(f"{name}.hardcore", hard_row))
|
hard_issues.extend(_route_metadata_issues(f"{name}.hardcore", hard_row))
|
||||||
reports = [
|
reports = [
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user