from __future__ import annotations from collections import Counter from typing import Any try: from . import krea2_eval_log, krea2_pose_variant_catalog except ImportError: # Allows local smoke tests from the repository root. import krea2_eval_log import krea2_pose_variant_catalog def _coverage_state(status: str, accepted_count: int) -> str: if status == "proven" and accepted_count > 0: return "proven_with_evidence" if status == "proven": return "proven_missing_evidence" if status == "candidate" and accepted_count == 0: return "needs_fixed_seed_tests" if status == "unstable": return "needs_stronger_control" return "tracked" def coverage_rows() -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for variant in krea2_pose_variant_catalog.variants(): key = str(variant.get("key") or "") evidence = krea2_eval_log.entries_for_variant(key) accepted = [entry for entry in evidence if entry.get("result") == "accepted"] status = str(variant.get("status") or "") rows.append( { "key": key, "family": variant.get("family") or "", "action_family": variant.get("action_family") or "", "status": status, "coverage_state": _coverage_state(status, len(accepted)), "accepted_evidence_count": len(accepted), "total_evidence_count": len(evidence), "reference_count": len(variant.get("reference_images") or []), "guide_section": (variant.get("evidence") or {}).get("guide_section", ""), } ) return rows def coverage_summary() -> dict[str, Any]: rows = coverage_rows() status_counts = Counter(row.get("status") for row in rows) state_counts = Counter(row.get("coverage_state") for row in rows) return { "variant_count": len(rows), "status_counts": dict(status_counts), "coverage_state_counts": dict(state_counts), "variants_without_accepted_evidence": [ str(row.get("key")) for row in rows if int(row.get("accepted_evidence_count") or 0) == 0 ], "next_test_candidates": [ str(row.get("key")) for row in rows if row.get("coverage_state") in {"needs_fixed_seed_tests", "proven_missing_evidence"} ], } def next_test_plans() -> list[dict[str, Any]]: rows_by_key = {str(row.get("key")): row for row in coverage_rows()} plans: list[dict[str, Any]] = [] for key in coverage_summary()["next_test_candidates"]: variant = krea2_pose_variant_catalog.get_variant(key) if not variant: continue row = rows_by_key.get(key, {}) evidence = variant.get("evidence") or {} plans.append( { "key": key, "family": variant.get("family") or "", "action_family": variant.get("action_family") or "", "status": variant.get("status") or "", "coverage_state": row.get("coverage_state") or "", "canonical_geometry": variant.get("canonical_geometry") or "", "prompt_cues": list(variant.get("prompt_cues") or []), "avoid_cues": list(variant.get("avoid_cues") or []), "reference_paths": [str(path) for path in krea2_pose_variant_catalog.reference_paths(key)], "generator_hook": variant.get("generator_hook") or {}, "guide_section": evidence.get("guide_section") or "", "notes": evidence.get("notes") or "", } ) return plans def markdown_report() -> str: lines = [ "# Krea2 Pose Variant Coverage", "", "| Variant | Status | Evidence | State |", "| --- | --- | ---: | --- |", ] for row in coverage_rows(): lines.append( f"| {row['key']} | {row['status']} | {row['accepted_evidence_count']}/{row['total_evidence_count']} | {row['coverage_state']} |" ) summary = coverage_summary() if summary["next_test_candidates"]: lines.extend( [ "", "## Next Fixed-Seed Tests", "", *[f"- {key}" for key in summary["next_test_candidates"]], ] ) plans = next_test_plans() if plans: lines.extend(["", "## Next Test Plans"]) for plan in plans: lines.extend( [ "", f"### {plan['key']}", "", f"- Geometry: {plan['canonical_geometry']}", f"- References: {', '.join(plan['reference_paths']) or 'none'}", "- Prompt cues:", *[f" - {cue}" for cue in plan["prompt_cues"]], "- Avoid cues:", *[f" - {cue}" for cue in plan["avoid_cues"]], ] ) return "\n".join(lines)