Files
ComfyUI-Prompt-Calibrator/nodes/qwen_judge.py
T
Ethanfel 0e9e99b8b2 Handle reasoning models (Qwen3.5/3.6): no-think + JSON-only + prose fallback
Qwen3.5/3.6 are reasoning models — they 'think out loud' in markdown and never
reach the JSON, then get cut off at the token limit -> '(no parseable judgement)'.
Fixes: apply_chat_template(enable_thinking=False) + strip <think> blocks; hardened
'output ONLY JSON, do not think out loud' instruction; default max_new_tokens
1024->2048 (max 8192); and a markdown fallback parser (_parse_markdown_verdicts /
_parse_axes) that extracts per-axis {verdict,ref,gen} from the prose the model
reliably emits. describe falls back to using the raw text as the caption.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-27 10:25:16 +02:00

963 lines
46 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Qwen3-VL Image-Similarity Judge node for ComfyUI.
The "vllm node" of the Prompt Calibrator. It takes a REFERENCE image and a
GENERATED image and asks a local Qwen3-VL model how close the generated image is
to the reference, returning a machine-readable score + per-axis difference
analysis that the calibration controller can act on.
Reuses the standard transformers Qwen3-VL plumbing (the same approach used by
ComfyUI-QwenVL-MultiImage / ComfyUI_Qwen3-VL-Instruct), but forces strict JSON
output so the result is usable by an automated loop rather than a human reader.
Default model is the locally converted huihui-ai Qwen3-VL-4B-Instruct
*abliterated* (uncensored) weights, which do not refuse to analyze adult imagery.
"""
from __future__ import annotations
import json
import os
import re
import numpy as np
import torch
from PIL import Image
# Default to the model already converted on this machine (works out of the box).
DEFAULT_MODEL_PATH = "/media/p5/qwen3vl_4b_abliterated_comfy_convert/hf_bf16"
DEFAULT_MODEL_PATH_FP8 = "/media/p5/qwen3vl_4b_abliterated_comfy_convert/hf_fp8"
# Recommended abliterated upgrades for the RTX 5090 32 GB (latest Qwen VL family).
# Download with: hf download <repo> --local-dir <dir>, then point model_path at it.
RECOMMENDED_MODELS = {
# Best judge that fits 32 GB. MoE (3B active -> fast). Use precision="nf4"
# (~18 GB) on 32 GB, or the GGUF quants via a GGUF node. transformers class:
# Qwen3VLMoeForConditionalGeneration (auto-detected below).
"30b-a3b": "huihui-ai/Huihui-Qwen3-VL-30B-A3B-Instruct-abliterated",
# Easy middle ground: bf16 ~17 GB, no quantization hassle, drop-in here.
"8b": "huihui-ai/Huihui-Qwen3-VL-8B-Instruct-abliterated",
# Lightweight, already local.
"4b": "huihui-ai/Huihui-Qwen3-VL-4B-Instruct-abliterated",
# Newer natively-multimodal Qwen3.5/3.6 abliterated (need a recent transformers).
"3.5-9b": "huihui-ai/Huihui-Qwen3.5-9B-abliterated", # dense 10B, fast, newer
"3.6-27b": "huihui-ai/Huihui-Qwen3.6-27B-abliterated", # dense 28B, strong (nf4)
"3.6-35b": "huihui-ai/Huihui-Qwen3.6-35B-A3B-abliterated", # MoE, top (nf4)
}
# Curated model dropdown (label shown in the node -> how to load it). The label
# carries the suggested VRAM. ALL entries are multimodal safetensors loaded via
# transformers (auto-downloaded). The Qwen3.5/3.6 entries are natively-multimodal and
# need a recent transformers (AutoModelForMultimodalLM). `model_path` overrides this.
# GGUF-only models still need a dedicated GGUF node — not run here (transformers only).
# model_select picks the MODEL (name only); `precision` is the separate quant dropdown.
# VRAM ≈ params × bytes/param: bf16 ≈ 2×, fp8 ≈ 1×, nf4 ≈ 0.6× (GB ≈ params·0.6). So on
# 32 GB: 8-10B fits bf16; 27-35B need nf4 (or fp8 if an fp8 checkpoint). `repo_by_precision`
# routes precisions to different checkpoints (the local 4B has separate bf16/fp8 dirs).
MANUAL_CHOICE = "(manual — use model_path below)"
MODEL_PRESETS = {
"Qwen3-VL-4B abliterated (huihui, local) · 4B": {
"repo": DEFAULT_MODEL_PATH,
"repo_by_precision": {"fp8": DEFAULT_MODEL_PATH_FP8}},
"Qwen3-VL-8B abliterated (huihui) · 8B": {
"repo": "huihui-ai/Huihui-Qwen3-VL-8B-Instruct-abliterated"},
"Qwen3.5-9B abliterated (huihui) · 10B dense · newer": {
"repo": "huihui-ai/Huihui-Qwen3.5-9B-abliterated"},
"Qwen3-VL-30B-A3B abliterated (huihui) · 30B MoE": {
"repo": "huihui-ai/Huihui-Qwen3-VL-30B-A3B-Instruct-abliterated"},
"Qwen3.6-27B abliterated (huihui) · 28B dense": {
"repo": "huihui-ai/Huihui-Qwen3.6-27B-abliterated"},
"Qwen3.6-35B-A3B abliterated (huihui) · 35B MoE · top": {
"repo": "huihui-ai/Huihui-Qwen3.6-35B-A3B-abliterated"},
}
# Difference axes + a one-line definition each. Definitions are injected into the
# prompt so the model fills the right axis (e.g. gender_mix = a count, not a position)
# and the action/pose cluster is captured in detail. Fully configurable on the node;
# any axis not in this map is still allowed (shown to the model by name only).
AXIS_DEFS = {
# identity / cast
"subject_count": "how many people are present (a count)",
"gender_mix": "composition BY GENDER as a count, e.g. '1 female, 1 male' (NOT positions)",
"age_appearance": "apparent age range of each subject",
"ethnicity_skin": "ethnicity and skin tone",
# body
"body_type": "overall physique / build (slim, curvy, athletic, BBW...)",
"breast_size": "breast size and shape of female subject(s)",
"distinctive_features": "tattoos, piercings, nail polish, scars — identity anchors",
"hair": "hair length, color, texture, and style",
# wardrobe
"clothing_state": "degree of undress and any garments / lingerie / accessories",
# action & pose cluster — OBSERVABLE GEOMETRY, not named labels. Naming a position
# ("doggy"/"cowgirl") is unreliable even at 30B; describe what is visible and let the
# agent compose any label from these primitives.
"sexual_act": "type of activity: vaginal, anal, oral/blowjob, handjob, fingering, none...",
"body_orientation": "who is on top / bottom / side / kneeling / standing, and which way each body faces (facing partner, same direction, or away). Describe the geometry; do NOT guess a named position.",
"limb_arrangement": "placement of legs and arms (spread, bent, raised, over shoulder, kneeling) and hand placement",
"penetration": "penetration type, depth (shallow/full), angle, and how visible it is",
"contact_points": "where bodies touch: grip/hands location, mouth, points of contact",
"genital_visibility": "which genitals are visible and how explicitly the frame shows them",
"pose": "overall body posture: torso/head lean, arch, twist, hip angle",
# affect
"facial_expression": "facial expression / affect (eyes, mouth, brow)",
"gaze": "gaze direction / eye contact (at camera, partner, away, eyes closed)",
# camera
"framing": "shot type and crop (close-up, medium, full body) and what the frame centers on",
"camera_angle": "camera angle / POV (low, high, eye-level, POV/first-person)",
# render
"scene": "location, furniture, props, background",
"lighting_color": "lighting quality and color palette / grade",
"art_style": "rendering style and realism (photoreal, anime, illustration, 3D)",
# --- DISTANCE-AWARE act-specific axes (used by analysis profiles) ---
# Oral: capture proximity/depth explicitly so 'mouth touching' vs 'head far away'
# is a measurable difference, not hidden behind a coarse 'oral' label.
"mouth_genital_contact": "is the mouth in contact with the penis/genitals? options: lips on tip / tip in mouth / shaft in mouth / licking shaft / kissing / NOT in contact",
"mouth_genital_distance": "if NOT in contact, how far is the mouth from the genitals: touching (~0), very close (<5cm), near (~10-20cm), far (>20cm). If in contact, say 'contact'.",
"oral_depth": "how much of the penis is inside the mouth: none / tip only / about half / deep (throat)",
"tongue": "tongue visible and where: not visible / on tip / on shaft / flat / extended",
"hand_on_shaft": "hands on the penis/shaft: none / one hand (base/mid/tip) / two hands",
"gaze_up": "is the giver looking up at the partner, at the camera, down, or eyes closed",
# Penetration depth/angle as measurable values
"insertion_depth": "how deep is penetration: tip only / shallow / half / full/hilt / pulling out",
"insertion_angle": "angle of penetration relative to the body: vertical / horizontal / oblique",
# Handjob
"grip_style": "how the penis is held: loose / firm / two-handed / fingertips / not held",
"stroke_position": "where the hand is along the shaft: base / mid / tip / gliding full length",
# Solo
"self_touch_location": "where the subject is touching themselves: clitoris / labia / breasts / penetration / none",
"toy_use": "any toy/object in use and where: none / dildo / vibrator / other (location)",
}
# Shared identity/body/wardrobe/affect/camera/render axes (act-independent).
_BASE_AXES = [
"subject_count", "gender_mix", "age_appearance", "ethnicity_skin", "body_type",
"breast_size", "distinctive_features", "hair", "clothing_state",
"facial_expression", "gaze", "framing", "camera_angle", "scene",
"lighting_color", "art_style",
]
# Analysis profiles: act-specialized axis sets. The act-critical axes are made
# distance/proximity-aware so magnitude (e.g. mouth-to-penis distance) is captured.
# Pick on the node via `profile`; leave `axes` empty to use the profile, or set
# `axes` to override entirely.
PROFILES = {
"general": _BASE_AXES + [
"sexual_act", "body_orientation", "limb_arrangement", "penetration",
"contact_points", "genital_visibility", "pose",
],
"oral": _BASE_AXES + [
"body_orientation", "mouth_genital_contact", "mouth_genital_distance",
"oral_depth", "tongue", "hand_on_shaft", "gaze_up", "genital_visibility", "pose",
],
"penetration": _BASE_AXES + [
"sexual_act", "body_orientation", "limb_arrangement", "insertion_depth",
"insertion_angle", "penetration", "contact_points", "genital_visibility", "pose",
],
"handjob": _BASE_AXES + [
"body_orientation", "hand_on_shaft", "grip_style", "stroke_position",
"mouth_genital_contact", "gaze_up", "genital_visibility",
],
"solo": _BASE_AXES + [
"self_touch_location", "toy_use", "insertion_depth", "limb_arrangement",
"genital_visibility", "pose",
],
}
DEFAULT_AXES = ", ".join(PROFILES["general"])
# Cache loaded (model, processor) keyed by (path, precision) so the loop does not
# reload weights every iteration.
_MODEL_CACHE: dict[tuple[str, str], tuple] = {}
def _looks_like_repo_id(s: str) -> bool:
"""'org/name' HF repo id, not an absolute/local filesystem path."""
return ("/" in s) and (" " not in s) and (not os.path.isabs(s)) and (not s.startswith("."))
def _download_target_dir(repo_id: str) -> str:
"""Where to put downloaded weights — prefer ComfyUI's models/prompt_generator/."""
name = repo_id.split("/")[-1]
try:
import folder_paths # available when running inside ComfyUI
base = os.path.join(folder_paths.models_dir, "prompt_generator")
except Exception:
base = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
return os.path.join(base, name)
def _resolve_model_source(model_path: str, auto_download: bool) -> str:
"""Turn model_path (local dir | short alias | HF repo id) into a local dir.
Downloads from the Hub on first use if needed (and auto_download is on).
"""
# Short alias -> full repo id (e.g. "30b-a3b", "8b", "4b").
if model_path in RECOMMENDED_MODELS:
model_path = RECOMMENDED_MODELS[model_path]
if os.path.isdir(model_path):
return model_path
if _looks_like_repo_id(model_path):
target = _download_target_dir(model_path)
# Already downloaded? (a config.json is enough to trust the local copy)
if os.path.isfile(os.path.join(target, "config.json")):
return target
if not auto_download:
raise FileNotFoundError(
f"[QwenVLImageJudge] '{model_path}' is not downloaded and auto_download is off. "
f"Enable auto_download or pre-fetch it to {target}.")
from huggingface_hub import snapshot_download
print(f"[QwenVLImageJudge] downloading {model_path} -> {target} (first run only, may be large)...")
local = snapshot_download(
repo_id=model_path,
local_dir=target,
# weights + processor/tokenizer/config/template; skip duplicate GGUF/onnx blobs.
allow_patterns=["*.json", "*.jinja", "*.safetensors", "*.txt", "*.model", "merges.txt", "*.py"],
)
print(f"[QwenVLImageJudge] download complete: {local}")
return local
# A local path that simply doesn't exist.
raise FileNotFoundError(
f"[QwenVLImageJudge] model_path not found: {model_path}. "
f"Use a local checkpoint dir, a HF repo id (org/name), or an alias "
f"({', '.join(RECOMMENDED_MODELS)}).")
def _tensor_to_pil(image: "torch.Tensor") -> Image.Image:
"""ComfyUI IMAGE tensor (B,H,W,C float 0..1) -> first-frame PIL.Image (RGB)."""
if image is None:
raise ValueError("Judge node received an empty image input.")
arr = image
if hasattr(arr, "detach"):
arr = arr.detach().cpu().numpy()
arr = np.asarray(arr)
if arr.ndim == 4: # batch -> take first frame
arr = arr[0]
arr = np.clip(arr * 255.0, 0, 255).astype(np.uint8)
if arr.ndim == 2:
arr = np.stack([arr] * 3, axis=-1)
if arr.shape[-1] == 4: # drop alpha
arr = arr[..., :3]
return Image.fromarray(arr, mode="RGB")
def _resolve_vl_classes(model_path: str):
"""Ordered list of candidate transformers auto classes to try. Qwen3-VL
(4B/8B/30B) loads via AutoModelForImageTextToText; the newer natively-multimodal
Qwen3.5/3.6 load via AutoModelForMultimodalLM. The two autos have separate
registries, so we try the one most likely for this model first (by name) and
fall back to the other, then to explicit Qwen3-VL classes on old transformers."""
import transformers
name = model_path.lower()
new_mm = any(t in name for t in ("3.5", "3.6", "qwen3_5", "qwen3_6", "qwen3.5", "qwen3.6"))
order = (["AutoModelForMultimodalLM", "AutoModelForImageTextToText"] if new_mm
else ["AutoModelForImageTextToText", "AutoModelForMultimodalLM"])
classes = [getattr(transformers, n) for n in order if getattr(transformers, n, None)]
is_moe = any(t in name for t in ("a3b", "moe", "30b", "235b"))
for n in (("Qwen3VLMoeForConditionalGeneration",) if is_moe else ("Qwen3VLForConditionalGeneration",)):
c = getattr(transformers, n, None)
if c:
classes.append(c)
return classes
def _load_model(model_path: str, precision: str):
key = (model_path, precision)
if key in _MODEL_CACHE:
return _MODEL_CACHE[key]
# Imported lazily so the node can be registered even if transformers is old.
from transformers import AutoProcessor
candidates = _resolve_vl_classes(model_path)
load_kwargs = dict(device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True)
if precision == "nf4":
# 4-bit (bitsandbytes) — lets the 30B-A3B abliterated MoE fit in ~18 GB on 32 GB.
from transformers import BitsAndBytesConfig
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
elif precision == "fp8":
# Pre-quantized FP8 weights: let the checkpoint dictate dtype.
pass
else:
load_kwargs["dtype"] = torch.bfloat16 if precision == "bf16" else torch.float16
model, last_err = None, None
for cls in candidates:
try:
model = cls.from_pretrained(model_path, **load_kwargs)
break
except Exception as e: # arch not in this auto class's registry -> try the next
last_err = e
model = None
if model is None:
raise RuntimeError(
f"[QwenVLImageJudge] could not load {model_path} with any of "
f"{[c.__name__ for c in candidates]}. Newer Qwen3.5/3.6 need a recent "
f"transformers (AutoModelForMultimodalLM). Last error: {last_err}")
model.eval()
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
_ensure_chat_template(processor, model_path)
_MODEL_CACHE[key] = (model, processor)
return model, processor
def _ensure_chat_template(processor, model_path: str):
"""Some ComfyUI-converted checkpoints ship the template as chat_template.jinja
(or only on the tokenizer), which AutoProcessor doesn't always pick up. Backfill
processor.chat_template from those sources so apply_chat_template works."""
if getattr(processor, "chat_template", None):
return
for fn in ("chat_template.jinja", "chat_template.json"):
fp = os.path.join(model_path, fn)
if os.path.isfile(fp):
try:
with open(fp, "r", encoding="utf-8") as f:
raw = f.read()
processor.chat_template = json.loads(raw).get("chat_template") if fn.endswith(".json") else raw
if processor.chat_template:
return
except (OSError, ValueError):
pass
tok = getattr(processor, "tokenizer", None)
if tok is not None and getattr(tok, "chat_template", None):
processor.chat_template = tok.chat_template
def _axis_definition_block(axes: list[str]) -> str:
return "\n".join(f" - {a}: {AXIS_DEFS.get(a, 'as named')}" for a in axes)
def _build_system_prompt(axes: list[str], reference_description: str = "") -> str:
axis_lines = "\n".join(
f' "{a}": {{"verdict": "match|partial|mismatch", "ref": "<ref value>", "gen": "<generated image>"}},'
for a in axes)
verdict_rule = (
" - verdict: 'match' if ref and gen are the same; 'mismatch' if they are "
"opposite or clearly different (e.g. 'on top' vs 'on bottom', 'doggy' vs "
"'cowgirl', 'short' vs 'long', 'eyes closed' vs 'at camera'); 'partial' ONLY "
"for a genuine middle ground (same category, minor difference). Do NOT default "
"to 'partial' — if the values are identical use 'match', if clearly different "
"use 'mismatch'.\n")
tail = (
"Output ONLY the JSON object — no reasoning, no step-by-step analysis, no "
"markdown, no commentary. Do NOT think out loud. Your entire reply must start "
"with '{' and end with '}', exactly:\n"
"{\n"
' "axes": {\n'
f"{axis_lines}\n"
" }\n"
"}\n")
if reference_description.strip():
# Anchored mode: the reference is a fixed canonical description (text), only the
# GENERATED image is shown. Keeps the ref side consistent across iterations.
return (
"You are a meticulous visual-similarity judge for an image-generation "
"calibration loop. You are given an AUTHORITATIVE REFERENCE description "
"(text — the target) and ONE GENERATED image. For every axis report:\n"
" - ref: the reference value taken FROM THE DESCRIPTION BELOW (quote it; do not invent)\n"
" - gen: concretely what the GENERATED image shows for this axis\n"
+ verdict_rule +
"Describe ONLY what you observe in the generated image; do NOT suggest fixes.\n\n"
"=== AUTHORITATIVE REFERENCE (the target) ===\n"
f"{reference_description.strip()}\n"
"=== end reference ===\n\n"
"Axes and exactly what each one means:\n"
f"{_axis_definition_block(axes)}\n\n"
+ tail +
"If the reference does not address an axis, verdict 'match' and ref/gen 'n/a'."
)
# Two-image mode: compare the reference image directly against the generated image.
return (
"You are a meticulous visual-similarity judge for an image-generation "
"calibration loop. You are shown two images: IMAGE 1 is the REFERENCE "
"(the target) and IMAGE 2 is the GENERATED candidate.\n\n"
"For every axis report THREE things:\n"
" - ref: concretely what IMAGE 1 (reference) shows for this axis\n"
" - gen: concretely what IMAGE 2 (generated) shows for this axis\n"
+ verdict_rule +
"Use specific concrete values (e.g. ref 'doggy style', gen 'cowgirl'), not "
"vague notes. Describe ONLY what you observe — do NOT suggest fixes.\n\n"
"Axes and exactly what each one means:\n"
f"{_axis_definition_block(axes)}\n\n"
+ tail +
"If an axis does not apply to either image, verdict 'match' and ref/gen 'n/a'."
)
def _format_chatml_qwenvl(messages):
"""Manual Qwen-VL ChatML prompt, used when the processor has no chat template
(e.g. checkpoints converted for ComfyUI that drop chat_template.json). Mirrors
apply_chat_template: each image -> <|vision_start|><|image_pad|><|vision_end|>,
which the processor then expands to the right number of image tokens."""
parts = []
for msg in messages:
parts.append(f"<|im_start|>{msg['role']}\n")
content = msg["content"]
if isinstance(content, str):
parts.append(content)
else:
for item in content:
if item.get("type") == "image":
parts.append("<|vision_start|><|image_pad|><|vision_end|>")
elif item.get("type") == "text":
parts.append(item.get("text", ""))
parts.append("<|im_end|>\n")
parts.append("<|im_start|>assistant\n")
return "".join(parts)
def _apply_template(processor, messages):
"""apply_chat_template with thinking disabled (Qwen3.5/3.6 are reasoning models that
otherwise 'think out loud' in prose and never reach the JSON). Falls back gracefully."""
try:
return processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
except TypeError:
pass # template doesn't accept enable_thinking
except (ValueError, AttributeError):
return _format_chatml_qwenvl(messages)
try:
return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except (ValueError, AttributeError):
return _format_chatml_qwenvl(messages)
def _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature):
"""Template + forward pass for a chat-message list; returns the decoded string."""
text = _apply_template(processor, messages)
inputs = processor(text=[text], images=images, return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = dict(max_new_tokens=max_new_tokens)
if temperature and temperature > 0:
gen_kwargs.update(do_sample=True, temperature=float(temperature))
else:
gen_kwargs.update(do_sample=False)
with torch.inference_mode():
out = model.generate(**inputs, **gen_kwargs)
trimmed = out[:, inputs.input_ids.shape[1]:]
decoded = processor.batch_decode(trimmed, skip_special_tokens=True)[0]
# Strip any <think>...</think> block a reasoning model may still emit.
decoded = re.sub(r"<think>.*?</think>", "", decoded, flags=re.DOTALL)
return decoded.strip()
def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperature):
"""Compare pass: ref vs gen -> raw JSON judgement string."""
messages = [
{"role": "system", "content": _build_system_prompt(axes)},
{
"role": "user",
"content": [
{"type": "text", "text": "IMAGE 1 = REFERENCE (target):"},
{"type": "image", "image": ref_pil},
{"type": "text", "text": "IMAGE 2 = GENERATED candidate:"},
{"type": "image", "image": gen_pil},
{"type": "text", "text": "Now return the strict JSON judgement."},
],
},
]
return _generate_from_messages(model, processor, messages, [ref_pil, gen_pil],
max_new_tokens, temperature)
def _run_anchored(model, processor, gen_pil, axes, max_new_tokens, temperature, reference_description):
"""Anchored compare: fixed canonical reference text + one generated image."""
messages = [
{"role": "system", "content": _build_system_prompt(axes, reference_description)},
{
"role": "user",
"content": [
{"type": "text", "text": "GENERATED candidate image:"},
{"type": "image", "image": gen_pil},
{"type": "text", "text": "Compare it to the reference description and return the strict JSON."},
],
},
]
return _generate_from_messages(model, processor, messages, [gen_pil],
max_new_tokens, temperature)
def _build_describe_prompt(axes: list[str]) -> str:
axis_lines = "\n".join(f' "{a}": "<concrete value or n/a>",' for a in axes)
return (
"You are writing the ONE canonical description of a REFERENCE image that an "
"image generator must reproduce. This description is the single source of truth "
"for the whole calibration loop, so it must be coherent and internally "
"consistent: the per-axis values must agree with each other and with the "
"paragraph (e.g. if the woman is on top, every axis that mentions arrangement "
"must say so). Describe ONLY what you observe, concretely, in prompt-ready "
"phrasing (the words a text-to-image prompt would use).\n\n"
"Axes and exactly what each one means:\n"
f"{_axis_definition_block(axes)}\n\n"
"Output ONLY the JSON object — no reasoning, no analysis, no markdown, no "
"commentary. Do NOT think out loud. Start with '{' and end with '}', exactly:\n"
"{\n"
' "description": "<one detailed, self-consistent paragraph describing the whole scene as a generation prompt>",\n'
' "axes": {\n'
f"{axis_lines}\n"
" }\n"
"}\n"
"Each axis value is a concrete description of that aspect (or \"n/a\" if absent) "
"and must not contradict the paragraph. The description is directly usable as a prompt."
)
def _run_chat(model, processor, images, system_prompt, user_prompt, max_new_tokens, temperature):
"""General VLM pass: your own system/user prompt over the image(s) -> raw text."""
content = [{"type": "image", "image": img} for img in images]
content.append({"type": "text", "text": user_prompt or "Describe this image."})
messages = []
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": content})
return _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature)
def _run_describe(model, processor, ref_pil, axes, max_new_tokens, temperature):
"""Describe pass: reference only -> raw JSON {caption, axes} string."""
messages = [
{"role": "system", "content": _build_describe_prompt(axes)},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this reference image:"},
{"type": "image", "image": ref_pil},
{"type": "text", "text": "Return the strict JSON description."},
],
},
]
return _generate_from_messages(model, processor, messages, [ref_pil],
max_new_tokens, temperature)
def _parse_json(raw: str) -> dict | None:
"""Best-effort: pull the first balanced JSON object out of the model output."""
# Strip code fences if present.
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
candidate = fenced.group(1) if fenced else None
if candidate is None:
start = raw.find("{")
if start == -1:
return None
depth = 0
for i in range(start, len(raw)):
if raw[i] == "{":
depth += 1
elif raw[i] == "}":
depth -= 1
if depth == 0:
candidate = raw[start:i + 1]
break
if candidate is None:
return None
try:
return json.loads(candidate)
except json.JSONDecodeError:
return None
def _parse_markdown_verdicts(raw: str, axes: list[str]) -> dict:
"""Fallback for reasoning models that emit prose instead of JSON. Reasoning models
reliably write a block per axis like:
**hair:**
- Ref: short, curly, brown
- Gen: long, straight, blonde
- Verdict: mismatch
Extract {verdict, ref, gen} per axis from that. Returns {} if nothing parseable."""
out = {}
for ax in axes:
m = re.search(rf"(?im)^[\s\d.>*\-]*\**\s*{re.escape(ax)}\s*\**\s*:?\s*$"
rf"|\**\s*{re.escape(ax)}\s*\**\s*:", raw)
if not m:
continue
seg = raw[m.end(): m.end() + 500]
vd = re.search(r"(?i)verdict[\s*:>-]*\**\s*(match|partial|mismatch)", seg)
if not vd:
continue
ref = re.search(r"(?im)^\W*ref[a-z]*\W*[:\-]\s*\**\s*(.+?)\s*$", seg)
gen = re.search(r"(?im)^\W*gen[a-z]*\W*[:\-]\s*\**\s*(.+?)\s*$", seg)
clean = lambda s: s.group(1).strip().strip("*").strip(" .") if s else ""
out[ax] = {"verdict": vd.group(1).lower(), "ref": clean(ref), "gen": clean(gen)}
return {"axes": out} if out else {}
def _parse_axes(raw: str, axes: list[str]) -> dict:
"""JSON first; if the model emitted prose instead, fall back to markdown verdicts."""
j = _parse_json(raw)
if j and isinstance(j.get("axes"), dict) and j["axes"]:
return j
return _parse_markdown_verdicts(raw, axes)
_VERDICT_ORDINAL = {"match": 1.0, "partial": 0.5, "mismatch": 0.0}
def _verdict_ordinal(verdict) -> float:
return _VERDICT_ORDINAL.get(str(verdict).strip().lower(), 0.0)
def _ordinal_verdict(x: float) -> str:
return "match" if x >= 0.75 else ("partial" if x >= 0.25 else "mismatch")
def _normalize_value(s) -> str:
return re.sub(r"\s+", " ", str(s).strip().lower()).strip(" .,:;")
def _apply_identical_match(axes: dict) -> dict:
"""Deterministic correction: small VLMs over-use 'partial', mislabeling axes
where ref and gen are identical. Force 'match' when the texts are equal — this
doesn't depend on the model getting the verdict right."""
for v in axes.values():
ref = v.get("ref", "")
if ref and _normalize_value(ref) == _normalize_value(v.get("gen", "")):
v["verdict"] = "match"
return axes
def _score_from_axes(axes: dict) -> tuple[float, int]:
"""Deterministic overall score (mean verdict ordinal) + mismatch count.
Computed here, not by the model, so it's reliable and monotonic."""
if not axes:
return 0.0, 0
ordinals = [_verdict_ordinal(v.get("verdict")) for v in axes.values()]
mismatches = sum(1 for o in ordinals if o == 0.0)
return round(sum(ordinals) / len(ordinals), 4), mismatches
def _merge_swapped(a: dict, b: dict) -> dict:
"""Average two judgements (normal + order-swapped) to cut position bias."""
if not b:
return a
if not a:
return b
out = {"axes": {}}
axes = set(a.get("axes", {})) | set(b.get("axes", {}))
for ax in axes:
sa = a.get("axes", {}).get(ax, {})
sb = b.get("axes", {}).get(ax, {})
# Average the two passes' verdicts on a 0/0.5/1 scale, then re-bucket.
ord_avg = (_verdict_ordinal(sa.get("verdict")) + _verdict_ordinal(sb.get("verdict"))) / 2.0
# In pass b the images were swapped, so b.ref describes the generated image
# and b.gen the reference -> invert b when falling back.
ref = sa.get("ref") or sb.get("gen") or ""
gen = sa.get("gen") or sb.get("ref") or ""
out["axes"][ax] = {"verdict": _ordinal_verdict(ord_avg), "ref": ref, "gen": gen}
return out
def _report_base_dir(report_dir: str) -> str:
if report_dir:
return report_dir
try:
import folder_paths
return os.path.join(folder_paths.get_output_directory(), "calibrator")
except Exception:
return os.path.join(os.path.dirname(os.path.dirname(__file__)), "output", "calibrator")
def _write_report(report_dir, run_tag, overall, merged, diff_analysis, raw_all,
mismatch_count=0):
"""Persist the analysis so the external CLI agent can read it after a queue.
Writes a per-run file plus a stable `latest.json` the agent can always poll.
Returns the per-run file path (or "" on failure)."""
base = _report_base_dir(report_dir)
try:
os.makedirs(base, exist_ok=True)
except OSError as e:
print(f"[QwenVLImageJudge] could not create report dir {base}: {e}")
return ""
payload = {
"run_tag": run_tag,
"overall_score": round(float(overall), 4),
"mismatch_count": mismatch_count,
"axes": (merged or {}).get("axes", {}),
"diff_analysis": diff_analysis,
"raw": raw_all,
}
tag = re.sub(r"[^A-Za-z0-9._-]", "_", run_tag) if run_tag else "latest"
run_path = os.path.join(base, f"calib_{tag}.json")
for path in (run_path, os.path.join(base, "latest.json")):
try:
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
except OSError as e:
print(f"[QwenVLImageJudge] failed writing report {path}: {e}")
# A markdown sibling is handy for the agent to read as plain text.
try:
md = (f"# Calibration analysis ({tag})\n\n"
f"**overall_score:** {payload['overall_score']}\n\n"
f"## per-axis\n\n{diff_analysis}\n")
with open(os.path.join(base, f"calib_{tag}.md"), "w", encoding="utf-8") as f:
f.write(md)
except OSError:
pass
return run_path
def _write_chat_report(report_dir, run_tag, system_prompt, user_prompt, response):
"""Persist a general-VLM (chat) response so the agent/loop can read it."""
base = _report_base_dir(report_dir)
try:
os.makedirs(base, exist_ok=True)
except OSError as e:
print(f"[QwenVLImageJudge] could not create report dir {base}: {e}")
return ""
payload = {"mode": "chat", "run_tag": run_tag, "system_prompt": system_prompt,
"user_prompt": user_prompt, "response": response}
tag = re.sub(r"[^A-Za-z0-9._-]", "_", run_tag) if run_tag else "chat"
run_path = os.path.join(base, f"calib_{tag}.json")
for path in (run_path, os.path.join(base, "latest.json")):
try:
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
except OSError as e:
print(f"[QwenVLImageJudge] failed writing report {path}: {e}")
return run_path
def _format_canonical_reference(caption: str, axes_spec: dict) -> str:
"""One canonical reference description = the paragraph + the per-axis target
values. The compare pass anchors on this so the reference side stays consistent
across iterations (no re-describing the reference each time)."""
lines = [caption.strip()] if caption else []
if axes_spec:
lines.append("")
for ax, val in axes_spec.items():
lines.append(f"- {ax}: {val}")
return "\n".join(lines).strip()
def _write_describe_report(report_dir, run_tag, caption, axes_spec, raw, canonical=""):
"""Persist the first-pass canonical description (target spec) to seed from."""
base = _report_base_dir(report_dir)
try:
os.makedirs(base, exist_ok=True)
except OSError as e:
print(f"[QwenVLImageJudge] could not create report dir {base}: {e}")
return ""
payload = {
"mode": "describe",
"run_tag": run_tag,
"caption": caption,
"axes": axes_spec, # per-axis target values -> the agent's initial axis_state
"canonical_reference": canonical or _format_canonical_reference(caption, axes_spec),
"raw": raw,
}
tag = re.sub(r"[^A-Za-z0-9._-]", "_", run_tag) if run_tag else "describe"
run_path = os.path.join(base, f"calib_{tag}.json")
for path in (run_path, os.path.join(base, "latest.json")):
try:
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
except OSError as e:
print(f"[QwenVLImageJudge] failed writing report {path}: {e}")
return run_path
class QwenVLImageJudge:
"""ComfyUI node: describe a reference, or score how close a generated image is to it."""
CATEGORY = "prompt_calibrator"
FUNCTION = "judge"
RETURN_TYPES = ("FLOAT", "STRING", "STRING", "STRING", "STRING")
RETURN_NAMES = ("overall_score", "axis_scores_json", "analysis", "raw", "report_path")
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"reference_image": ("IMAGE",),
# compare = ref vs generated -> per-axis scoring. describe = reference only
# -> target description (first pass). chat = general VLM: your own
# system_prompt + user_prompt over the image(s) -> raw text.
"mode": (["compare", "describe", "chat"], {"default": "compare"}),
# Analysis profile: act-specialized axis set (distance-aware where it
# matters). `axes` below overrides it when non-empty.
"profile": (list(PROFILES.keys()), {"default": "general"}),
# Curated model dropdown (label shows VRAM). model_path below overrides it.
"model_select": ([MANUAL_CHOICE] + list(MODEL_PRESETS.keys()),
{"default": list(MODEL_PRESETS.keys())[0]}),
"model_path": ("STRING", {"default": ""}), # manual override (local dir / HF repo / alias)
"precision": (["bf16", "fp8", "nf4"], {"default": "bf16"}),
"axes": ("STRING", {"default": "", "multiline": True}),
"max_new_tokens": ("INT", {"default": 2048, "min": 64, "max": 8192}),
"temperature": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.5, "step": 0.05}),
"swap_eval": ("BOOLEAN", {"default": True}),
"keep_loaded": ("BOOLEAN", {"default": True}),
"auto_download": ("BOOLEAN", {"default": True}),
# The agent reads the analysis from these files after each queue.
"report_dir": ("STRING", {"default": ""}),
"run_tag": ("STRING", {"default": ""}),
# compare: canonical reference text (from describe). When set, compare
# anchors on it instead of re-reading the reference image each time.
"reference_description": ("STRING", {"default": "", "multiline": True}),
# chat mode: use the node as a general VLM with your own prompts.
"system_prompt": ("STRING", {"default": "", "multiline": True}),
"user_prompt": ("STRING", {"default": "Describe this image.", "multiline": True}),
},
# Only genuine node-to-node wires stay optional (widgets in `optional` render
# as input sockets instead of editable fields in some ComfyUI frontends).
"optional": {
"generated_image": ("IMAGE",), # required for compare, ignored for describe/chat
},
}
def judge(self, reference_image, mode, model_path, precision, axes,
max_new_tokens, temperature, swap_eval, profile="general",
model_select=MANUAL_CHOICE, generated_image=None,
keep_loaded=True, auto_download=True,
report_dir="", run_tag="", reference_description="",
system_prompt="", user_prompt="Describe this image."):
# `axes` overrides the profile when provided; otherwise use the profile's axis set.
axis_list = [a.strip() for a in re.split(r"[,\n]", axes) if a.strip()]
if not axis_list:
axis_list = list(PROFILES.get(profile, PROFILES["general"]))
# Resolve the model: manual model_path overrides the dropdown. `precision` is the
# quant dropdown and applies to whichever model is chosen.
eff_precision = precision
if model_path.strip():
eff_repo = model_path.strip()
else:
preset = MODEL_PRESETS.get(model_select)
if not preset:
msg = "[QwenVLImageJudge] pick a model in model_select, or fill model_path."
print(msg); return (0.0, "{}", msg, msg, "")
# repo_by_precision routes a quant to a different checkpoint (e.g. local fp8 dir).
eff_repo = preset.get("repo_by_precision", {}).get(precision, preset["repo"])
if eff_repo.lower().endswith(".gguf"):
msg = (f"[QwenVLImageJudge] '{eff_repo}' is GGUF — this node is transformers "
f"(safetensors) only. Run GGUF models in a dedicated GGUF node "
f"(1038lab/ComfyUI-QwenVL or KLL535 Simple-Qwen3-VL-gguf).")
print(msg); return (0.0, "{}", msg, msg, "")
try:
resolved_path = _resolve_model_source(eff_repo, auto_download)
except Exception as e: # missing model / download failure -> surface as score 0
msg = str(e)
print(msg)
return (0.0, "{}", msg, msg, "")
ref_pil = _tensor_to_pil(reference_image)
model, processor = _load_model(resolved_path, eff_precision)
if mode == "chat":
gen_pil = _tensor_to_pil(generated_image) if generated_image is not None else None
return self._chat(model, processor, ref_pil, gen_pil, system_prompt, user_prompt,
max_new_tokens, temperature, resolved_path, eff_precision,
keep_loaded, report_dir, run_tag)
if mode == "describe":
return self._describe(model, processor, ref_pil, axis_list, max_new_tokens,
temperature, resolved_path, eff_precision, keep_loaded,
report_dir, run_tag)
if generated_image is None:
msg = "[QwenVLImageJudge] compare mode needs generated_image (or set mode=describe)."
print(msg)
return (0.0, "{}", msg, msg, "")
gen_pil = _tensor_to_pil(generated_image)
if reference_description.strip():
# Anchored: fixed canonical reference text + one generated image. No swap
# (single image), and the reference side stays identical across iterations.
raw_all = _run_anchored(model, processor, gen_pil, axis_list, max_new_tokens,
temperature, reference_description)
merged = _parse_axes(raw_all, axis_list)
else:
raw1 = _run_once(model, processor, ref_pil, gen_pil, axis_list, max_new_tokens, temperature)
parsed1 = _parse_axes(raw1, axis_list)
raw_all = raw1
merged = parsed1
if swap_eval:
# Swap which image is called REFERENCE to average out position bias.
raw2 = _run_once(model, processor, gen_pil, ref_pil, axis_list, max_new_tokens, temperature)
parsed2 = _parse_axes(raw2, axis_list)
merged = _merge_swapped(parsed1, parsed2)
raw_all = raw1 + "\n--- SWAPPED ---\n" + raw2
if not keep_loaded:
_MODEL_CACHE.pop((resolved_path, eff_precision), None)
del model
torch.cuda.empty_cache()
axes_map = merged.get("axes", {}) if merged else {}
# Correct the 4B's bias toward 'partial' on identical values, then score.
axes_map = _apply_identical_match(axes_map)
overall, mismatch_count = _score_from_axes(axes_map)
axis_scores = json.dumps(axes_map, ensure_ascii=False, indent=2) if axes_map else "{}"
# Summary worst-first: mismatch, then partial, then match.
items = sorted(axes_map.items(), key=lambda kv: _verdict_ordinal(kv[1].get("verdict")))
diff_lines = [
f"- {ax}: {str(info.get('verdict', '?')).upper():8} "
f"ref:[{info.get('ref', '')}] gen:[{info.get('gen', '')}]"
for ax, info in items
]
header = f"overall {overall:.2f} | {mismatch_count} mismatch(es) of {len(axes_map)} axes"
diff_analysis = header + "\n" + "\n".join(diff_lines) if diff_lines else "(no parseable judgement)"
report_path = _write_report(
report_dir, run_tag, overall, merged, diff_analysis, raw_all, mismatch_count)
return (round(overall, 4), axis_scores, diff_analysis, raw_all, report_path)
def _chat(self, model, processor, ref_pil, gen_pil, system_prompt, user_prompt,
max_new_tokens, temperature, resolved_path, precision, keep_loaded,
report_dir, run_tag):
"""General-VLM mode: not a judge — just runs your prompt over the image(s)."""
images = [ref_pil] + ([gen_pil] if gen_pil is not None else [])
text = _run_chat(model, processor, images, system_prompt, user_prompt,
max_new_tokens, temperature).strip()
if not keep_loaded:
_MODEL_CACHE.pop((resolved_path, precision), None)
del model
torch.cuda.empty_cache()
report_path = _write_chat_report(report_dir, run_tag, system_prompt, user_prompt, text)
return (1.0, "{}", text, text, report_path)
def _describe(self, model, processor, ref_pil, axis_list, max_new_tokens,
temperature, resolved_path, precision, keep_loaded, report_dir, run_tag):
"""First pass: describe the reference image the generator must reproduce.
Outputs the target spec (per-axis values) + a prompt-ready caption."""
raw = _run_describe(model, processor, ref_pil, axis_list, max_new_tokens, temperature)
parsed = _parse_json(raw) or {}
if not keep_loaded:
_MODEL_CACHE.pop((resolved_path, precision), None)
del model
torch.cuda.empty_cache()
# Fall back to the raw text as the caption if the model emitted prose, not JSON.
caption = (parsed.get("description") or parsed.get("caption") or raw).strip()
axes_spec = parsed.get("axes", {}) if isinstance(parsed.get("axes"), dict) else {}
axis_scores = json.dumps(axes_spec, ensure_ascii=False, indent=2)
# The canonical reference text the compare pass will anchor on: paragraph + axes.
canonical = _format_canonical_reference(caption, axes_spec)
analysis = canonical if caption else "(no parseable description)"
report_path = _write_describe_report(report_dir, run_tag, caption, axes_spec, raw, canonical)
# overall_score is n/a in describe mode; return 1.0 as a neutral placeholder.
return (1.0, axis_scores, analysis, raw, report_path)
NODE_CLASS_MAPPINGS = {"QwenVLImageJudge": QwenVLImageJudge}
NODE_DISPLAY_NAME_MAPPINGS = {"QwenVLImageJudge": "Qwen3-VL Image Judge (Calibrator)"}