from __future__ import annotations from collections import Counter from pathlib import Path from typing import Any try: from . import krea2_eval_log, krea2_pose_variant_catalog except ImportError: # Allows local smoke tests from the repository root. import krea2_eval_log import krea2_pose_variant_catalog def _coverage_state(status: str, accepted_count: int) -> str: if status == "proven" and accepted_count > 0: return "proven_with_evidence" if status == "proven": return "proven_missing_evidence" if status == "candidate" and accepted_count == 0: return "needs_fixed_seed_tests" if status == "unstable": return "needs_stronger_control" return "tracked" def coverage_rows() -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for variant in krea2_pose_variant_catalog.variants(): key = str(variant.get("key") or "") evidence = krea2_eval_log.entries_for_variant(key) accepted = [entry for entry in evidence if entry.get("result") == "accepted"] status = str(variant.get("status") or "") rows.append( { "key": key, "family": variant.get("family") or "", "action_family": variant.get("action_family") or "", "status": status, "coverage_state": _coverage_state(status, len(accepted)), "accepted_evidence_count": len(accepted), "total_evidence_count": len(evidence), "reference_count": len(variant.get("reference_images") or []), "guide_section": (variant.get("evidence") or {}).get("guide_section", ""), } ) return rows def coverage_summary() -> dict[str, Any]: rows = coverage_rows() status_counts = Counter(row.get("status") for row in rows) state_counts = Counter(row.get("coverage_state") for row in rows) return { "variant_count": len(rows), "status_counts": dict(status_counts), "coverage_state_counts": dict(state_counts), "variants_without_accepted_evidence": [ str(row.get("key")) for row in rows if int(row.get("accepted_evidence_count") or 0) == 0 ], "next_test_candidates": [ str(row.get("key")) for row in rows if row.get("coverage_state") in {"needs_fixed_seed_tests", "proven_missing_evidence"} ], } def _catalog_atlas_root() -> Path: catalog = krea2_pose_variant_catalog.load_catalog() return Path(str(catalog.get("atlas_root") or "")) def _mapped_atlas_folders() -> dict[str, list[str]]: mapped: dict[str, list[str]] = {} for variant in krea2_pose_variant_catalog.variants(): key = str(variant.get("key") or "") for folder in variant.get("atlas_folders") or []: folder_name = str(folder) if not folder_name: continue mapped.setdefault(folder_name, []).append(key) return mapped def _is_background_or_control_folder(folder_name: str) -> bool: lower = folder_name.lower() return ( lower == "bg" or lower == "woman" or lower.endswith("_control") or lower.endswith("_bg") or lower.endswith("_control_bg") ) def atlas_folder_rows(atlas_root: str | Path | None = None) -> list[dict[str, Any]]: root = Path(atlas_root) if atlas_root is not None else _catalog_atlas_root() if not root.is_dir(): return [] mapped = _mapped_atlas_folders() rows: list[dict[str, Any]] = [] for folder in sorted(root.iterdir(), key=lambda path: path.name.lower()): if not folder.is_dir(): continue folder_name = folder.name if _is_background_or_control_folder(folder_name): continue image_count = sum(1 for _ in folder.glob("*.png")) if image_count <= 0: continue control_folder = root / f"{folder_name}_control" variant_keys = mapped.get(folder_name, []) if not variant_keys and not control_folder.is_dir(): continue rows.append( { "folder": folder_name, "image_count": image_count, "mapped": bool(variant_keys), "variant_keys": list(variant_keys), "control_folder": str(control_folder) if control_folder.is_dir() else "", } ) return rows def atlas_coverage_summary(atlas_root: str | Path | None = None) -> dict[str, Any]: rows = atlas_folder_rows(atlas_root=atlas_root) unmapped = [str(row.get("folder")) for row in rows if not row.get("mapped")] return { "pose_folder_count": len(rows), "mapped_folder_count": len(rows) - len(unmapped), "unmapped_folder_count": len(unmapped), "unmapped_folders": unmapped, } def next_test_plans() -> list[dict[str, Any]]: rows_by_key = {str(row.get("key")): row for row in coverage_rows()} plans: list[dict[str, Any]] = [] for key in coverage_summary()["next_test_candidates"]: variant = krea2_pose_variant_catalog.get_variant(key) if not variant: continue row = rows_by_key.get(key, {}) evidence = variant.get("evidence") or {} plans.append( { "key": key, "family": variant.get("family") or "", "action_family": variant.get("action_family") or "", "status": variant.get("status") or "", "coverage_state": row.get("coverage_state") or "", "canonical_geometry": variant.get("canonical_geometry") or "", "prompt_cues": list(variant.get("prompt_cues") or []), "avoid_cues": list(variant.get("avoid_cues") or []), "reference_paths": [str(path) for path in krea2_pose_variant_catalog.reference_paths(key)], "generator_hook": variant.get("generator_hook") or {}, "guide_section": evidence.get("guide_section") or "", "notes": evidence.get("notes") or "", } ) return plans def markdown_report(atlas_root: str | Path | None = None) -> str: lines = [ "# Krea2 Pose Variant Coverage", "", "| Variant | Status | Evidence | State |", "| --- | --- | ---: | --- |", ] for row in coverage_rows(): lines.append( f"| {row['key']} | {row['status']} | {row['accepted_evidence_count']}/{row['total_evidence_count']} | {row['coverage_state']} |" ) summary = coverage_summary() if summary["next_test_candidates"]: lines.extend( [ "", "## Next Fixed-Seed Tests", "", *[f"- {key}" for key in summary["next_test_candidates"]], ] ) plans = next_test_plans() if plans: lines.extend(["", "## Next Test Plans"]) for plan in plans: lines.extend( [ "", f"### {plan['key']}", "", f"- Geometry: {plan['canonical_geometry']}", f"- References: {', '.join(plan['reference_paths']) or 'none'}", "- Prompt cues:", *[f" - {cue}" for cue in plan["prompt_cues"]], "- Avoid cues:", *[f" - {cue}" for cue in plan["avoid_cues"]], ] ) atlas_summary = atlas_coverage_summary(atlas_root=atlas_root) if atlas_summary["pose_folder_count"]: unmapped = atlas_summary["unmapped_folders"] lines.extend( [ "", "## Atlas Folder Coverage", "", f"- Pose folders: {atlas_summary['pose_folder_count']}", f"- Mapped folders: {atlas_summary['mapped_folder_count']}", f"- Unmapped folders: {atlas_summary['unmapped_folder_count']}", ] ) if unmapped: lines.extend(["", "Unmapped atlas folders:", *[f"- {folder}" for folder in unmapped]]) return "\n".join(lines)