Files
ComfyUI-Ethanfel-Prompt-Bui…/tools/sxcp_prompt_batch.py

571 lines
23 KiB
Python

#!/usr/bin/env python3
"""Prepare repeatable SxCP prompt-probe batches without opening the MCP bridge.
The live bridge call remains `tools/sxcp_mcp_client.py`. This helper keeps the
batch plan and image-presence checklist deterministic so prompt-axis probes are
less dependent on hand-copied commands.
"""
from __future__ import annotations
import argparse
from datetime import date
import json
import shlex
import subprocess
import sys
import time
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parents[1]
APPROVED_PYTHON = "/media/p5/miniforge3/bin/python"
MCP_HELPER = "tools/sxcp_mcp_client.py"
DEFAULT_OUT_CHANNEL = "sxcp_eval_out"
DEFAULT_IN_CHANNEL = "sxcp_eval_in"
NEGATIVE_OUT_CHANNEL = "sxcp_eval_negative_out"
PROMPT_ORDERS = {"subject_first", "geometry_only", "prompt_order_test"}
PROBE_METADATA_FIELDS = (
"variant_key",
"source_entry_id",
"source_stem",
"cue_axes",
"seed_metadata",
"evidence",
"matrix_evidence",
"selection",
"prompt_source",
"reference_images",
"notes",
)
BATCH_METADATA_FIELDS = (
"subject_id",
"variant_key",
"source_entry_id",
"source_stem",
"source_prompt_sha256",
"selection",
)
class BatchError(ValueError):
pass
def _load_json(path: Path) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as handle:
data = json.load(handle)
if not isinstance(data, dict):
raise BatchError("batch JSON must contain one object")
return data
def _load_json_list(path: Path) -> list[Any]:
with path.open("r", encoding="utf-8") as handle:
data = json.load(handle)
if not isinstance(data, list):
raise BatchError("mock pulls JSON must contain a list")
return data
def _json_arg(value: dict[str, Any]) -> str:
return shlex.quote(json.dumps(value, ensure_ascii=True, separators=(",", ":")))
def _text(value: Any) -> str:
return "" if value is None else str(value).strip()
def _validate_no_negative_channel(value: Any, *, field: str) -> None:
text = _text(value)
if text == NEGATIVE_OUT_CHANNEL:
raise BatchError(f"{field} must not use {NEGATIVE_OUT_CHANNEL}")
if NEGATIVE_OUT_CHANNEL in text:
raise BatchError(f"{field} must not mention {NEGATIVE_OUT_CHANNEL}")
def _validate_probe(raw: Any, index: int) -> dict[str, str]:
if not isinstance(raw, dict):
raise BatchError(f"probes[{index}] must be an object")
for forbidden in ("negative", "negative_prompt", "negative_text", "negative_channel"):
if forbidden in raw:
raise BatchError(f"probes[{index}] must not contain {forbidden}")
probe_id = _text(raw.get("id"))
if not probe_id:
raise BatchError(f"probes[{index}].id is required")
prompt_order = _text(raw.get("prompt_order") or "subject_first")
if prompt_order not in PROMPT_ORDERS:
raise BatchError(f"probes[{index}].prompt_order must be one of {sorted(PROMPT_ORDERS)}")
text = _text(raw.get("text"))
if not text:
raise BatchError(f"probes[{index}].text is required")
_validate_no_negative_channel(text, field=f"probes[{index}].text")
probe: dict[str, Any] = {"id": probe_id, "prompt_order": prompt_order, "text": text}
for field in PROBE_METADATA_FIELDS:
if field in raw:
probe[field] = raw[field]
return probe
def _validate_image_path(value: Any, *, field: str) -> str:
path_text = _text(value)
if not path_text:
raise BatchError(f"{field} is required")
path = Path(path_text)
if not path.is_absolute():
raise BatchError(f"{field} must be absolute")
if path.suffix.lower() != ".png":
raise BatchError(f"{field} must reference a PNG artifact")
return path_text
def load_batch(path: Path) -> dict[str, Any]:
batch = _load_json(path)
for forbidden in ("negative", "negative_prompt", "negative_text", "negative_channel"):
if forbidden in batch:
raise BatchError(f"batch must not contain {forbidden}")
seed = batch.get("seed")
if not isinstance(seed, int):
raise BatchError("seed must be an integer sampler seed")
channel_out = _text(batch.get("channel_out") or DEFAULT_OUT_CHANNEL)
channel_in = _text(batch.get("channel_in") or DEFAULT_IN_CHANNEL)
_validate_no_negative_channel(channel_out, field="channel_out")
_validate_no_negative_channel(channel_in, field="channel_in")
probes_raw = batch.get("probes")
if not isinstance(probes_raw, list) or not probes_raw:
raise BatchError("probes must be a non-empty list")
probes = [_validate_probe(raw, index) for index, raw in enumerate(probes_raw)]
loaded = {
"seed": seed,
"channel_out": channel_out,
"channel_in": channel_in,
"probes": probes,
}
for field in BATCH_METADATA_FIELDS:
if field in batch:
loaded[field] = batch[field]
return loaded
def load_results(path: Path) -> dict[str, Any]:
data = _load_json(path)
seed = data.get("seed")
if not isinstance(seed, int):
raise BatchError("result seed must be an integer sampler seed")
channel_in = _text(data.get("channel_in") or DEFAULT_IN_CHANNEL)
_validate_no_negative_channel(channel_in, field="channel_in")
probes_raw = data.get("probes")
if not isinstance(probes_raw, list) or not probes_raw:
raise BatchError("result probes must be a non-empty list")
probes: list[dict[str, Any]] = []
for index, raw in enumerate(probes_raw):
if not isinstance(raw, dict):
raise BatchError(f"result probes[{index}] must be an object")
probe_id = _text(raw.get("id"))
if not probe_id:
raise BatchError(f"result probes[{index}].id is required")
prompt_order = _text(raw.get("prompt_order") or "subject_first")
if prompt_order not in PROMPT_ORDERS:
raise BatchError(f"result probes[{index}].prompt_order must be one of {sorted(PROMPT_ORDERS)}")
turn = raw.get("turn")
if turn is not None and (not isinstance(turn, int) or isinstance(turn, bool)):
raise BatchError(f"result probes[{index}].turn must be an integer when present")
returned_seed = raw.get("returned_seed")
if returned_seed is not None and (not isinstance(returned_seed, int) or isinstance(returned_seed, bool)):
raise BatchError(f"result probes[{index}].returned_seed must be an integer when present")
image_path = _text(raw.get("image_path"))
probes.append(
{
"id": probe_id,
"prompt_order": prompt_order,
"turn": turn,
"image_path": image_path,
"returned_seed": returned_seed,
}
)
return {"seed": seed, "channel_in": channel_in, "probes": probes}
def print_push_commands(batch: dict[str, Any]) -> None:
for index, probe in enumerate(batch["probes"], start=1):
prompt_order = probe["prompt_order"]
caveat = "geometry-only: pose-axis discovery; not subject/look-controlled" if prompt_order == "geometry_only" else prompt_order
print(f"# {index}/{len(batch['probes'])} {probe['id']} ({caveat})")
push_args = {
"channel": batch["channel_out"],
"seed": batch["seed"],
"text": probe["text"],
}
print(f"{APPROVED_PYTHON} {MCP_HELPER} call-tool comfy_push --arguments-json {_json_arg(push_args)}")
pull_args = {"channel": batch["channel_in"]}
print(f"{APPROVED_PYTHON} {MCP_HELPER} call-tool comfy_pull --arguments-json {_json_arg(pull_args)}")
def print_result_template(batch: dict[str, Any]) -> None:
template = {
"seed": batch["seed"],
"channel_in": batch["channel_in"],
"probes": [
{
"id": probe["id"],
"prompt_order": probe["prompt_order"],
"turn": None,
"image_path": "",
"returned_seed": None,
}
for probe in batch["probes"]
],
}
print(json.dumps(template, ensure_ascii=True, indent=2))
def _probe_by_id(probes: list[dict[str, Any]], probe_id: str, *, label: str) -> dict[str, Any]:
for probe in probes:
if probe.get("id") == probe_id:
return probe
raise BatchError(f"{label} {probe_id!r} was not found")
def validate_results(batch: dict[str, Any], results: dict[str, Any]) -> None:
if results["seed"] != batch["seed"]:
raise BatchError(f"result seed {results['seed']} does not match batch seed {batch['seed']}")
batch_probe_ids = [probe["id"] for probe in batch["probes"]]
result_probe_ids = [probe["id"] for probe in results["probes"]]
if result_probe_ids != batch_probe_ids:
raise BatchError(f"result probe ids must match batch probe ids in order: expected {batch_probe_ids}, got {result_probe_ids}")
turns: list[int] = []
for index, (batch_probe, result_probe) in enumerate(zip(batch["probes"], results["probes"])):
expected_order = batch_probe["prompt_order"]
if result_probe["prompt_order"] != expected_order:
raise BatchError(
f"result probes[{index}].prompt_order must match batch prompt_order {expected_order!r}"
)
turn = result_probe.get("turn")
if not isinstance(turn, int) or isinstance(turn, bool):
raise BatchError(f"result probes[{index}].turn is required and must be an integer")
turns.append(turn)
_validate_image_path(result_probe.get("image_path"), field=f"result probes[{index}].image_path")
returned_seed = result_probe.get("returned_seed")
if returned_seed != batch["seed"]:
raise BatchError(
f"result probes[{index}].returned_seed must match batch seed {batch['seed']}"
)
if len(set(turns)) != len(turns):
raise BatchError("result probe turns must be unique")
if turns != sorted(turns):
raise BatchError("result probe turns must be in batch order")
def _payload_from_mcp_response(data: Any) -> dict[str, Any]:
if isinstance(data, dict) and ("turn" in data or "image_path" in data or "seed" in data):
return data
if not isinstance(data, dict):
raise BatchError("MCP response must be an object")
content = data.get("content")
if not isinstance(content, list):
raise BatchError("MCP response did not contain content")
for item in content:
if not isinstance(item, dict):
continue
text = _text(item.get("text"))
if not text:
continue
try:
payload = json.loads(text)
except json.JSONDecodeError as exc:
raise BatchError(f"MCP content text was not JSON: {exc}") from exc
if isinstance(payload, dict):
return payload
raise BatchError("MCP response did not contain a JSON payload")
def load_mock_pulls(path: Path) -> list[dict[str, Any]]:
return [_payload_from_mcp_response(item) for item in _load_json_list(path)]
def _call_mcp_tool(tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
result = subprocess.run(
[
APPROVED_PYTHON,
MCP_HELPER,
"call-tool",
tool_name,
"--arguments-json",
json.dumps(arguments, ensure_ascii=True, separators=(",", ":")),
],
cwd=ROOT,
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
detail = result.stderr.strip() or result.stdout.strip() or f"exit {result.returncode}"
raise BatchError(f"MCP {tool_name} failed: {detail}")
try:
data = json.loads(result.stdout)
except json.JSONDecodeError as exc:
raise BatchError(f"MCP {tool_name} returned non-JSON output: {exc}") from exc
return _payload_from_mcp_response(data)
def _result_probe_from_payload(
probe: dict[str, Any],
payload: dict[str, Any],
*,
expected_seed: int,
previous_turn: int,
) -> dict[str, Any] | None:
turn = payload.get("turn")
if not isinstance(turn, int) or isinstance(turn, bool):
raise BatchError(f"pull result for {probe['id']} must contain an integer turn")
if turn <= previous_turn:
return None
image_path = _validate_image_path(payload.get("image_path"), field=f"pull result for {probe['id']}.image_path")
returned_seed = payload.get("seed")
if not isinstance(returned_seed, int) or isinstance(returned_seed, bool):
raise BatchError(f"pull result for {probe['id']} must contain an integer seed")
if returned_seed != expected_seed:
raise BatchError(f"pull result for {probe['id']} returned seed {returned_seed}, expected {expected_seed}")
return {
"id": probe["id"],
"prompt_order": probe["prompt_order"],
"turn": turn,
"image_path": image_path,
"returned_seed": returned_seed,
}
def run_batch(
batch: dict[str, Any],
*,
result_path: Path,
previous_turn: int,
max_polls: int,
poll_interval: float,
run_live: bool = False,
mock_pulls: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
if max_polls <= 0:
raise BatchError("max_polls must be positive")
if poll_interval < 0:
raise BatchError("poll_interval must be non-negative")
if run_live and mock_pulls is not None:
raise BatchError("--run and --mock-pulls-json cannot be used together")
if not run_live and mock_pulls is None:
raise BatchError("run_batch requires live mode or mock pulls")
mock_index = 0
result_probes: list[dict[str, Any]] = []
current_turn = previous_turn
for probe in batch["probes"]:
if run_live:
_call_mcp_tool(
"comfy_push",
{"channel": batch["channel_out"], "seed": batch["seed"], "text": probe["text"]},
)
for poll_index in range(max_polls):
if mock_pulls is not None:
if mock_index >= len(mock_pulls):
raise BatchError(f"mock pulls exhausted while waiting for {probe['id']}")
payload = mock_pulls[mock_index]
mock_index += 1
else:
payload = _call_mcp_tool("comfy_pull", {"channel": batch["channel_in"]})
result_probe = _result_probe_from_payload(
probe,
payload,
expected_seed=batch["seed"],
previous_turn=current_turn,
)
if result_probe is not None:
result_probes.append(result_probe)
current_turn = result_probe["turn"]
break
if run_live and poll_index < max_polls - 1:
time.sleep(poll_interval)
else:
raise BatchError(f"no new result for {probe['id']} after {max_polls} polls")
results = {"seed": batch["seed"], "channel_in": batch["channel_in"], "probes": result_probes}
validate_results(batch, results)
result_path.write_text(json.dumps(results, ensure_ascii=True, indent=2) + "\n", encoding="utf-8")
return results
def _entry_id_slug(value: str) -> str:
chars = [char.lower() if char.isalnum() else "-" for char in value]
slug = "".join(chars).strip("-")
while "--" in slug:
slug = slug.replace("--", "-")
return slug or "sxcp-batch"
def eval_entry_draft(
batch: dict[str, Any],
results: dict[str, Any],
*,
variant_key: str,
entry_id: str,
baseline_image: str,
candidate_id: str,
source: str,
result: str,
decision: str,
entry_date: str,
allow_geometry_only: bool = False,
) -> dict[str, Any]:
validate_results(batch, results)
batch_probe = _probe_by_id(batch["probes"], candidate_id, label="candidate probe")
result_probe = _probe_by_id(results["probes"], candidate_id, label="candidate result")
candidate_image = _validate_image_path(result_probe.get("image_path"), field="candidate image_path")
baseline = _validate_image_path(baseline_image, field="baseline_image")
prompt_order = batch_probe["prompt_order"]
turn = result_probe.get("turn")
returned_seed = result_probe.get("returned_seed")
if returned_seed is not None and returned_seed != batch["seed"]:
raise BatchError(f"candidate returned_seed {returned_seed} does not match batch seed {batch['seed']}")
if prompt_order == "geometry_only" and not allow_geometry_only:
raise BatchError("candidate prompt_order is geometry_only; rerun with --allow-geometry-only to draft non-controlled prompt-axis evidence")
order_note = (
"subject/look-controlled candidate"
if prompt_order == "subject_first"
else "geometry-only prompt-order probe; do not treat as subject/look-controlled evidence"
if prompt_order == "geometry_only"
else "prompt-order sensitivity probe"
)
entry = {
"id": entry_id or f"{_entry_id_slug(variant_key)}-{batch['seed']}-{_entry_id_slug(candidate_id)}",
"date": entry_date,
"variant_key": variant_key,
"seed": batch["seed"],
"source": source,
"result": result,
"decision": decision,
"baseline_prompt_summary": f"Replace with the same-seed baseline summary for {variant_key}.",
"candidate_prompt_summary": (
f"Batch candidate {candidate_id!r} used prompt_order={prompt_order!r}; "
f"replace with the pose-axis change and controlled variables."
),
"observation": (
f"Replace with image comparison for candidate {candidate_id!r}"
f"{f' on turn {turn}' if turn is not None else ''}. Prompt-order note: {order_note}."
),
"baseline_image": baseline,
"candidate_image": candidate_image,
"commit": "pending",
}
return entry
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=__doc__)
subparsers = parser.add_subparsers(dest="command", required=True)
for command in ("validate", "print-push-commands", "print-result-template"):
subparser = subparsers.add_parser(command)
subparser.add_argument("--batch-json", required=True, help="Path to the prompt batch JSON file.")
validate_results_parser = subparsers.add_parser("validate-results")
validate_results_parser.add_argument("--batch-json", required=True, help="Path to the prompt batch JSON file.")
validate_results_parser.add_argument("--result-json", required=True, help="Path to a filled result template JSON file.")
run_parser = subparsers.add_parser("run-batch")
run_parser.add_argument("--batch-json", required=True, help="Path to the prompt batch JSON file.")
run_parser.add_argument("--result-json", required=True, help="Path where the filled result JSON should be written.")
run_parser.add_argument("--run", action="store_true", help="Call the live MCP helper. Omit for dry-run or mock mode.")
run_parser.add_argument("--mock-pulls-json", help="Path to a JSON list of mocked sxcp_eval_in payloads for local testing.")
run_parser.add_argument("--previous-turn", type=int, default=0, help="Ignore pulls at or below this turn before the first probe.")
run_parser.add_argument("--max-polls", type=int, default=60, help="Maximum pull attempts per probe.")
run_parser.add_argument("--poll-interval", type=float, default=2.0, help="Seconds to wait between live pull attempts.")
draft_parser = subparsers.add_parser("print-eval-entry-draft")
draft_parser.add_argument("--batch-json", required=True, help="Path to the prompt batch JSON file.")
draft_parser.add_argument("--result-json", required=True, help="Path to a filled result template JSON file.")
draft_parser.add_argument("--variant-key", required=True, help="Catalog variant key for the eval entry.")
draft_parser.add_argument("--entry-id", default="", help="Durable eval entry id. Defaults to a generated id.")
draft_parser.add_argument("--baseline-image", required=True, help="Absolute PNG path for the baseline image.")
draft_parser.add_argument("--candidate-id", required=True, help="Probe id to use as the candidate image.")
draft_parser.add_argument("--source", default="sxcp_eval_mcp_batch", help="Source label for the eval entry.")
draft_parser.add_argument("--result", default="inconclusive", help="Eval result. Default: inconclusive.")
draft_parser.add_argument("--decision", default="needs_more_tests", help="Eval decision. Default: needs_more_tests.")
draft_parser.add_argument("--date", default=date.today().isoformat(), help="Eval entry date.")
draft_parser.add_argument(
"--allow-geometry-only",
action="store_true",
help="Allow drafting an entry from a geometry_only probe. Use only for non-controlled prompt-axis evidence.",
)
return parser
def main(argv: list[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
try:
batch = load_batch(Path(args.batch_json))
except Exception as exc:
print(f"error: {exc}", file=sys.stderr)
return 1
if args.command == "validate":
print(f"validated: {len(batch['probes'])} probes, seed {batch['seed']}")
return 0
if args.command == "print-push-commands":
print_push_commands(batch)
return 0
if args.command == "print-result-template":
print_result_template(batch)
return 0
if args.command == "validate-results":
try:
results = load_results(Path(args.result_json))
validate_results(batch, results)
except Exception as exc:
print(f"error: {exc}", file=sys.stderr)
return 1
print(f"validated results: {len(batch['probes'])} probes, seed {batch['seed']}")
return 0
if args.command == "run-batch":
try:
if not args.run and not args.mock_pulls_json:
print(f"dry-run: {len(batch['probes'])} probes, seed {batch['seed']}")
print_push_commands(batch)
return 0
mock_pulls = load_mock_pulls(Path(args.mock_pulls_json)) if args.mock_pulls_json else None
results = run_batch(
batch,
result_path=Path(args.result_json),
previous_turn=args.previous_turn,
max_polls=args.max_polls,
poll_interval=args.poll_interval,
run_live=args.run,
mock_pulls=mock_pulls,
)
except Exception as exc:
print(f"error: {exc}", file=sys.stderr)
return 1
print(f"recorded results: {len(results['probes'])} probes, seed {results['seed']} -> {args.result_json}")
return 0
if args.command == "print-eval-entry-draft":
try:
results = load_results(Path(args.result_json))
entry = eval_entry_draft(
batch,
results,
variant_key=args.variant_key,
entry_id=args.entry_id,
baseline_image=args.baseline_image,
candidate_id=args.candidate_id,
source=args.source,
result=args.result,
decision=args.decision,
entry_date=args.date,
allow_geometry_only=args.allow_geometry_only,
)
except Exception as exc:
print(f"error: {exc}", file=sys.stderr)
return 1
print(json.dumps(entry, ensure_ascii=True, indent=2))
return 0
parser.error(f"unknown command: {args.command}")
return 2
if __name__ == "__main__":
raise SystemExit(main())