Add Krea2 POV routing and eval tooling
This commit is contained in:
@@ -30,6 +30,7 @@ def main() -> int:
|
||||
parser.add_argument("--print-template", action="store_true", help="Print a valid eval entry template instead of recording.")
|
||||
parser.add_argument("--variant-key", help="Catalog variant key for --print-template.")
|
||||
parser.add_argument("--seed", type=int, help="Fixed seed for --print-template.")
|
||||
parser.add_argument("--generator-seed", type=int, help="Optional SxCP generator/control seed for --print-template.")
|
||||
parser.add_argument("--source", default="sxcp_eval_mcp", help="Source label for --print-template.")
|
||||
parser.add_argument("--date", default=date.today().isoformat(), help="Date for --print-template.")
|
||||
parser.add_argument("--log-path", default=str(krea2_eval_log.DEFAULT_EVAL_LOG_PATH), help="Eval log path to update.")
|
||||
@@ -43,6 +44,7 @@ def main() -> int:
|
||||
entry = krea2_eval_log.entry_template(
|
||||
args.variant_key,
|
||||
seed=args.seed,
|
||||
generator_seed=args.generator_seed,
|
||||
source=args.source,
|
||||
date=args.date,
|
||||
)
|
||||
|
||||
+1653
-240
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Small CLI for one-off SxCP MCP bridge calls.
|
||||
|
||||
The repository smoke tests run with the system Python, so MCP dependencies are
|
||||
imported only after a network subcommand is selected. For live bridge calls, run
|
||||
this with the Python environment that has the `mcp` package installed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
|
||||
DEFAULT_BRIDGE_URL = "http://192.168.1.12:9188/mcp"
|
||||
|
||||
|
||||
def _json_loads(value: str) -> dict[str, Any]:
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise argparse.ArgumentTypeError(str(exc)) from exc
|
||||
if not isinstance(parsed, dict):
|
||||
raise argparse.ArgumentTypeError("arguments JSON must decode to an object")
|
||||
return parsed
|
||||
|
||||
|
||||
def _json_default(value: Any) -> Any:
|
||||
if hasattr(value, "model_dump"):
|
||||
return value.model_dump(mode="json")
|
||||
if hasattr(value, "__dict__"):
|
||||
return value.__dict__
|
||||
return str(value)
|
||||
|
||||
|
||||
async def _list_tools(bridge_url: str) -> int:
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
async with streamablehttp_client(bridge_url) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
tools = (await session.list_tools()).tools
|
||||
for tool in tools:
|
||||
print(tool.name)
|
||||
return 0
|
||||
|
||||
|
||||
async def _call_tool(bridge_url: str, tool_name: str, arguments: dict[str, Any]) -> int:
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
async with streamablehttp_client(bridge_url) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(tool_name, arguments)
|
||||
print(json.dumps(result, ensure_ascii=True, indent=2, default=_json_default))
|
||||
return 0
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--bridge-url",
|
||||
default=DEFAULT_BRIDGE_URL,
|
||||
help=f"MCP bridge URL. Default: {DEFAULT_BRIDGE_URL}",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
subparsers.add_parser("list-tools", help="List available MCP tool names.")
|
||||
|
||||
call_parser = subparsers.add_parser("call-tool", help="Call one MCP tool.")
|
||||
call_parser.add_argument("tool_name", help="Tool name to call.")
|
||||
call_parser.add_argument(
|
||||
"--arguments-json",
|
||||
type=_json_loads,
|
||||
default={},
|
||||
help="JSON object with tool arguments.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
try:
|
||||
import anyio
|
||||
except ImportError as exc:
|
||||
raise SystemExit(
|
||||
"sxcp_mcp_client requires the MCP Python environment for network calls; "
|
||||
"try /media/p5/miniforge3/bin/python."
|
||||
) from exc
|
||||
|
||||
if args.command == "list-tools":
|
||||
return anyio.run(_list_tools, args.bridge_url)
|
||||
if args.command == "call-tool":
|
||||
return anyio.run(_call_tool, args.bridge_url, args.tool_name, args.arguments_json)
|
||||
parser.error(f"unknown command: {args.command}")
|
||||
return 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -0,0 +1,541 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Prepare repeatable SxCP prompt-probe batches without opening the MCP bridge.
|
||||
|
||||
The live bridge call remains `tools/sxcp_mcp_client.py`. This helper keeps the
|
||||
batch plan and image-presence checklist deterministic so prompt-axis probes are
|
||||
less dependent on hand-copied commands.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from datetime import date
|
||||
import json
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
APPROVED_PYTHON = "/media/p5/miniforge3/bin/python"
|
||||
MCP_HELPER = "tools/sxcp_mcp_client.py"
|
||||
DEFAULT_OUT_CHANNEL = "sxcp_eval_out"
|
||||
DEFAULT_IN_CHANNEL = "sxcp_eval_in"
|
||||
NEGATIVE_OUT_CHANNEL = "sxcp_eval_negative_out"
|
||||
PROMPT_ORDERS = {"subject_first", "geometry_only", "prompt_order_test"}
|
||||
|
||||
|
||||
class BatchError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _load_json(path: Path) -> dict[str, Any]:
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
data = json.load(handle)
|
||||
if not isinstance(data, dict):
|
||||
raise BatchError("batch JSON must contain one object")
|
||||
return data
|
||||
|
||||
|
||||
def _load_json_list(path: Path) -> list[Any]:
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
data = json.load(handle)
|
||||
if not isinstance(data, list):
|
||||
raise BatchError("mock pulls JSON must contain a list")
|
||||
return data
|
||||
|
||||
|
||||
def _json_arg(value: dict[str, Any]) -> str:
|
||||
return shlex.quote(json.dumps(value, ensure_ascii=True, separators=(",", ":")))
|
||||
|
||||
|
||||
def _text(value: Any) -> str:
|
||||
return "" if value is None else str(value).strip()
|
||||
|
||||
|
||||
def _validate_no_negative_channel(value: Any, *, field: str) -> None:
|
||||
text = _text(value)
|
||||
if text == NEGATIVE_OUT_CHANNEL:
|
||||
raise BatchError(f"{field} must not use {NEGATIVE_OUT_CHANNEL}")
|
||||
if NEGATIVE_OUT_CHANNEL in text:
|
||||
raise BatchError(f"{field} must not mention {NEGATIVE_OUT_CHANNEL}")
|
||||
|
||||
|
||||
def _validate_probe(raw: Any, index: int) -> dict[str, str]:
|
||||
if not isinstance(raw, dict):
|
||||
raise BatchError(f"probes[{index}] must be an object")
|
||||
for forbidden in ("negative", "negative_prompt", "negative_text", "negative_channel"):
|
||||
if forbidden in raw:
|
||||
raise BatchError(f"probes[{index}] must not contain {forbidden}")
|
||||
probe_id = _text(raw.get("id"))
|
||||
if not probe_id:
|
||||
raise BatchError(f"probes[{index}].id is required")
|
||||
prompt_order = _text(raw.get("prompt_order") or "subject_first")
|
||||
if prompt_order not in PROMPT_ORDERS:
|
||||
raise BatchError(f"probes[{index}].prompt_order must be one of {sorted(PROMPT_ORDERS)}")
|
||||
text = _text(raw.get("text"))
|
||||
if not text:
|
||||
raise BatchError(f"probes[{index}].text is required")
|
||||
_validate_no_negative_channel(text, field=f"probes[{index}].text")
|
||||
return {"id": probe_id, "prompt_order": prompt_order, "text": text}
|
||||
|
||||
|
||||
def _validate_image_path(value: Any, *, field: str) -> str:
|
||||
path_text = _text(value)
|
||||
if not path_text:
|
||||
raise BatchError(f"{field} is required")
|
||||
path = Path(path_text)
|
||||
if not path.is_absolute():
|
||||
raise BatchError(f"{field} must be absolute")
|
||||
if path.suffix.lower() != ".png":
|
||||
raise BatchError(f"{field} must reference a PNG artifact")
|
||||
return path_text
|
||||
|
||||
|
||||
def load_batch(path: Path) -> dict[str, Any]:
|
||||
batch = _load_json(path)
|
||||
for forbidden in ("negative", "negative_prompt", "negative_text", "negative_channel"):
|
||||
if forbidden in batch:
|
||||
raise BatchError(f"batch must not contain {forbidden}")
|
||||
seed = batch.get("seed")
|
||||
if not isinstance(seed, int):
|
||||
raise BatchError("seed must be an integer sampler seed")
|
||||
channel_out = _text(batch.get("channel_out") or DEFAULT_OUT_CHANNEL)
|
||||
channel_in = _text(batch.get("channel_in") or DEFAULT_IN_CHANNEL)
|
||||
_validate_no_negative_channel(channel_out, field="channel_out")
|
||||
_validate_no_negative_channel(channel_in, field="channel_in")
|
||||
probes_raw = batch.get("probes")
|
||||
if not isinstance(probes_raw, list) or not probes_raw:
|
||||
raise BatchError("probes must be a non-empty list")
|
||||
probes = [_validate_probe(raw, index) for index, raw in enumerate(probes_raw)]
|
||||
return {
|
||||
"seed": seed,
|
||||
"channel_out": channel_out,
|
||||
"channel_in": channel_in,
|
||||
"probes": probes,
|
||||
}
|
||||
|
||||
|
||||
def load_results(path: Path) -> dict[str, Any]:
|
||||
data = _load_json(path)
|
||||
seed = data.get("seed")
|
||||
if not isinstance(seed, int):
|
||||
raise BatchError("result seed must be an integer sampler seed")
|
||||
channel_in = _text(data.get("channel_in") or DEFAULT_IN_CHANNEL)
|
||||
_validate_no_negative_channel(channel_in, field="channel_in")
|
||||
probes_raw = data.get("probes")
|
||||
if not isinstance(probes_raw, list) or not probes_raw:
|
||||
raise BatchError("result probes must be a non-empty list")
|
||||
probes: list[dict[str, Any]] = []
|
||||
for index, raw in enumerate(probes_raw):
|
||||
if not isinstance(raw, dict):
|
||||
raise BatchError(f"result probes[{index}] must be an object")
|
||||
probe_id = _text(raw.get("id"))
|
||||
if not probe_id:
|
||||
raise BatchError(f"result probes[{index}].id is required")
|
||||
prompt_order = _text(raw.get("prompt_order") or "subject_first")
|
||||
if prompt_order not in PROMPT_ORDERS:
|
||||
raise BatchError(f"result probes[{index}].prompt_order must be one of {sorted(PROMPT_ORDERS)}")
|
||||
turn = raw.get("turn")
|
||||
if turn is not None and (not isinstance(turn, int) or isinstance(turn, bool)):
|
||||
raise BatchError(f"result probes[{index}].turn must be an integer when present")
|
||||
returned_seed = raw.get("returned_seed")
|
||||
if returned_seed is not None and (not isinstance(returned_seed, int) or isinstance(returned_seed, bool)):
|
||||
raise BatchError(f"result probes[{index}].returned_seed must be an integer when present")
|
||||
image_path = _text(raw.get("image_path"))
|
||||
probes.append(
|
||||
{
|
||||
"id": probe_id,
|
||||
"prompt_order": prompt_order,
|
||||
"turn": turn,
|
||||
"image_path": image_path,
|
||||
"returned_seed": returned_seed,
|
||||
}
|
||||
)
|
||||
return {"seed": seed, "channel_in": channel_in, "probes": probes}
|
||||
|
||||
|
||||
def print_push_commands(batch: dict[str, Any]) -> None:
|
||||
for index, probe in enumerate(batch["probes"], start=1):
|
||||
prompt_order = probe["prompt_order"]
|
||||
caveat = "geometry-only: pose-axis discovery; not subject/look-controlled" if prompt_order == "geometry_only" else prompt_order
|
||||
print(f"# {index}/{len(batch['probes'])} {probe['id']} ({caveat})")
|
||||
push_args = {
|
||||
"channel": batch["channel_out"],
|
||||
"seed": batch["seed"],
|
||||
"text": probe["text"],
|
||||
}
|
||||
print(f"{APPROVED_PYTHON} {MCP_HELPER} call-tool comfy_push --arguments-json {_json_arg(push_args)}")
|
||||
pull_args = {"channel": batch["channel_in"]}
|
||||
print(f"{APPROVED_PYTHON} {MCP_HELPER} call-tool comfy_pull --arguments-json {_json_arg(pull_args)}")
|
||||
|
||||
|
||||
def print_result_template(batch: dict[str, Any]) -> None:
|
||||
template = {
|
||||
"seed": batch["seed"],
|
||||
"channel_in": batch["channel_in"],
|
||||
"probes": [
|
||||
{
|
||||
"id": probe["id"],
|
||||
"prompt_order": probe["prompt_order"],
|
||||
"turn": None,
|
||||
"image_path": "",
|
||||
"returned_seed": None,
|
||||
}
|
||||
for probe in batch["probes"]
|
||||
],
|
||||
}
|
||||
print(json.dumps(template, ensure_ascii=True, indent=2))
|
||||
|
||||
|
||||
def _probe_by_id(probes: list[dict[str, Any]], probe_id: str, *, label: str) -> dict[str, Any]:
|
||||
for probe in probes:
|
||||
if probe.get("id") == probe_id:
|
||||
return probe
|
||||
raise BatchError(f"{label} {probe_id!r} was not found")
|
||||
|
||||
|
||||
def validate_results(batch: dict[str, Any], results: dict[str, Any]) -> None:
|
||||
if results["seed"] != batch["seed"]:
|
||||
raise BatchError(f"result seed {results['seed']} does not match batch seed {batch['seed']}")
|
||||
batch_probe_ids = [probe["id"] for probe in batch["probes"]]
|
||||
result_probe_ids = [probe["id"] for probe in results["probes"]]
|
||||
if result_probe_ids != batch_probe_ids:
|
||||
raise BatchError(f"result probe ids must match batch probe ids in order: expected {batch_probe_ids}, got {result_probe_ids}")
|
||||
turns: list[int] = []
|
||||
for index, (batch_probe, result_probe) in enumerate(zip(batch["probes"], results["probes"])):
|
||||
expected_order = batch_probe["prompt_order"]
|
||||
if result_probe["prompt_order"] != expected_order:
|
||||
raise BatchError(
|
||||
f"result probes[{index}].prompt_order must match batch prompt_order {expected_order!r}"
|
||||
)
|
||||
turn = result_probe.get("turn")
|
||||
if not isinstance(turn, int) or isinstance(turn, bool):
|
||||
raise BatchError(f"result probes[{index}].turn is required and must be an integer")
|
||||
turns.append(turn)
|
||||
_validate_image_path(result_probe.get("image_path"), field=f"result probes[{index}].image_path")
|
||||
returned_seed = result_probe.get("returned_seed")
|
||||
if returned_seed != batch["seed"]:
|
||||
raise BatchError(
|
||||
f"result probes[{index}].returned_seed must match batch seed {batch['seed']}"
|
||||
)
|
||||
if len(set(turns)) != len(turns):
|
||||
raise BatchError("result probe turns must be unique")
|
||||
if turns != sorted(turns):
|
||||
raise BatchError("result probe turns must be in batch order")
|
||||
|
||||
|
||||
def _payload_from_mcp_response(data: Any) -> dict[str, Any]:
|
||||
if isinstance(data, dict) and ("turn" in data or "image_path" in data or "seed" in data):
|
||||
return data
|
||||
if not isinstance(data, dict):
|
||||
raise BatchError("MCP response must be an object")
|
||||
content = data.get("content")
|
||||
if not isinstance(content, list):
|
||||
raise BatchError("MCP response did not contain content")
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
text = _text(item.get("text"))
|
||||
if not text:
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise BatchError(f"MCP content text was not JSON: {exc}") from exc
|
||||
if isinstance(payload, dict):
|
||||
return payload
|
||||
raise BatchError("MCP response did not contain a JSON payload")
|
||||
|
||||
|
||||
def load_mock_pulls(path: Path) -> list[dict[str, Any]]:
|
||||
return [_payload_from_mcp_response(item) for item in _load_json_list(path)]
|
||||
|
||||
|
||||
def _call_mcp_tool(tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
result = subprocess.run(
|
||||
[
|
||||
APPROVED_PYTHON,
|
||||
MCP_HELPER,
|
||||
"call-tool",
|
||||
tool_name,
|
||||
"--arguments-json",
|
||||
json.dumps(arguments, ensure_ascii=True, separators=(",", ":")),
|
||||
],
|
||||
cwd=ROOT,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
detail = result.stderr.strip() or result.stdout.strip() or f"exit {result.returncode}"
|
||||
raise BatchError(f"MCP {tool_name} failed: {detail}")
|
||||
try:
|
||||
data = json.loads(result.stdout)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise BatchError(f"MCP {tool_name} returned non-JSON output: {exc}") from exc
|
||||
return _payload_from_mcp_response(data)
|
||||
|
||||
|
||||
def _result_probe_from_payload(
|
||||
probe: dict[str, Any],
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
expected_seed: int,
|
||||
previous_turn: int,
|
||||
) -> dict[str, Any] | None:
|
||||
turn = payload.get("turn")
|
||||
if not isinstance(turn, int) or isinstance(turn, bool):
|
||||
raise BatchError(f"pull result for {probe['id']} must contain an integer turn")
|
||||
if turn <= previous_turn:
|
||||
return None
|
||||
image_path = _validate_image_path(payload.get("image_path"), field=f"pull result for {probe['id']}.image_path")
|
||||
returned_seed = payload.get("seed")
|
||||
if not isinstance(returned_seed, int) or isinstance(returned_seed, bool):
|
||||
raise BatchError(f"pull result for {probe['id']} must contain an integer seed")
|
||||
if returned_seed != expected_seed:
|
||||
raise BatchError(f"pull result for {probe['id']} returned seed {returned_seed}, expected {expected_seed}")
|
||||
return {
|
||||
"id": probe["id"],
|
||||
"prompt_order": probe["prompt_order"],
|
||||
"turn": turn,
|
||||
"image_path": image_path,
|
||||
"returned_seed": returned_seed,
|
||||
}
|
||||
|
||||
|
||||
def run_batch(
|
||||
batch: dict[str, Any],
|
||||
*,
|
||||
result_path: Path,
|
||||
previous_turn: int,
|
||||
max_polls: int,
|
||||
poll_interval: float,
|
||||
run_live: bool = False,
|
||||
mock_pulls: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if max_polls <= 0:
|
||||
raise BatchError("max_polls must be positive")
|
||||
if poll_interval < 0:
|
||||
raise BatchError("poll_interval must be non-negative")
|
||||
if run_live and mock_pulls is not None:
|
||||
raise BatchError("--run and --mock-pulls-json cannot be used together")
|
||||
if not run_live and mock_pulls is None:
|
||||
raise BatchError("run_batch requires live mode or mock pulls")
|
||||
|
||||
mock_index = 0
|
||||
result_probes: list[dict[str, Any]] = []
|
||||
current_turn = previous_turn
|
||||
for probe in batch["probes"]:
|
||||
if run_live:
|
||||
_call_mcp_tool(
|
||||
"comfy_push",
|
||||
{"channel": batch["channel_out"], "seed": batch["seed"], "text": probe["text"]},
|
||||
)
|
||||
for poll_index in range(max_polls):
|
||||
if mock_pulls is not None:
|
||||
if mock_index >= len(mock_pulls):
|
||||
raise BatchError(f"mock pulls exhausted while waiting for {probe['id']}")
|
||||
payload = mock_pulls[mock_index]
|
||||
mock_index += 1
|
||||
else:
|
||||
payload = _call_mcp_tool("comfy_pull", {"channel": batch["channel_in"]})
|
||||
result_probe = _result_probe_from_payload(
|
||||
probe,
|
||||
payload,
|
||||
expected_seed=batch["seed"],
|
||||
previous_turn=current_turn,
|
||||
)
|
||||
if result_probe is not None:
|
||||
result_probes.append(result_probe)
|
||||
current_turn = result_probe["turn"]
|
||||
break
|
||||
if run_live and poll_index < max_polls - 1:
|
||||
time.sleep(poll_interval)
|
||||
else:
|
||||
raise BatchError(f"no new result for {probe['id']} after {max_polls} polls")
|
||||
|
||||
results = {"seed": batch["seed"], "channel_in": batch["channel_in"], "probes": result_probes}
|
||||
validate_results(batch, results)
|
||||
result_path.write_text(json.dumps(results, ensure_ascii=True, indent=2) + "\n", encoding="utf-8")
|
||||
return results
|
||||
|
||||
|
||||
def _entry_id_slug(value: str) -> str:
|
||||
chars = [char.lower() if char.isalnum() else "-" for char in value]
|
||||
slug = "".join(chars).strip("-")
|
||||
while "--" in slug:
|
||||
slug = slug.replace("--", "-")
|
||||
return slug or "sxcp-batch"
|
||||
|
||||
|
||||
def eval_entry_draft(
|
||||
batch: dict[str, Any],
|
||||
results: dict[str, Any],
|
||||
*,
|
||||
variant_key: str,
|
||||
entry_id: str,
|
||||
baseline_image: str,
|
||||
candidate_id: str,
|
||||
source: str,
|
||||
result: str,
|
||||
decision: str,
|
||||
entry_date: str,
|
||||
allow_geometry_only: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
validate_results(batch, results)
|
||||
batch_probe = _probe_by_id(batch["probes"], candidate_id, label="candidate probe")
|
||||
result_probe = _probe_by_id(results["probes"], candidate_id, label="candidate result")
|
||||
candidate_image = _validate_image_path(result_probe.get("image_path"), field="candidate image_path")
|
||||
baseline = _validate_image_path(baseline_image, field="baseline_image")
|
||||
prompt_order = batch_probe["prompt_order"]
|
||||
turn = result_probe.get("turn")
|
||||
returned_seed = result_probe.get("returned_seed")
|
||||
if returned_seed is not None and returned_seed != batch["seed"]:
|
||||
raise BatchError(f"candidate returned_seed {returned_seed} does not match batch seed {batch['seed']}")
|
||||
if prompt_order == "geometry_only" and not allow_geometry_only:
|
||||
raise BatchError("candidate prompt_order is geometry_only; rerun with --allow-geometry-only to draft non-controlled prompt-axis evidence")
|
||||
order_note = (
|
||||
"subject/look-controlled candidate"
|
||||
if prompt_order == "subject_first"
|
||||
else "geometry-only prompt-order probe; do not treat as subject/look-controlled evidence"
|
||||
if prompt_order == "geometry_only"
|
||||
else "prompt-order sensitivity probe"
|
||||
)
|
||||
entry = {
|
||||
"id": entry_id or f"{_entry_id_slug(variant_key)}-{batch['seed']}-{_entry_id_slug(candidate_id)}",
|
||||
"date": entry_date,
|
||||
"variant_key": variant_key,
|
||||
"seed": batch["seed"],
|
||||
"source": source,
|
||||
"result": result,
|
||||
"decision": decision,
|
||||
"baseline_prompt_summary": f"Replace with the same-seed baseline summary for {variant_key}.",
|
||||
"candidate_prompt_summary": (
|
||||
f"Batch candidate {candidate_id!r} used prompt_order={prompt_order!r}; "
|
||||
f"replace with the pose-axis change and controlled variables."
|
||||
),
|
||||
"observation": (
|
||||
f"Replace with image comparison for candidate {candidate_id!r}"
|
||||
f"{f' on turn {turn}' if turn is not None else ''}. Prompt-order note: {order_note}."
|
||||
),
|
||||
"baseline_image": baseline,
|
||||
"candidate_image": candidate_image,
|
||||
"commit": "pending",
|
||||
}
|
||||
return entry
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
for command in ("validate", "print-push-commands", "print-result-template"):
|
||||
subparser = subparsers.add_parser(command)
|
||||
subparser.add_argument("--batch-json", required=True, help="Path to the prompt batch JSON file.")
|
||||
validate_results_parser = subparsers.add_parser("validate-results")
|
||||
validate_results_parser.add_argument("--batch-json", required=True, help="Path to the prompt batch JSON file.")
|
||||
validate_results_parser.add_argument("--result-json", required=True, help="Path to a filled result template JSON file.")
|
||||
run_parser = subparsers.add_parser("run-batch")
|
||||
run_parser.add_argument("--batch-json", required=True, help="Path to the prompt batch JSON file.")
|
||||
run_parser.add_argument("--result-json", required=True, help="Path where the filled result JSON should be written.")
|
||||
run_parser.add_argument("--run", action="store_true", help="Call the live MCP helper. Omit for dry-run or mock mode.")
|
||||
run_parser.add_argument("--mock-pulls-json", help="Path to a JSON list of mocked sxcp_eval_in payloads for local testing.")
|
||||
run_parser.add_argument("--previous-turn", type=int, default=0, help="Ignore pulls at or below this turn before the first probe.")
|
||||
run_parser.add_argument("--max-polls", type=int, default=60, help="Maximum pull attempts per probe.")
|
||||
run_parser.add_argument("--poll-interval", type=float, default=2.0, help="Seconds to wait between live pull attempts.")
|
||||
draft_parser = subparsers.add_parser("print-eval-entry-draft")
|
||||
draft_parser.add_argument("--batch-json", required=True, help="Path to the prompt batch JSON file.")
|
||||
draft_parser.add_argument("--result-json", required=True, help="Path to a filled result template JSON file.")
|
||||
draft_parser.add_argument("--variant-key", required=True, help="Catalog variant key for the eval entry.")
|
||||
draft_parser.add_argument("--entry-id", default="", help="Durable eval entry id. Defaults to a generated id.")
|
||||
draft_parser.add_argument("--baseline-image", required=True, help="Absolute PNG path for the baseline image.")
|
||||
draft_parser.add_argument("--candidate-id", required=True, help="Probe id to use as the candidate image.")
|
||||
draft_parser.add_argument("--source", default="sxcp_eval_mcp_batch", help="Source label for the eval entry.")
|
||||
draft_parser.add_argument("--result", default="inconclusive", help="Eval result. Default: inconclusive.")
|
||||
draft_parser.add_argument("--decision", default="needs_more_tests", help="Eval decision. Default: needs_more_tests.")
|
||||
draft_parser.add_argument("--date", default=date.today().isoformat(), help="Eval entry date.")
|
||||
draft_parser.add_argument(
|
||||
"--allow-geometry-only",
|
||||
action="store_true",
|
||||
help="Allow drafting an entry from a geometry_only probe. Use only for non-controlled prompt-axis evidence.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
try:
|
||||
batch = load_batch(Path(args.batch_json))
|
||||
except Exception as exc:
|
||||
print(f"error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
if args.command == "validate":
|
||||
print(f"validated: {len(batch['probes'])} probes, seed {batch['seed']}")
|
||||
return 0
|
||||
if args.command == "print-push-commands":
|
||||
print_push_commands(batch)
|
||||
return 0
|
||||
if args.command == "print-result-template":
|
||||
print_result_template(batch)
|
||||
return 0
|
||||
if args.command == "validate-results":
|
||||
try:
|
||||
results = load_results(Path(args.result_json))
|
||||
validate_results(batch, results)
|
||||
except Exception as exc:
|
||||
print(f"error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
print(f"validated results: {len(batch['probes'])} probes, seed {batch['seed']}")
|
||||
return 0
|
||||
if args.command == "run-batch":
|
||||
try:
|
||||
if not args.run and not args.mock_pulls_json:
|
||||
print(f"dry-run: {len(batch['probes'])} probes, seed {batch['seed']}")
|
||||
print_push_commands(batch)
|
||||
return 0
|
||||
mock_pulls = load_mock_pulls(Path(args.mock_pulls_json)) if args.mock_pulls_json else None
|
||||
results = run_batch(
|
||||
batch,
|
||||
result_path=Path(args.result_json),
|
||||
previous_turn=args.previous_turn,
|
||||
max_polls=args.max_polls,
|
||||
poll_interval=args.poll_interval,
|
||||
run_live=args.run,
|
||||
mock_pulls=mock_pulls,
|
||||
)
|
||||
except Exception as exc:
|
||||
print(f"error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
print(f"recorded results: {len(results['probes'])} probes, seed {results['seed']} -> {args.result_json}")
|
||||
return 0
|
||||
if args.command == "print-eval-entry-draft":
|
||||
try:
|
||||
results = load_results(Path(args.result_json))
|
||||
entry = eval_entry_draft(
|
||||
batch,
|
||||
results,
|
||||
variant_key=args.variant_key,
|
||||
entry_id=args.entry_id,
|
||||
baseline_image=args.baseline_image,
|
||||
candidate_id=args.candidate_id,
|
||||
source=args.source,
|
||||
result=args.result,
|
||||
decision=args.decision,
|
||||
entry_date=args.date,
|
||||
allow_geometry_only=args.allow_geometry_only,
|
||||
)
|
||||
except Exception as exc:
|
||||
print(f"error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
print(json.dumps(entry, ensure_ascii=True, indent=2))
|
||||
return 0
|
||||
parser.error(f"unknown command: {args.command}")
|
||||
return 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user