Validate formatter route traces in simulation
This commit is contained in:
@@ -386,11 +386,72 @@ def _formatter_expectation_issues(
|
||||
return issues
|
||||
|
||||
|
||||
def _trace_dict(formatter_name: str, payload: dict[str, Any]) -> tuple[dict[str, Any], str]:
|
||||
trace_text = str(payload.get("route_trace_json") or "")
|
||||
if not trace_text:
|
||||
return {}, f"{formatter_name}: missing_route_trace"
|
||||
try:
|
||||
trace = json.loads(trace_text)
|
||||
except json.JSONDecodeError as exc:
|
||||
return {}, f"{formatter_name}: invalid_route_trace:{exc}"
|
||||
if not isinstance(trace, dict):
|
||||
return {}, f"{formatter_name}: route_trace_not_object"
|
||||
return trace, ""
|
||||
|
||||
|
||||
def _formatter_trace_issues(
|
||||
name: str,
|
||||
formats: dict[str, Any],
|
||||
*,
|
||||
target: str,
|
||||
) -> list[str]:
|
||||
expected_formatters = {
|
||||
"krea": "krea2",
|
||||
"sdxl": "sdxl",
|
||||
"caption": "caption",
|
||||
}
|
||||
issues: list[str] = []
|
||||
for formatter_name, expected_formatter in expected_formatters.items():
|
||||
payload = formats[formatter_name]
|
||||
trace, error = _trace_dict(f"{name}.{formatter_name}", payload)
|
||||
if error:
|
||||
issues.append(error)
|
||||
continue
|
||||
method = str(payload.get("method") or "")
|
||||
branch = str(trace.get("branch") or "")
|
||||
if trace.get("formatter") != expected_formatter:
|
||||
issues.append(f"{name}.{formatter_name}: trace_formatter_mismatch:{trace.get('formatter')} != {expected_formatter}")
|
||||
if trace.get("method") != method:
|
||||
issues.append(f"{name}.{formatter_name}: trace_method_mismatch:{trace.get('method')} != {method}")
|
||||
if trace.get("target") != target:
|
||||
issues.append(f"{name}.{formatter_name}: trace_target_mismatch:{trace.get('target')} != {target}")
|
||||
if trace.get("input_hint") != "metadata_json":
|
||||
issues.append(f"{name}.{formatter_name}: trace_input_hint_mismatch:{trace.get('input_hint')}")
|
||||
if branch in ("", "fallback", "text"):
|
||||
issues.append(f"{name}.{formatter_name}: trace_branch_not_metadata:{branch}")
|
||||
if "metadata" not in method:
|
||||
issues.append(f"{name}.{formatter_name}: trace_method_not_metadata:{method}")
|
||||
if "insta_of_pair" in method:
|
||||
if formatter_name in ("krea", "sdxl"):
|
||||
if branch != "insta_of_pair":
|
||||
issues.append(f"{name}.{formatter_name}: trace_pair_branch_mismatch:{branch}")
|
||||
if trace.get("selected_side") != target:
|
||||
issues.append(f"{name}.{formatter_name}: trace_selected_side_mismatch:{trace.get('selected_side')} != {target}")
|
||||
elif "metadata(insta_of_pair)" not in method:
|
||||
issues.append(f"{name}.{formatter_name}: trace_caption_pair_method_mismatch:{method}")
|
||||
elif formatter_name == "krea" and not branch.startswith("metadata("):
|
||||
issues.append(f"{name}.{formatter_name}: trace_krea_metadata_branch_mismatch:{branch}")
|
||||
elif formatter_name in ("sdxl", "caption") and branch != "metadata":
|
||||
issues.append(f"{name}.{formatter_name}: trace_metadata_branch_mismatch:{branch}")
|
||||
return issues
|
||||
|
||||
|
||||
def _formatter_issues(
|
||||
name: str,
|
||||
formats: dict[str, Any],
|
||||
*,
|
||||
row: dict[str, Any] | None = None,
|
||||
target: str,
|
||||
expected_terms: dict[str, tuple[str, ...]] | None = None,
|
||||
is_pov: bool = False,
|
||||
) -> list[str]:
|
||||
@@ -416,6 +477,7 @@ def _formatter_issues(
|
||||
):
|
||||
if "metadata" not in str(method or ""):
|
||||
issues.append(f"{name}.{formatter_name}: not_metadata_route:{method}")
|
||||
issues.extend(_formatter_trace_issues(name, formats, target=target))
|
||||
|
||||
for label, value in (
|
||||
(f"{name}.krea_negative", krea.get("negative_prompt")),
|
||||
@@ -508,7 +570,14 @@ def _case_report(
|
||||
is_pov: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
formats = _format_metadata(metadata, target)
|
||||
issues = _formatter_issues(name, formats, row=metadata, expected_terms=expected_terms, is_pov=is_pov)
|
||||
issues = _formatter_issues(
|
||||
name,
|
||||
formats,
|
||||
row=metadata,
|
||||
target=target,
|
||||
expected_terms=expected_terms,
|
||||
is_pov=is_pov,
|
||||
)
|
||||
issues.extend(_route_metadata_issues(name, metadata))
|
||||
issues.extend(_route_expectation_issues(name, metadata, expected_route))
|
||||
if target == "softcore":
|
||||
@@ -539,11 +608,17 @@ def _pair_reports(name: str, pair: dict[str, Any], *, include_prompts: bool) ->
|
||||
hard_row = dict(pair.get("hardcore_row") or {})
|
||||
soft_formats = _format_metadata(pair, "softcore")
|
||||
hard_formats = _format_metadata(pair, "hardcore")
|
||||
soft_issues = _formatter_issues(f"{name}.softcore", soft_formats, row=soft_row)
|
||||
soft_issues = _formatter_issues(f"{name}.softcore", soft_formats, row=soft_row, target="softcore")
|
||||
soft_issues.extend(_route_metadata_issues(f"{name}.softcore", soft_row))
|
||||
soft_issues.extend(_softcore_issues(f"{name}.softcore.krea_prompt", soft_formats["krea"].get("krea_prompt")))
|
||||
hard_is_pov = bool(hard_row.get("pov_character_labels"))
|
||||
hard_issues = _formatter_issues(f"{name}.hardcore", hard_formats, row=hard_row, is_pov=hard_is_pov)
|
||||
hard_issues = _formatter_issues(
|
||||
f"{name}.hardcore",
|
||||
hard_formats,
|
||||
row=hard_row,
|
||||
target="hardcore",
|
||||
is_pov=hard_is_pov,
|
||||
)
|
||||
hard_issues.extend(_route_metadata_issues(f"{name}.hardcore", hard_row))
|
||||
reports = [
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user