Initial commit: VLM-as-judge prompt calibration loop

Qwen3-VL image-similarity judge node, external-prompt receptor node,
agent_bridge CLI, example SDXL workflow, and methodology/agent-loop/
calibration-policy docs.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-26 22:15:56 +02:00
commit 95198a15b5
13 changed files with 1294 additions and 0 deletions
View File
+418
View File
@@ -0,0 +1,418 @@
"""
Qwen3-VL Image-Similarity Judge node for ComfyUI.
The "vllm node" of the Prompt Calibrator. It takes a REFERENCE image and a
GENERATED image and asks a local Qwen3-VL model how close the generated image is
to the reference, returning a machine-readable score + per-axis difference
analysis that the calibration controller can act on.
Reuses the standard transformers Qwen3-VL plumbing (the same approach used by
ComfyUI-QwenVL-MultiImage / ComfyUI_Qwen3-VL-Instruct), but forces strict JSON
output so the result is usable by an automated loop rather than a human reader.
Default model is the locally converted huihui-ai Qwen3-VL-4B-Instruct
*abliterated* (uncensored) weights, which do not refuse to analyze adult imagery.
"""
from __future__ import annotations
import json
import os
import re
import numpy as np
import torch
from PIL import Image
# Default to the model already converted on this machine (works out of the box).
DEFAULT_MODEL_PATH = "/media/p5/qwen3vl_4b_abliterated_comfy_convert/hf_bf16"
DEFAULT_MODEL_PATH_FP8 = "/media/p5/qwen3vl_4b_abliterated_comfy_convert/hf_fp8"
# Recommended abliterated upgrades for the RTX 5090 32 GB (latest Qwen VL family).
# Download with: hf download <repo> --local-dir <dir>, then point model_path at it.
RECOMMENDED_MODELS = {
# Best judge that fits 32 GB. MoE (3B active -> fast). Use precision="nf4"
# (~18 GB) on 32 GB, or the GGUF quants via a GGUF node. transformers class:
# Qwen3VLMoeForConditionalGeneration (auto-detected below).
"30b-a3b": "huihui-ai/Huihui-Qwen3-VL-30B-A3B-Instruct-abliterated",
# Easy middle ground: bf16 ~17 GB, no quantization hassle, drop-in here.
"8b": "huihui-ai/Huihui-Qwen3-VL-8B-Instruct-abliterated",
# Lightweight, already local.
"4b": "huihui-ai/Huihui-Qwen3-VL-4B-Instruct-abliterated",
}
DEFAULT_AXES = "cast, clothing, pose, scene, composition, expression, color_light"
# Cache loaded (model, processor) keyed by (path, precision) so the loop does not
# reload weights every iteration.
_MODEL_CACHE: dict[tuple[str, str], tuple] = {}
def _looks_like_repo_id(s: str) -> bool:
"""'org/name' HF repo id, not an absolute/local filesystem path."""
return ("/" in s) and (" " not in s) and (not os.path.isabs(s)) and (not s.startswith("."))
def _download_target_dir(repo_id: str) -> str:
"""Where to put downloaded weights — prefer ComfyUI's models/prompt_generator/."""
name = repo_id.split("/")[-1]
try:
import folder_paths # available when running inside ComfyUI
base = os.path.join(folder_paths.models_dir, "prompt_generator")
except Exception:
base = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
return os.path.join(base, name)
def _resolve_model_source(model_path: str, auto_download: bool) -> str:
"""Turn model_path (local dir | short alias | HF repo id) into a local dir.
Downloads from the Hub on first use if needed (and auto_download is on).
"""
# Short alias -> full repo id (e.g. "30b-a3b", "8b", "4b").
if model_path in RECOMMENDED_MODELS:
model_path = RECOMMENDED_MODELS[model_path]
if os.path.isdir(model_path):
return model_path
if _looks_like_repo_id(model_path):
target = _download_target_dir(model_path)
# Already downloaded? (a config.json is enough to trust the local copy)
if os.path.isfile(os.path.join(target, "config.json")):
return target
if not auto_download:
raise FileNotFoundError(
f"[QwenVLImageJudge] '{model_path}' is not downloaded and auto_download is off. "
f"Enable auto_download or pre-fetch it to {target}.")
from huggingface_hub import snapshot_download
print(f"[QwenVLImageJudge] downloading {model_path} -> {target} (first run only, may be large)...")
local = snapshot_download(
repo_id=model_path,
local_dir=target,
# weights + processor/tokenizer/config; skip duplicate GGUF/onnx blobs.
allow_patterns=["*.json", "*.safetensors", "*.txt", "*.model", "merges.txt", "*.py"],
)
print(f"[QwenVLImageJudge] download complete: {local}")
return local
# A local path that simply doesn't exist.
raise FileNotFoundError(
f"[QwenVLImageJudge] model_path not found: {model_path}. "
f"Use a local checkpoint dir, a HF repo id (org/name), or an alias "
f"({', '.join(RECOMMENDED_MODELS)}).")
def _tensor_to_pil(image: "torch.Tensor") -> Image.Image:
"""ComfyUI IMAGE tensor (B,H,W,C float 0..1) -> first-frame PIL.Image (RGB)."""
if image is None:
raise ValueError("Judge node received an empty image input.")
arr = image
if hasattr(arr, "detach"):
arr = arr.detach().cpu().numpy()
arr = np.asarray(arr)
if arr.ndim == 4: # batch -> take first frame
arr = arr[0]
arr = np.clip(arr * 255.0, 0, 255).astype(np.uint8)
if arr.ndim == 2:
arr = np.stack([arr] * 3, axis=-1)
if arr.shape[-1] == 4: # drop alpha
arr = arr[..., :3]
return Image.fromarray(arr, mode="RGB")
def _resolve_vl_class(model_path: str):
"""Pick the right transformers class. AutoModelForImageTextToText reads the
checkpoint's `architectures` and instantiates the correct dense
(Qwen3VLForConditionalGeneration) or MoE (Qwen3VLMoeForConditionalGeneration)
class automatically — so 4B/8B *and* 30B-A3B all work without branching."""
try:
from transformers import AutoModelForImageTextToText as _Auto
return _Auto
except ImportError: # pragma: no cover - older transformers
name = model_path.lower()
is_moe = any(t in name for t in ("a3b", "moe", "30b", "235b"))
if is_moe:
from transformers import Qwen3VLMoeForConditionalGeneration as _C
else:
from transformers import Qwen3VLForConditionalGeneration as _C
return _C
def _load_model(model_path: str, precision: str):
key = (model_path, precision)
if key in _MODEL_CACHE:
return _MODEL_CACHE[key]
# Imported lazily so the node can be registered even if transformers is old.
from transformers import AutoProcessor
_VLModel = _resolve_vl_class(model_path)
load_kwargs = dict(device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True)
if precision == "nf4":
# 4-bit (bitsandbytes) — lets the 30B-A3B abliterated MoE fit in ~18 GB on 32 GB.
from transformers import BitsAndBytesConfig
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
elif precision == "fp8":
# Pre-quantized FP8 weights: let the checkpoint dictate dtype.
pass
else:
load_kwargs["dtype"] = torch.bfloat16 if precision == "bf16" else torch.float16
model = _VLModel.from_pretrained(model_path, **load_kwargs)
model.eval()
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
_MODEL_CACHE[key] = (model, processor)
return model, processor
def _build_system_prompt(axes: list[str]) -> str:
axis_lines = "\n".join(f' "{a}": {{"score": <0..1>, "diff": "<short note>"}},' for a in axes)
return (
"You are a meticulous visual-similarity judge for an image-generation "
"calibration loop. You are shown two images: IMAGE 1 is the REFERENCE "
"(the target) and IMAGE 2 is the GENERATED candidate. Judge how closely "
"the GENERATED image reproduces the REFERENCE.\n\n"
"Score each axis from 0 to 1 using this anchored rubric:\n"
" 0.0 = unrelated; 0.5 = same general category but clearly different "
"details; 1.0 = near-identical.\n"
"For each axis, FIRST note the concrete difference, THEN assign the number.\n\n"
"Reply with STRICT JSON only, no prose, no markdown fences, exactly:\n"
"{\n"
' "overall_score": <0..1>,\n'
' "axes": {\n'
f"{axis_lines}\n"
" },\n"
' "fix_suggestions": ["<actionable change to the generation prompt>", ...]\n'
"}\n"
"Phrase every diff and fix in terms of the named axes "
"(cast/clothing/pose/scene/composition/expression/color_light). "
"overall_score must be consistent with the per-axis scores."
)
def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperature):
"""One forward pass; returns the raw decoded string."""
messages = [
{"role": "system", "content": _build_system_prompt(axes)},
{
"role": "user",
"content": [
{"type": "text", "text": "IMAGE 1 = REFERENCE (target):"},
{"type": "image", "image": ref_pil},
{"type": "text", "text": "IMAGE 2 = GENERATED candidate:"},
{"type": "image", "image": gen_pil},
{"type": "text", "text": "Now return the strict JSON judgement."},
],
},
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[ref_pil, gen_pil], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = dict(max_new_tokens=max_new_tokens)
if temperature and temperature > 0:
gen_kwargs.update(do_sample=True, temperature=float(temperature))
else:
gen_kwargs.update(do_sample=False)
with torch.inference_mode():
out = model.generate(**inputs, **gen_kwargs)
trimmed = out[:, inputs.input_ids.shape[1]:]
decoded = processor.batch_decode(trimmed, skip_special_tokens=True)[0]
return decoded.strip()
def _parse_json(raw: str) -> dict | None:
"""Best-effort: pull the first balanced JSON object out of the model output."""
# Strip code fences if present.
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
candidate = fenced.group(1) if fenced else None
if candidate is None:
start = raw.find("{")
if start == -1:
return None
depth = 0
for i in range(start, len(raw)):
if raw[i] == "{":
depth += 1
elif raw[i] == "}":
depth -= 1
if depth == 0:
candidate = raw[start:i + 1]
break
if candidate is None:
return None
try:
return json.loads(candidate)
except json.JSONDecodeError:
return None
def _merge_swapped(a: dict, b: dict) -> dict:
"""Average two judgements (normal + order-swapped) to cut position bias."""
if not b:
return a
if not a:
return b
out = {"axes": {}, "fix_suggestions": []}
out["overall_score"] = round(
(float(a.get("overall_score", 0)) + float(b.get("overall_score", 0))) / 2.0, 4
)
axes = set(a.get("axes", {})) | set(b.get("axes", {}))
for ax in axes:
sa = a.get("axes", {}).get(ax, {})
sb = b.get("axes", {}).get(ax, {})
score = (float(sa.get("score", 0)) + float(sb.get("score", 0))) / 2.0
diff = sa.get("diff") or sb.get("diff") or ""
out["axes"][ax] = {"score": round(score, 4), "diff": diff}
out["fix_suggestions"] = (a.get("fix_suggestions") or []) + (b.get("fix_suggestions") or [])
return out
def _report_base_dir(report_dir: str) -> str:
if report_dir:
return report_dir
try:
import folder_paths
return os.path.join(folder_paths.get_output_directory(), "calibrator")
except Exception:
return os.path.join(os.path.dirname(os.path.dirname(__file__)), "output", "calibrator")
def _write_report(report_dir, run_tag, overall, merged, diff_analysis, raw_all, prompt_used):
"""Persist the analysis so the external CLI agent can read it after a queue.
Writes a per-run file plus a stable `latest.json` the agent can always poll.
Returns the per-run file path (or "" on failure)."""
base = _report_base_dir(report_dir)
try:
os.makedirs(base, exist_ok=True)
except OSError as e:
print(f"[QwenVLImageJudge] could not create report dir {base}: {e}")
return ""
payload = {
"run_tag": run_tag,
"overall_score": round(float(overall), 4),
"axes": (merged or {}).get("axes", {}),
"fix_suggestions": (merged or {}).get("fix_suggestions", []),
"diff_analysis": diff_analysis,
"prompt_used": prompt_used,
"raw": raw_all,
}
tag = re.sub(r"[^A-Za-z0-9._-]", "_", run_tag) if run_tag else "latest"
run_path = os.path.join(base, f"calib_{tag}.json")
for path in (run_path, os.path.join(base, "latest.json")):
try:
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
except OSError as e:
print(f"[QwenVLImageJudge] failed writing report {path}: {e}")
# A markdown sibling is handy for the agent to read as plain text.
try:
md = (f"# Calibration analysis ({tag})\n\n"
f"**overall_score:** {payload['overall_score']}\n\n"
f"**prompt_used:**\n\n{prompt_used or '(not provided)'}\n\n"
f"## per-axis\n\n{diff_analysis}\n")
with open(os.path.join(base, f"calib_{tag}.md"), "w", encoding="utf-8") as f:
f.write(md)
except OSError:
pass
return run_path
class QwenVLImageJudge:
"""ComfyUI node: score how close a generated image is to a reference."""
CATEGORY = "prompt_calibrator"
FUNCTION = "judge"
RETURN_TYPES = ("FLOAT", "STRING", "STRING", "STRING", "STRING")
RETURN_NAMES = ("overall_score", "axis_scores_json", "diff_analysis", "raw", "report_path")
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"reference_image": ("IMAGE",),
"generated_image": ("IMAGE",),
"model_path": ("STRING", {"default": DEFAULT_MODEL_PATH}),
"precision": (["bf16", "fp16", "fp8", "nf4"], {"default": "bf16"}),
"axes": ("STRING", {"default": DEFAULT_AXES, "multiline": True}),
"max_new_tokens": ("INT", {"default": 512, "min": 64, "max": 4096}),
"temperature": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.5, "step": 0.05}),
"swap_eval": ("BOOLEAN", {"default": True}),
},
"optional": {
"keep_loaded": ("BOOLEAN", {"default": True}),
"auto_download": ("BOOLEAN", {"default": True}),
# The agent reads the analysis from these files after each queue.
"report_dir": ("STRING", {"default": ""}),
"run_tag": ("STRING", {"default": ""}),
"prompt_used": ("STRING", {"default": "", "multiline": True}),
},
}
def judge(self, reference_image, generated_image, model_path, precision, axes,
max_new_tokens, temperature, swap_eval, keep_loaded=True, auto_download=True,
report_dir="", run_tag="", prompt_used=""):
axis_list = [a.strip() for a in re.split(r"[,\n]", axes) if a.strip()]
if not axis_list:
axis_list = [a.strip() for a in DEFAULT_AXES.split(",")]
try:
resolved_path = _resolve_model_source(model_path, auto_download)
except Exception as e: # missing model / download failure -> surface as score 0
msg = str(e)
print(msg)
return (0.0, "{}", msg, msg)
ref_pil = _tensor_to_pil(reference_image)
gen_pil = _tensor_to_pil(generated_image)
model, processor = _load_model(resolved_path, precision)
raw1 = _run_once(model, processor, ref_pil, gen_pil, axis_list, max_new_tokens, temperature)
parsed1 = _parse_json(raw1) or {}
raw_all = raw1
merged = parsed1
if swap_eval:
# Swap which image is called REFERENCE to average out position bias.
raw2 = _run_once(model, processor, gen_pil, ref_pil, axis_list, max_new_tokens, temperature)
parsed2 = _parse_json(raw2) or {}
merged = _merge_swapped(parsed1, parsed2)
raw_all = raw1 + "\n--- SWAPPED ---\n" + raw2
if not keep_loaded:
_MODEL_CACHE.pop((resolved_path, precision), None)
del model
torch.cuda.empty_cache()
overall = float(merged.get("overall_score", 0.0)) if merged else 0.0
axis_scores = json.dumps(merged.get("axes", {}), ensure_ascii=False, indent=2) if merged else "{}"
# Human/controller-readable diff summary.
diff_lines = []
for ax, info in (merged.get("axes", {}) if merged else {}).items():
diff_lines.append(f"- {ax}: {info.get('score', 0):.2f}{info.get('diff', '')}")
fixes = merged.get("fix_suggestions", []) if merged else []
if fixes:
diff_lines.append("fixes: " + "; ".join(str(f) for f in fixes))
diff_analysis = "\n".join(diff_lines) if diff_lines else "(no parseable judgement)"
report_path = _write_report(
report_dir, run_tag, overall, merged, diff_analysis, raw_all, prompt_used)
return (round(overall, 4), axis_scores, diff_analysis, raw_all, report_path)
NODE_CLASS_MAPPINGS = {"QwenVLImageJudge": QwenVLImageJudge}
NODE_DISPLAY_NAME_MAPPINGS = {"QwenVLImageJudge": "Qwen3-VL Image Judge (Calibrator)"}
+66
View File
@@ -0,0 +1,66 @@
"""
Calibrator Prompt Receptor node.
The injection point for the external CLI-agent controller. The agent overrides
this node's widget values per queue via the ComfyUI HTTP API (`POST /prompt`,
override by node id), or — as a fallback — points `source_file` at a JSON file
the agent writes. Its outputs feed the T2I sampler in place of a static prompt.
This is the "receptor in ComfyUI" in the loop:
agent -> (sets prompt here) -> T2I -> Qwen3-VL Judge -> analysis -> agent
"""
from __future__ import annotations
import json
import os
class CalibratorPromptReceptor:
CATEGORY = "prompt_calibrator"
FUNCTION = "emit"
RETURN_TYPES = ("STRING", "STRING", "INT")
RETURN_NAMES = ("prompt", "negative", "seed")
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": ("STRING", {"default": "", "multiline": True}),
"negative": ("STRING", {"default": "", "multiline": True}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0x7FFFFFFFFFFFFFFF}),
},
"optional": {
# If set and present, a JSON file {prompt, negative, seed} overrides
# the widgets above. Lets the agent drive the loop file-first if it
# prefers that to the HTTP API.
"source_file": ("STRING", {"default": ""}),
},
}
@classmethod
def IS_CHANGED(cls, prompt, negative, seed, source_file=""):
# Re-run whenever the effective inputs change: widget values (API override)
# OR the source file's mtime (file-driven mode).
mtime = ""
if source_file and os.path.isfile(source_file):
mtime = str(os.path.getmtime(source_file))
return f"{prompt}|{negative}|{seed}|{source_file}|{mtime}"
def emit(self, prompt, negative, seed, source_file=""):
if source_file and os.path.isfile(source_file):
try:
with open(source_file, "r", encoding="utf-8") as f:
data = json.load(f)
prompt = data.get("prompt", prompt)
negative = data.get("negative", negative)
seed = int(data.get("seed", seed))
except (OSError, ValueError, json.JSONDecodeError) as e:
print(f"[CalibratorPromptReceptor] could not read {source_file}: {e}")
return (prompt, negative, int(seed))
NODE_CLASS_MAPPINGS = {"CalibratorPromptReceptor": CalibratorPromptReceptor}
NODE_DISPLAY_NAME_MAPPINGS = {
"CalibratorPromptReceptor": "SxCP External Prompt (Receptor)"
}