#!/usr/bin/env python3 """Prepare repeatable SxCP prompt-probe batches without opening the MCP bridge. The live bridge call remains `tools/sxcp_mcp_client.py`. This helper keeps the batch plan and image-presence checklist deterministic so prompt-axis probes are less dependent on hand-copied commands. """ from __future__ import annotations import argparse from datetime import date import json import shlex import subprocess import sys import time from pathlib import Path from typing import Any ROOT = Path(__file__).resolve().parents[1] APPROVED_PYTHON = "/media/p5/miniforge3/bin/python" MCP_HELPER = "tools/sxcp_mcp_client.py" DEFAULT_OUT_CHANNEL = "sxcp_eval_out" DEFAULT_IN_CHANNEL = "sxcp_eval_in" NEGATIVE_OUT_CHANNEL = "sxcp_eval_negative_out" PROMPT_ORDERS = {"subject_first", "geometry_only", "prompt_order_test"} class BatchError(ValueError): pass def _load_json(path: Path) -> dict[str, Any]: with path.open("r", encoding="utf-8") as handle: data = json.load(handle) if not isinstance(data, dict): raise BatchError("batch JSON must contain one object") return data def _load_json_list(path: Path) -> list[Any]: with path.open("r", encoding="utf-8") as handle: data = json.load(handle) if not isinstance(data, list): raise BatchError("mock pulls JSON must contain a list") return data def _json_arg(value: dict[str, Any]) -> str: return shlex.quote(json.dumps(value, ensure_ascii=True, separators=(",", ":"))) def _text(value: Any) -> str: return "" if value is None else str(value).strip() def _validate_no_negative_channel(value: Any, *, field: str) -> None: text = _text(value) if text == NEGATIVE_OUT_CHANNEL: raise BatchError(f"{field} must not use {NEGATIVE_OUT_CHANNEL}") if NEGATIVE_OUT_CHANNEL in text: raise BatchError(f"{field} must not mention {NEGATIVE_OUT_CHANNEL}") def _validate_probe(raw: Any, index: int) -> dict[str, str]: if not isinstance(raw, dict): raise BatchError(f"probes[{index}] must be an object") for forbidden in ("negative", "negative_prompt", "negative_text", "negative_channel"): if forbidden in raw: raise BatchError(f"probes[{index}] must not contain {forbidden}") probe_id = _text(raw.get("id")) if not probe_id: raise BatchError(f"probes[{index}].id is required") prompt_order = _text(raw.get("prompt_order") or "subject_first") if prompt_order not in PROMPT_ORDERS: raise BatchError(f"probes[{index}].prompt_order must be one of {sorted(PROMPT_ORDERS)}") text = _text(raw.get("text")) if not text: raise BatchError(f"probes[{index}].text is required") _validate_no_negative_channel(text, field=f"probes[{index}].text") return {"id": probe_id, "prompt_order": prompt_order, "text": text} def _validate_image_path(value: Any, *, field: str) -> str: path_text = _text(value) if not path_text: raise BatchError(f"{field} is required") path = Path(path_text) if not path.is_absolute(): raise BatchError(f"{field} must be absolute") if path.suffix.lower() != ".png": raise BatchError(f"{field} must reference a PNG artifact") return path_text def load_batch(path: Path) -> dict[str, Any]: batch = _load_json(path) for forbidden in ("negative", "negative_prompt", "negative_text", "negative_channel"): if forbidden in batch: raise BatchError(f"batch must not contain {forbidden}") seed = batch.get("seed") if not isinstance(seed, int): raise BatchError("seed must be an integer sampler seed") channel_out = _text(batch.get("channel_out") or DEFAULT_OUT_CHANNEL) channel_in = _text(batch.get("channel_in") or DEFAULT_IN_CHANNEL) _validate_no_negative_channel(channel_out, field="channel_out") _validate_no_negative_channel(channel_in, field="channel_in") probes_raw = batch.get("probes") if not isinstance(probes_raw, list) or not probes_raw: raise BatchError("probes must be a non-empty list") probes = [_validate_probe(raw, index) for index, raw in enumerate(probes_raw)] return { "seed": seed, "channel_out": channel_out, "channel_in": channel_in, "probes": probes, } def load_results(path: Path) -> dict[str, Any]: data = _load_json(path) seed = data.get("seed") if not isinstance(seed, int): raise BatchError("result seed must be an integer sampler seed") channel_in = _text(data.get("channel_in") or DEFAULT_IN_CHANNEL) _validate_no_negative_channel(channel_in, field="channel_in") probes_raw = data.get("probes") if not isinstance(probes_raw, list) or not probes_raw: raise BatchError("result probes must be a non-empty list") probes: list[dict[str, Any]] = [] for index, raw in enumerate(probes_raw): if not isinstance(raw, dict): raise BatchError(f"result probes[{index}] must be an object") probe_id = _text(raw.get("id")) if not probe_id: raise BatchError(f"result probes[{index}].id is required") prompt_order = _text(raw.get("prompt_order") or "subject_first") if prompt_order not in PROMPT_ORDERS: raise BatchError(f"result probes[{index}].prompt_order must be one of {sorted(PROMPT_ORDERS)}") turn = raw.get("turn") if turn is not None and (not isinstance(turn, int) or isinstance(turn, bool)): raise BatchError(f"result probes[{index}].turn must be an integer when present") returned_seed = raw.get("returned_seed") if returned_seed is not None and (not isinstance(returned_seed, int) or isinstance(returned_seed, bool)): raise BatchError(f"result probes[{index}].returned_seed must be an integer when present") image_path = _text(raw.get("image_path")) probes.append( { "id": probe_id, "prompt_order": prompt_order, "turn": turn, "image_path": image_path, "returned_seed": returned_seed, } ) return {"seed": seed, "channel_in": channel_in, "probes": probes} def print_push_commands(batch: dict[str, Any]) -> None: for index, probe in enumerate(batch["probes"], start=1): prompt_order = probe["prompt_order"] caveat = "geometry-only: pose-axis discovery; not subject/look-controlled" if prompt_order == "geometry_only" else prompt_order print(f"# {index}/{len(batch['probes'])} {probe['id']} ({caveat})") push_args = { "channel": batch["channel_out"], "seed": batch["seed"], "text": probe["text"], } print(f"{APPROVED_PYTHON} {MCP_HELPER} call-tool comfy_push --arguments-json {_json_arg(push_args)}") pull_args = {"channel": batch["channel_in"]} print(f"{APPROVED_PYTHON} {MCP_HELPER} call-tool comfy_pull --arguments-json {_json_arg(pull_args)}") def print_result_template(batch: dict[str, Any]) -> None: template = { "seed": batch["seed"], "channel_in": batch["channel_in"], "probes": [ { "id": probe["id"], "prompt_order": probe["prompt_order"], "turn": None, "image_path": "", "returned_seed": None, } for probe in batch["probes"] ], } print(json.dumps(template, ensure_ascii=True, indent=2)) def _probe_by_id(probes: list[dict[str, Any]], probe_id: str, *, label: str) -> dict[str, Any]: for probe in probes: if probe.get("id") == probe_id: return probe raise BatchError(f"{label} {probe_id!r} was not found") def validate_results(batch: dict[str, Any], results: dict[str, Any]) -> None: if results["seed"] != batch["seed"]: raise BatchError(f"result seed {results['seed']} does not match batch seed {batch['seed']}") batch_probe_ids = [probe["id"] for probe in batch["probes"]] result_probe_ids = [probe["id"] for probe in results["probes"]] if result_probe_ids != batch_probe_ids: raise BatchError(f"result probe ids must match batch probe ids in order: expected {batch_probe_ids}, got {result_probe_ids}") turns: list[int] = [] for index, (batch_probe, result_probe) in enumerate(zip(batch["probes"], results["probes"])): expected_order = batch_probe["prompt_order"] if result_probe["prompt_order"] != expected_order: raise BatchError( f"result probes[{index}].prompt_order must match batch prompt_order {expected_order!r}" ) turn = result_probe.get("turn") if not isinstance(turn, int) or isinstance(turn, bool): raise BatchError(f"result probes[{index}].turn is required and must be an integer") turns.append(turn) _validate_image_path(result_probe.get("image_path"), field=f"result probes[{index}].image_path") returned_seed = result_probe.get("returned_seed") if returned_seed != batch["seed"]: raise BatchError( f"result probes[{index}].returned_seed must match batch seed {batch['seed']}" ) if len(set(turns)) != len(turns): raise BatchError("result probe turns must be unique") if turns != sorted(turns): raise BatchError("result probe turns must be in batch order") def _payload_from_mcp_response(data: Any) -> dict[str, Any]: if isinstance(data, dict) and ("turn" in data or "image_path" in data or "seed" in data): return data if not isinstance(data, dict): raise BatchError("MCP response must be an object") content = data.get("content") if not isinstance(content, list): raise BatchError("MCP response did not contain content") for item in content: if not isinstance(item, dict): continue text = _text(item.get("text")) if not text: continue try: payload = json.loads(text) except json.JSONDecodeError as exc: raise BatchError(f"MCP content text was not JSON: {exc}") from exc if isinstance(payload, dict): return payload raise BatchError("MCP response did not contain a JSON payload") def load_mock_pulls(path: Path) -> list[dict[str, Any]]: return [_payload_from_mcp_response(item) for item in _load_json_list(path)] def _call_mcp_tool(tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]: result = subprocess.run( [ APPROVED_PYTHON, MCP_HELPER, "call-tool", tool_name, "--arguments-json", json.dumps(arguments, ensure_ascii=True, separators=(",", ":")), ], cwd=ROOT, capture_output=True, text=True, check=False, ) if result.returncode != 0: detail = result.stderr.strip() or result.stdout.strip() or f"exit {result.returncode}" raise BatchError(f"MCP {tool_name} failed: {detail}") try: data = json.loads(result.stdout) except json.JSONDecodeError as exc: raise BatchError(f"MCP {tool_name} returned non-JSON output: {exc}") from exc return _payload_from_mcp_response(data) def _result_probe_from_payload( probe: dict[str, Any], payload: dict[str, Any], *, expected_seed: int, previous_turn: int, ) -> dict[str, Any] | None: turn = payload.get("turn") if not isinstance(turn, int) or isinstance(turn, bool): raise BatchError(f"pull result for {probe['id']} must contain an integer turn") if turn <= previous_turn: return None image_path = _validate_image_path(payload.get("image_path"), field=f"pull result for {probe['id']}.image_path") returned_seed = payload.get("seed") if not isinstance(returned_seed, int) or isinstance(returned_seed, bool): raise BatchError(f"pull result for {probe['id']} must contain an integer seed") if returned_seed != expected_seed: raise BatchError(f"pull result for {probe['id']} returned seed {returned_seed}, expected {expected_seed}") return { "id": probe["id"], "prompt_order": probe["prompt_order"], "turn": turn, "image_path": image_path, "returned_seed": returned_seed, } def run_batch( batch: dict[str, Any], *, result_path: Path, previous_turn: int, max_polls: int, poll_interval: float, run_live: bool = False, mock_pulls: list[dict[str, Any]] | None = None, ) -> dict[str, Any]: if max_polls <= 0: raise BatchError("max_polls must be positive") if poll_interval < 0: raise BatchError("poll_interval must be non-negative") if run_live and mock_pulls is not None: raise BatchError("--run and --mock-pulls-json cannot be used together") if not run_live and mock_pulls is None: raise BatchError("run_batch requires live mode or mock pulls") mock_index = 0 result_probes: list[dict[str, Any]] = [] current_turn = previous_turn for probe in batch["probes"]: if run_live: _call_mcp_tool( "comfy_push", {"channel": batch["channel_out"], "seed": batch["seed"], "text": probe["text"]}, ) for poll_index in range(max_polls): if mock_pulls is not None: if mock_index >= len(mock_pulls): raise BatchError(f"mock pulls exhausted while waiting for {probe['id']}") payload = mock_pulls[mock_index] mock_index += 1 else: payload = _call_mcp_tool("comfy_pull", {"channel": batch["channel_in"]}) result_probe = _result_probe_from_payload( probe, payload, expected_seed=batch["seed"], previous_turn=current_turn, ) if result_probe is not None: result_probes.append(result_probe) current_turn = result_probe["turn"] break if run_live and poll_index < max_polls - 1: time.sleep(poll_interval) else: raise BatchError(f"no new result for {probe['id']} after {max_polls} polls") results = {"seed": batch["seed"], "channel_in": batch["channel_in"], "probes": result_probes} validate_results(batch, results) result_path.write_text(json.dumps(results, ensure_ascii=True, indent=2) + "\n", encoding="utf-8") return results def _entry_id_slug(value: str) -> str: chars = [char.lower() if char.isalnum() else "-" for char in value] slug = "".join(chars).strip("-") while "--" in slug: slug = slug.replace("--", "-") return slug or "sxcp-batch" def eval_entry_draft( batch: dict[str, Any], results: dict[str, Any], *, variant_key: str, entry_id: str, baseline_image: str, candidate_id: str, source: str, result: str, decision: str, entry_date: str, allow_geometry_only: bool = False, ) -> dict[str, Any]: validate_results(batch, results) batch_probe = _probe_by_id(batch["probes"], candidate_id, label="candidate probe") result_probe = _probe_by_id(results["probes"], candidate_id, label="candidate result") candidate_image = _validate_image_path(result_probe.get("image_path"), field="candidate image_path") baseline = _validate_image_path(baseline_image, field="baseline_image") prompt_order = batch_probe["prompt_order"] turn = result_probe.get("turn") returned_seed = result_probe.get("returned_seed") if returned_seed is not None and returned_seed != batch["seed"]: raise BatchError(f"candidate returned_seed {returned_seed} does not match batch seed {batch['seed']}") if prompt_order == "geometry_only" and not allow_geometry_only: raise BatchError("candidate prompt_order is geometry_only; rerun with --allow-geometry-only to draft non-controlled prompt-axis evidence") order_note = ( "subject/look-controlled candidate" if prompt_order == "subject_first" else "geometry-only prompt-order probe; do not treat as subject/look-controlled evidence" if prompt_order == "geometry_only" else "prompt-order sensitivity probe" ) entry = { "id": entry_id or f"{_entry_id_slug(variant_key)}-{batch['seed']}-{_entry_id_slug(candidate_id)}", "date": entry_date, "variant_key": variant_key, "seed": batch["seed"], "source": source, "result": result, "decision": decision, "baseline_prompt_summary": f"Replace with the same-seed baseline summary for {variant_key}.", "candidate_prompt_summary": ( f"Batch candidate {candidate_id!r} used prompt_order={prompt_order!r}; " f"replace with the pose-axis change and controlled variables." ), "observation": ( f"Replace with image comparison for candidate {candidate_id!r}" f"{f' on turn {turn}' if turn is not None else ''}. Prompt-order note: {order_note}." ), "baseline_image": baseline, "candidate_image": candidate_image, "commit": "pending", } return entry def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description=__doc__) subparsers = parser.add_subparsers(dest="command", required=True) for command in ("validate", "print-push-commands", "print-result-template"): subparser = subparsers.add_parser(command) subparser.add_argument("--batch-json", required=True, help="Path to the prompt batch JSON file.") validate_results_parser = subparsers.add_parser("validate-results") validate_results_parser.add_argument("--batch-json", required=True, help="Path to the prompt batch JSON file.") validate_results_parser.add_argument("--result-json", required=True, help="Path to a filled result template JSON file.") run_parser = subparsers.add_parser("run-batch") run_parser.add_argument("--batch-json", required=True, help="Path to the prompt batch JSON file.") run_parser.add_argument("--result-json", required=True, help="Path where the filled result JSON should be written.") run_parser.add_argument("--run", action="store_true", help="Call the live MCP helper. Omit for dry-run or mock mode.") run_parser.add_argument("--mock-pulls-json", help="Path to a JSON list of mocked sxcp_eval_in payloads for local testing.") run_parser.add_argument("--previous-turn", type=int, default=0, help="Ignore pulls at or below this turn before the first probe.") run_parser.add_argument("--max-polls", type=int, default=60, help="Maximum pull attempts per probe.") run_parser.add_argument("--poll-interval", type=float, default=2.0, help="Seconds to wait between live pull attempts.") draft_parser = subparsers.add_parser("print-eval-entry-draft") draft_parser.add_argument("--batch-json", required=True, help="Path to the prompt batch JSON file.") draft_parser.add_argument("--result-json", required=True, help="Path to a filled result template JSON file.") draft_parser.add_argument("--variant-key", required=True, help="Catalog variant key for the eval entry.") draft_parser.add_argument("--entry-id", default="", help="Durable eval entry id. Defaults to a generated id.") draft_parser.add_argument("--baseline-image", required=True, help="Absolute PNG path for the baseline image.") draft_parser.add_argument("--candidate-id", required=True, help="Probe id to use as the candidate image.") draft_parser.add_argument("--source", default="sxcp_eval_mcp_batch", help="Source label for the eval entry.") draft_parser.add_argument("--result", default="inconclusive", help="Eval result. Default: inconclusive.") draft_parser.add_argument("--decision", default="needs_more_tests", help="Eval decision. Default: needs_more_tests.") draft_parser.add_argument("--date", default=date.today().isoformat(), help="Eval entry date.") draft_parser.add_argument( "--allow-geometry-only", action="store_true", help="Allow drafting an entry from a geometry_only probe. Use only for non-controlled prompt-axis evidence.", ) return parser def main(argv: list[str] | None = None) -> int: parser = build_parser() args = parser.parse_args(argv) try: batch = load_batch(Path(args.batch_json)) except Exception as exc: print(f"error: {exc}", file=sys.stderr) return 1 if args.command == "validate": print(f"validated: {len(batch['probes'])} probes, seed {batch['seed']}") return 0 if args.command == "print-push-commands": print_push_commands(batch) return 0 if args.command == "print-result-template": print_result_template(batch) return 0 if args.command == "validate-results": try: results = load_results(Path(args.result_json)) validate_results(batch, results) except Exception as exc: print(f"error: {exc}", file=sys.stderr) return 1 print(f"validated results: {len(batch['probes'])} probes, seed {batch['seed']}") return 0 if args.command == "run-batch": try: if not args.run and not args.mock_pulls_json: print(f"dry-run: {len(batch['probes'])} probes, seed {batch['seed']}") print_push_commands(batch) return 0 mock_pulls = load_mock_pulls(Path(args.mock_pulls_json)) if args.mock_pulls_json else None results = run_batch( batch, result_path=Path(args.result_json), previous_turn=args.previous_turn, max_polls=args.max_polls, poll_interval=args.poll_interval, run_live=args.run, mock_pulls=mock_pulls, ) except Exception as exc: print(f"error: {exc}", file=sys.stderr) return 1 print(f"recorded results: {len(results['probes'])} probes, seed {results['seed']} -> {args.result_json}") return 0 if args.command == "print-eval-entry-draft": try: results = load_results(Path(args.result_json)) entry = eval_entry_draft( batch, results, variant_key=args.variant_key, entry_id=args.entry_id, baseline_image=args.baseline_image, candidate_id=args.candidate_id, source=args.source, result=args.result, decision=args.decision, entry_date=args.date, allow_geometry_only=args.allow_geometry_only, ) except Exception as exc: print(f"error: {exc}", file=sys.stderr) return 1 print(json.dumps(entry, ensure_ascii=True, indent=2)) return 0 parser.error(f"unknown command: {args.command}") return 2 if __name__ == "__main__": raise SystemExit(main())