Files
ComfyUI-Ethanfel-Prompt-Bui…/krea2_tuning_report.py
T

427 lines
18 KiB
Python

from __future__ import annotations
from collections import Counter
from pathlib import Path
import sys
from typing import Any
try:
from . import krea2_eval_log, krea2_pose_variant_catalog
except ImportError: # Allows local smoke tests from the repository root.
import krea2_eval_log
import krea2_pose_variant_catalog
def _coverage_state(status: str, accepted_count: int) -> str:
if status == "proven" and accepted_count > 0:
return "proven_with_evidence"
if status == "proven":
return "proven_missing_evidence"
if status == "candidate" and accepted_count == 0:
return "needs_fixed_seed_tests"
if status == "unstable":
return "needs_stronger_control"
return "tracked"
def _latest_evidence(entries: list[dict[str, Any]], *, result: str | None = None) -> dict[str, Any]:
filtered = [entry for entry in entries if result is None or entry.get("result") == result]
if not filtered:
return {}
entry = filtered[-1]
return {
"id": entry.get("id") or "",
"seed": entry.get("seed"),
"generator_seed": entry.get("generator_seed"),
"result": entry.get("result") or "",
"decision": entry.get("decision") or "",
"baseline_prompt_summary": entry.get("baseline_prompt_summary") or "",
"candidate_prompt_summary": entry.get("candidate_prompt_summary") or "",
"observation": entry.get("observation") or "",
"needs_expansion": bool(entry.get("needs_expansion")),
"commit": entry.get("commit") or "",
}
def coverage_rows() -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for variant in krea2_pose_variant_catalog.variants():
key = str(variant.get("key") or "")
evidence = krea2_eval_log.entries_for_variant(key)
accepted = [entry for entry in evidence if entry.get("result") == "accepted"]
status = str(variant.get("status") or "")
rows.append(
{
"key": key,
"family": variant.get("family") or "",
"action_family": variant.get("action_family") or "",
"status": status,
"difficulty": variant.get("difficulty") or "",
"priority": variant.get("priority") or "",
"control_requirement": variant.get("control_requirement") or "",
"coverage_state": _coverage_state(status, len(accepted)),
"accepted_evidence_count": len(accepted),
"total_evidence_count": len(evidence),
"latest_evidence": _latest_evidence(evidence),
"latest_accepted_evidence": _latest_evidence(evidence, result="accepted"),
"reference_count": len(variant.get("reference_images") or []),
"guide_section": (variant.get("evidence") or {}).get("guide_section", ""),
}
)
return rows
def coverage_summary() -> dict[str, Any]:
rows = coverage_rows()
status_counts = Counter(row.get("status") for row in rows)
state_counts = Counter(row.get("coverage_state") for row in rows)
return {
"variant_count": len(rows),
"status_counts": dict(status_counts),
"coverage_state_counts": dict(state_counts),
"variants_without_accepted_evidence": [
str(row.get("key"))
for row in rows
if int(row.get("accepted_evidence_count") or 0) == 0
],
"next_test_candidates": [
str(row.get("key"))
for row in rows
if row.get("coverage_state") in {"needs_fixed_seed_tests", "proven_missing_evidence"}
],
"stronger_control_cases": [
str(row.get("key"))
for row in rows
if row.get("coverage_state") == "needs_stronger_control"
],
}
def _catalog_atlas_root() -> Path:
catalog = krea2_pose_variant_catalog.load_catalog()
return Path(str(catalog.get("atlas_root") or ""))
def _mapped_atlas_folders() -> dict[str, list[str]]:
mapped: dict[str, list[str]] = {}
for variant in krea2_pose_variant_catalog.variants():
key = str(variant.get("key") or "")
for folder in variant.get("atlas_folders") or []:
folder_name = str(folder)
if not folder_name:
continue
mapped.setdefault(folder_name, []).append(key)
return mapped
def _is_background_or_control_folder(folder_name: str) -> bool:
lower = folder_name.lower()
return (
lower == "bg"
or lower == "woman"
or lower.endswith("_control")
or lower.endswith("_bg")
or lower.endswith("_control_bg")
)
def _sample_pngs(folder: Path, limit: int) -> list[str]:
if not folder.is_dir() or limit <= 0:
return []
return [str(path) for path in sorted(folder.glob("*.png"), key=lambda path: path.name.lower())[:limit]]
def atlas_folder_rows(atlas_root: str | Path | None = None) -> list[dict[str, Any]]:
root = Path(atlas_root) if atlas_root is not None else _catalog_atlas_root()
if not root.is_dir():
return []
mapped = _mapped_atlas_folders()
rows: list[dict[str, Any]] = []
for folder in sorted(root.iterdir(), key=lambda path: path.name.lower()):
if not folder.is_dir():
continue
folder_name = folder.name
if _is_background_or_control_folder(folder_name):
continue
image_count = sum(1 for _ in folder.glob("*.png"))
if image_count <= 0:
continue
control_folder = root / f"{folder_name}_control"
variant_keys = mapped.get(folder_name, [])
if not variant_keys and not control_folder.is_dir():
continue
rows.append(
{
"folder": folder_name,
"image_count": image_count,
"mapped": bool(variant_keys),
"variant_keys": list(variant_keys),
"control_folder": str(control_folder) if control_folder.is_dir() else "",
}
)
return rows
def atlas_coverage_summary(atlas_root: str | Path | None = None) -> dict[str, Any]:
rows = atlas_folder_rows(atlas_root=atlas_root)
unmapped = [str(row.get("folder")) for row in rows if not row.get("mapped")]
return {
"pose_folder_count": len(rows),
"mapped_folder_count": len(rows) - len(unmapped),
"unmapped_folder_count": len(unmapped),
"unmapped_folders": unmapped,
}
def _suggested_variant_key(folder_name: str) -> str:
if folder_name.lower() == "ready":
return "pov_ejaculation_aftermath_open_thigh_candidate"
normalized = "".join(char if char.isalnum() else "_" for char in folder_name.lower()).strip("_")
while "__" in normalized:
normalized = normalized.replace("__", "_")
return f"pov_{normalized}_candidate" if normalized else "pov_unmapped_candidate"
def atlas_gap_plans(atlas_root: str | Path | None = None, sample_limit: int = 3) -> list[dict[str, Any]]:
root = Path(atlas_root) if atlas_root is not None else _catalog_atlas_root()
plans: list[dict[str, Any]] = []
for row in atlas_folder_rows(atlas_root=root):
if row.get("mapped"):
continue
folder_name = str(row.get("folder") or "")
folder_path = root / folder_name
control_folder = Path(str(row.get("control_folder") or ""))
plans.append(
{
"folder": folder_name,
"suggested_variant_key": _suggested_variant_key(folder_name),
"image_count": int(row.get("image_count") or 0),
"sample_images": _sample_pngs(folder_path, sample_limit),
"control_images": _sample_pngs(control_folder, sample_limit),
}
)
return plans
def next_test_plans() -> list[dict[str, Any]]:
rows_by_key = {str(row.get("key")): row for row in coverage_rows()}
plans: list[dict[str, Any]] = []
for key in coverage_summary()["next_test_candidates"]:
variant = krea2_pose_variant_catalog.get_variant(key)
if not variant:
continue
row = rows_by_key.get(key, {})
evidence = variant.get("evidence") or {}
plans.append(
{
"key": key,
"family": variant.get("family") or "",
"action_family": variant.get("action_family") or "",
"status": variant.get("status") or "",
"coverage_state": row.get("coverage_state") or "",
"canonical_geometry": variant.get("canonical_geometry") or "",
"prompt_cues": list(variant.get("prompt_cues") or []),
"avoid_cues": list(variant.get("avoid_cues") or []),
"reference_paths": [str(path) for path in krea2_pose_variant_catalog.reference_paths(key)],
"generator_hook": variant.get("generator_hook") or {},
"guide_section": evidence.get("guide_section") or "",
"notes": evidence.get("notes") or "",
}
)
return plans
def guide_expansion_plans() -> list[dict[str, Any]]:
plans: list[dict[str, Any]] = []
for row in coverage_rows():
latest_accepted = row.get("latest_accepted_evidence") or {}
decision = str(latest_accepted.get("decision") or "")
if decision not in {"prompt_guide_rule", "needs_more_tests"} and not (
decision == "provisional_generator_patch" and latest_accepted.get("needs_expansion")
):
continue
key = str(row.get("key") or "")
variant = krea2_pose_variant_catalog.get_variant(key)
if not variant:
continue
evidence = variant.get("evidence") or {}
plans.append(
{
"key": key,
"family": variant.get("family") or "",
"action_family": variant.get("action_family") or "",
"status": variant.get("status") or "",
"coverage_state": row.get("coverage_state") or "",
"target": "multi_seed_multi_woman_matrix",
"latest_accepted_id": latest_accepted.get("id") or "",
"latest_accepted_seed": latest_accepted.get("seed"),
"latest_accepted_decision": decision,
"accepted_evidence_count": row.get("accepted_evidence_count") or 0,
"total_evidence_count": row.get("total_evidence_count") or 0,
"canonical_geometry": variant.get("canonical_geometry") or "",
"prompt_cues": list(variant.get("prompt_cues") or []),
"avoid_cues": list(variant.get("avoid_cues") or []),
"reference_paths": [str(path) for path in krea2_pose_variant_catalog.reference_paths(key)],
"generator_hook": variant.get("generator_hook") or {},
"guide_section": evidence.get("guide_section") or "",
"notes": evidence.get("notes") or "",
}
)
return plans
def next_eval_template_commands(*, seed_token: str = "<fixed_seed>") -> list[dict[str, str]]:
commands: list[dict[str, str]] = []
for plan in next_test_plans():
key = str(plan.get("key") or "")
if not key:
continue
commands.append(
{
"key": key,
"command": f"python tools/krea2_record_eval.py --print-template --variant-key {key} --seed {seed_token}",
}
)
return commands
def markdown_report(atlas_root: str | Path | None = None) -> str:
lines = [
"# Krea2 Pose Variant Coverage",
"",
"| Variant | Status | Evidence | State |",
"| --- | --- | ---: | --- |",
]
for row in coverage_rows():
lines.append(
f"| {row['key']} | {row['status']} | {row['accepted_evidence_count']}/{row['total_evidence_count']} | {row['coverage_state']} |"
)
evidence_rows = [row for row in coverage_rows() if row.get("latest_evidence")]
if evidence_rows:
lines.extend(["", "## Latest Evidence", ""])
for row in evidence_rows:
evidence = row.get("latest_evidence") or {}
seed = evidence.get("seed")
seed_text = f"seed {seed}" if isinstance(seed, int) else "seed unknown"
generator_seed = evidence.get("generator_seed")
generator_seed_text = f", generator seed {generator_seed}" if isinstance(generator_seed, int) else ""
commit = evidence.get("commit") or "uncommitted"
lines.append(
f"- {row['key']}: {evidence.get('id') or 'unnamed'} ({evidence.get('result') or 'unknown'}, {seed_text}{generator_seed_text}, {evidence.get('decision') or 'unknown'}, commit {commit})"
)
if evidence.get("candidate_prompt_summary"):
lines.append(f" Candidate: {evidence['candidate_prompt_summary']}")
if evidence.get("observation"):
lines.append(f" Observation: {evidence['observation']}")
accepted = row.get("latest_accepted_evidence") or {}
if accepted and accepted.get("id") != evidence.get("id"):
accepted_seed = accepted.get("seed")
accepted_seed_text = f"seed {accepted_seed}" if isinstance(accepted_seed, int) else "seed unknown"
accepted_generator_seed = accepted.get("generator_seed")
accepted_generator_seed_text = (
f", generator seed {accepted_generator_seed}" if isinstance(accepted_generator_seed, int) else ""
)
accepted_commit = accepted.get("commit") or "uncommitted"
lines.append(
f" Latest accepted: {accepted.get('id') or 'unnamed'} ({accepted.get('result') or 'unknown'}, {accepted_seed_text}{accepted_generator_seed_text}, {accepted.get('decision') or 'unknown'}, commit {accepted_commit})"
)
if accepted.get("candidate_prompt_summary"):
lines.append(f" Accepted candidate: {accepted['candidate_prompt_summary']}")
if accepted.get("observation"):
lines.append(f" Accepted observation: {accepted['observation']}")
summary = coverage_summary()
if summary["next_test_candidates"]:
lines.extend(
[
"",
"## Next Fixed-Seed Tests",
"",
*[f"- {key}" for key in summary["next_test_candidates"]],
]
)
template_commands = next_eval_template_commands()
if template_commands:
lines.extend(["", "## Eval Entry Template Commands", ""])
for command in template_commands:
lines.append(f"- {command['key']}: `{command['command']}`")
stronger_control_rows = [row for row in coverage_rows() if row.get("coverage_state") == "needs_stronger_control"]
if stronger_control_rows:
lines.extend(["", "## Stronger Control Cases", ""])
for row in stronger_control_rows:
difficulty = row.get("difficulty") or "unrated"
priority = row.get("priority") or "unprioritized"
control_requirement = row.get("control_requirement") or "control_needed"
lines.append(
f"- {row['key']}: {difficulty}, {priority} priority, {control_requirement}"
)
expansion_plans = guide_expansion_plans()
if expansion_plans:
lines.extend(["", "## Guide/Fragile Evidence Expansion", ""])
for plan in expansion_plans:
seed = plan.get("latest_accepted_seed")
seed_text = f"seed {seed}" if isinstance(seed, int) else "seed unknown"
lines.append(
f"- {plan['key']}: {plan['target']} after {plan['latest_accepted_decision']} "
f"({plan['latest_accepted_id']}, {seed_text})"
)
plans = next_test_plans()
if plans:
lines.extend(["", "## Next Test Plans"])
for plan in plans:
lines.extend(
[
"",
f"### {plan['key']}",
"",
f"- Geometry: {plan['canonical_geometry']}",
f"- References: {', '.join(plan['reference_paths']) or 'none'}",
"- Prompt cues:",
*[f" - {cue}" for cue in plan["prompt_cues"]],
"- Avoid cues:",
*[f" - {cue}" for cue in plan["avoid_cues"]],
]
)
atlas_summary = atlas_coverage_summary(atlas_root=atlas_root)
if atlas_summary["pose_folder_count"]:
unmapped = atlas_summary["unmapped_folders"]
lines.extend(
[
"",
"## Atlas Folder Coverage",
"",
f"- Pose folders: {atlas_summary['pose_folder_count']}",
f"- Mapped folders: {atlas_summary['mapped_folder_count']}",
f"- Unmapped folders: {atlas_summary['unmapped_folder_count']}",
]
)
if unmapped:
lines.extend(["", "Unmapped atlas folders:", *[f"- {folder}" for folder in unmapped]])
gap_plans = atlas_gap_plans(atlas_root=atlas_root)
if gap_plans:
lines.extend(["", "## Atlas Gap Plans"])
for plan in gap_plans:
sample_images = plan["sample_images"]
control_images = plan["control_images"]
lines.extend(
[
"",
f"### {plan['folder']}",
"",
f"- Suggested key: {plan['suggested_variant_key']}",
f"- Pose images: {plan['image_count']}",
f"- Samples: {', '.join(sample_images) or 'none'}",
f"- Controls: {', '.join(control_images) or 'none'}",
]
)
return "\n".join(lines)
def main(argv: list[str] | None = None) -> int:
_ = argv
print(markdown_report())
return 0
if __name__ == "__main__":
raise SystemExit(main(sys.argv[1:]))