c7ef756a71
New mode on QwenVLImageJudge: 'describe' looks at the reference alone and returns a prompt-ready caption + per-axis target spec to seed the very first prompt (the generator has nothing to reproduce yet). 'compare' is the existing ref-vs-gen scoring. generated_image is now optional (required only for compare); shared generation refactored into _generate_from_messages; third output renamed diff_analysis -> analysis (mode-agnostic). agent_bridge gains --mode (describe needs no receptor/prompt); added workflow_describe_api.json. Docs updated with the first-pass bootstrap step. Fixed error-return arity to 5-tuple. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
602 lines
26 KiB
Python
602 lines
26 KiB
Python
"""
|
|
Qwen3-VL Image-Similarity Judge node for ComfyUI.
|
|
|
|
The "vllm node" of the Prompt Calibrator. It takes a REFERENCE image and a
|
|
GENERATED image and asks a local Qwen3-VL model how close the generated image is
|
|
to the reference, returning a machine-readable score + per-axis difference
|
|
analysis that the calibration controller can act on.
|
|
|
|
Reuses the standard transformers Qwen3-VL plumbing (the same approach used by
|
|
ComfyUI-QwenVL-MultiImage / ComfyUI_Qwen3-VL-Instruct), but forces strict JSON
|
|
output so the result is usable by an automated loop rather than a human reader.
|
|
|
|
Default model is the locally converted huihui-ai Qwen3-VL-4B-Instruct
|
|
*abliterated* (uncensored) weights, which do not refuse to analyze adult imagery.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
|
|
# Default to the model already converted on this machine (works out of the box).
|
|
DEFAULT_MODEL_PATH = "/media/p5/qwen3vl_4b_abliterated_comfy_convert/hf_bf16"
|
|
DEFAULT_MODEL_PATH_FP8 = "/media/p5/qwen3vl_4b_abliterated_comfy_convert/hf_fp8"
|
|
|
|
# Recommended abliterated upgrades for the RTX 5090 32 GB (latest Qwen VL family).
|
|
# Download with: hf download <repo> --local-dir <dir>, then point model_path at it.
|
|
RECOMMENDED_MODELS = {
|
|
# Best judge that fits 32 GB. MoE (3B active -> fast). Use precision="nf4"
|
|
# (~18 GB) on 32 GB, or the GGUF quants via a GGUF node. transformers class:
|
|
# Qwen3VLMoeForConditionalGeneration (auto-detected below).
|
|
"30b-a3b": "huihui-ai/Huihui-Qwen3-VL-30B-A3B-Instruct-abliterated",
|
|
# Easy middle ground: bf16 ~17 GB, no quantization hassle, drop-in here.
|
|
"8b": "huihui-ai/Huihui-Qwen3-VL-8B-Instruct-abliterated",
|
|
# Lightweight, already local.
|
|
"4b": "huihui-ai/Huihui-Qwen3-VL-4B-Instruct-abliterated",
|
|
}
|
|
|
|
# Difference axes the judge scores. Granular by default so the comparison is
|
|
# discriminative for explicit/adult imagery (where coarse axes blur the differences
|
|
# that matter). Fully configurable on the node — trim or extend per use case.
|
|
# subject_count number of people
|
|
# gender_mix gender composition (e.g. 1F, 2F1M)
|
|
# body_type physique / build / proportions per subject
|
|
# distinctive_features tattoos / piercings / marks (identity anchors)
|
|
# age_appearance apparent age
|
|
# ethnicity_skin ethnicity / skin tone
|
|
# hair length, color, style
|
|
# clothing_state degree of undress + specific garments
|
|
# sexual_act the act / activity being performed
|
|
# position sexual position / arrangement of bodies
|
|
# penetration type & visibility of penetration
|
|
# explicitness how graphic / genital visibility level
|
|
# body_contact who contacts whom; interaction between subjects
|
|
# pose non-act body positioning
|
|
# facial_expression face / affect
|
|
# gaze eye contact / look direction
|
|
# framing shot type / crop (close-up <-> full body)
|
|
# camera_angle POV / angle / perspective
|
|
# scene location / setting / background
|
|
# lighting_color palette, lighting, color grade
|
|
# art_style photoreal vs anime/illustrated, render style
|
|
DEFAULT_AXES = (
|
|
"subject_count, gender_mix, body_type, distinctive_features, age_appearance, "
|
|
"ethnicity_skin, hair, clothing_state, sexual_act, position, penetration, "
|
|
"explicitness, body_contact, pose, facial_expression, gaze, framing, "
|
|
"camera_angle, scene, lighting_color, art_style"
|
|
)
|
|
|
|
# Cache loaded (model, processor) keyed by (path, precision) so the loop does not
|
|
# reload weights every iteration.
|
|
_MODEL_CACHE: dict[tuple[str, str], tuple] = {}
|
|
|
|
|
|
def _looks_like_repo_id(s: str) -> bool:
|
|
"""'org/name' HF repo id, not an absolute/local filesystem path."""
|
|
return ("/" in s) and (" " not in s) and (not os.path.isabs(s)) and (not s.startswith("."))
|
|
|
|
|
|
def _download_target_dir(repo_id: str) -> str:
|
|
"""Where to put downloaded weights — prefer ComfyUI's models/prompt_generator/."""
|
|
name = repo_id.split("/")[-1]
|
|
try:
|
|
import folder_paths # available when running inside ComfyUI
|
|
base = os.path.join(folder_paths.models_dir, "prompt_generator")
|
|
except Exception:
|
|
base = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
|
|
return os.path.join(base, name)
|
|
|
|
|
|
def _resolve_model_source(model_path: str, auto_download: bool) -> str:
|
|
"""Turn model_path (local dir | short alias | HF repo id) into a local dir.
|
|
|
|
Downloads from the Hub on first use if needed (and auto_download is on).
|
|
"""
|
|
# Short alias -> full repo id (e.g. "30b-a3b", "8b", "4b").
|
|
if model_path in RECOMMENDED_MODELS:
|
|
model_path = RECOMMENDED_MODELS[model_path]
|
|
|
|
if os.path.isdir(model_path):
|
|
return model_path
|
|
|
|
if _looks_like_repo_id(model_path):
|
|
target = _download_target_dir(model_path)
|
|
# Already downloaded? (a config.json is enough to trust the local copy)
|
|
if os.path.isfile(os.path.join(target, "config.json")):
|
|
return target
|
|
if not auto_download:
|
|
raise FileNotFoundError(
|
|
f"[QwenVLImageJudge] '{model_path}' is not downloaded and auto_download is off. "
|
|
f"Enable auto_download or pre-fetch it to {target}.")
|
|
from huggingface_hub import snapshot_download
|
|
print(f"[QwenVLImageJudge] downloading {model_path} -> {target} (first run only, may be large)...")
|
|
local = snapshot_download(
|
|
repo_id=model_path,
|
|
local_dir=target,
|
|
# weights + processor/tokenizer/config/template; skip duplicate GGUF/onnx blobs.
|
|
allow_patterns=["*.json", "*.jinja", "*.safetensors", "*.txt", "*.model", "merges.txt", "*.py"],
|
|
)
|
|
print(f"[QwenVLImageJudge] download complete: {local}")
|
|
return local
|
|
|
|
# A local path that simply doesn't exist.
|
|
raise FileNotFoundError(
|
|
f"[QwenVLImageJudge] model_path not found: {model_path}. "
|
|
f"Use a local checkpoint dir, a HF repo id (org/name), or an alias "
|
|
f"({', '.join(RECOMMENDED_MODELS)}).")
|
|
|
|
|
|
def _tensor_to_pil(image: "torch.Tensor") -> Image.Image:
|
|
"""ComfyUI IMAGE tensor (B,H,W,C float 0..1) -> first-frame PIL.Image (RGB)."""
|
|
if image is None:
|
|
raise ValueError("Judge node received an empty image input.")
|
|
arr = image
|
|
if hasattr(arr, "detach"):
|
|
arr = arr.detach().cpu().numpy()
|
|
arr = np.asarray(arr)
|
|
if arr.ndim == 4: # batch -> take first frame
|
|
arr = arr[0]
|
|
arr = np.clip(arr * 255.0, 0, 255).astype(np.uint8)
|
|
if arr.ndim == 2:
|
|
arr = np.stack([arr] * 3, axis=-1)
|
|
if arr.shape[-1] == 4: # drop alpha
|
|
arr = arr[..., :3]
|
|
return Image.fromarray(arr, mode="RGB")
|
|
|
|
|
|
def _resolve_vl_class(model_path: str):
|
|
"""Pick the right transformers class. AutoModelForImageTextToText reads the
|
|
checkpoint's `architectures` and instantiates the correct dense
|
|
(Qwen3VLForConditionalGeneration) or MoE (Qwen3VLMoeForConditionalGeneration)
|
|
class automatically — so 4B/8B *and* 30B-A3B all work without branching."""
|
|
try:
|
|
from transformers import AutoModelForImageTextToText as _Auto
|
|
return _Auto
|
|
except ImportError: # pragma: no cover - older transformers
|
|
name = model_path.lower()
|
|
is_moe = any(t in name for t in ("a3b", "moe", "30b", "235b"))
|
|
if is_moe:
|
|
from transformers import Qwen3VLMoeForConditionalGeneration as _C
|
|
else:
|
|
from transformers import Qwen3VLForConditionalGeneration as _C
|
|
return _C
|
|
|
|
|
|
def _load_model(model_path: str, precision: str):
|
|
key = (model_path, precision)
|
|
if key in _MODEL_CACHE:
|
|
return _MODEL_CACHE[key]
|
|
|
|
# Imported lazily so the node can be registered even if transformers is old.
|
|
from transformers import AutoProcessor
|
|
|
|
_VLModel = _resolve_vl_class(model_path)
|
|
load_kwargs = dict(device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True)
|
|
|
|
if precision == "nf4":
|
|
# 4-bit (bitsandbytes) — lets the 30B-A3B abliterated MoE fit in ~18 GB on 32 GB.
|
|
from transformers import BitsAndBytesConfig
|
|
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
bnb_4bit_use_double_quant=True,
|
|
)
|
|
elif precision == "fp8":
|
|
# Pre-quantized FP8 weights: let the checkpoint dictate dtype.
|
|
pass
|
|
else:
|
|
load_kwargs["dtype"] = torch.bfloat16 if precision == "bf16" else torch.float16
|
|
|
|
model = _VLModel.from_pretrained(model_path, **load_kwargs)
|
|
model.eval()
|
|
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
|
_ensure_chat_template(processor, model_path)
|
|
_MODEL_CACHE[key] = (model, processor)
|
|
return model, processor
|
|
|
|
|
|
def _ensure_chat_template(processor, model_path: str):
|
|
"""Some ComfyUI-converted checkpoints ship the template as chat_template.jinja
|
|
(or only on the tokenizer), which AutoProcessor doesn't always pick up. Backfill
|
|
processor.chat_template from those sources so apply_chat_template works."""
|
|
if getattr(processor, "chat_template", None):
|
|
return
|
|
for fn in ("chat_template.jinja", "chat_template.json"):
|
|
fp = os.path.join(model_path, fn)
|
|
if os.path.isfile(fp):
|
|
try:
|
|
with open(fp, "r", encoding="utf-8") as f:
|
|
raw = f.read()
|
|
processor.chat_template = json.loads(raw).get("chat_template") if fn.endswith(".json") else raw
|
|
if processor.chat_template:
|
|
return
|
|
except (OSError, ValueError):
|
|
pass
|
|
tok = getattr(processor, "tokenizer", None)
|
|
if tok is not None and getattr(tok, "chat_template", None):
|
|
processor.chat_template = tok.chat_template
|
|
|
|
|
|
def _build_system_prompt(axes: list[str]) -> str:
|
|
axis_lines = "\n".join(
|
|
f' "{a}": {{"score": <0..1>, "ref": "<what IMAGE 1 shows>", "gen": "<what IMAGE 2 shows>"}},'
|
|
for a in axes)
|
|
return (
|
|
"You are a meticulous visual-similarity judge for an image-generation "
|
|
"calibration loop. You are shown two images: IMAGE 1 is the REFERENCE "
|
|
"(the target) and IMAGE 2 is the GENERATED candidate. Judge how closely "
|
|
"the GENERATED image reproduces the REFERENCE.\n\n"
|
|
"For every axis report THREE things:\n"
|
|
" - ref: concretely what IMAGE 1 (reference / target) shows for this axis\n"
|
|
" - gen: concretely what IMAGE 2 (generated) shows for this axis\n"
|
|
" - score: 0..1 closeness, where 0.0 = unrelated, 0.5 = same general "
|
|
"category but clearly different details, 1.0 = near-identical.\n"
|
|
"Use specific concrete values (e.g. ref 'doggy style', gen 'missionary'), "
|
|
"not vague notes. Describe ONLY what you observe — do NOT suggest fixes or "
|
|
"prompt changes; correction is handled by a separate model.\n\n"
|
|
"Reply with STRICT JSON only, no prose, no markdown fences, exactly:\n"
|
|
"{\n"
|
|
' "overall_score": <0..1>,\n'
|
|
' "axes": {\n'
|
|
f"{axis_lines}\n"
|
|
" }\n"
|
|
"}\n"
|
|
"overall_score must be consistent with the per-axis scores. If an axis is "
|
|
"not applicable to either image, set score 1.0 and ref/gen to \"n/a\"."
|
|
)
|
|
|
|
|
|
def _format_chatml_qwenvl(messages):
|
|
"""Manual Qwen-VL ChatML prompt, used when the processor has no chat template
|
|
(e.g. checkpoints converted for ComfyUI that drop chat_template.json). Mirrors
|
|
apply_chat_template: each image -> <|vision_start|><|image_pad|><|vision_end|>,
|
|
which the processor then expands to the right number of image tokens."""
|
|
parts = []
|
|
for msg in messages:
|
|
parts.append(f"<|im_start|>{msg['role']}\n")
|
|
content = msg["content"]
|
|
if isinstance(content, str):
|
|
parts.append(content)
|
|
else:
|
|
for item in content:
|
|
if item.get("type") == "image":
|
|
parts.append("<|vision_start|><|image_pad|><|vision_end|>")
|
|
elif item.get("type") == "text":
|
|
parts.append(item.get("text", ""))
|
|
parts.append("<|im_end|>\n")
|
|
parts.append("<|im_start|>assistant\n")
|
|
return "".join(parts)
|
|
|
|
|
|
def _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature):
|
|
"""Template + forward pass for a chat-message list; returns the decoded string."""
|
|
try:
|
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
except (ValueError, AttributeError):
|
|
# Processor/tokenizer carries no chat template -> build ChatML by hand.
|
|
text = _format_chatml_qwenvl(messages)
|
|
inputs = processor(text=[text], images=images, return_tensors="pt")
|
|
inputs = inputs.to(model.device)
|
|
|
|
gen_kwargs = dict(max_new_tokens=max_new_tokens)
|
|
if temperature and temperature > 0:
|
|
gen_kwargs.update(do_sample=True, temperature=float(temperature))
|
|
else:
|
|
gen_kwargs.update(do_sample=False)
|
|
|
|
with torch.inference_mode():
|
|
out = model.generate(**inputs, **gen_kwargs)
|
|
trimmed = out[:, inputs.input_ids.shape[1]:]
|
|
decoded = processor.batch_decode(trimmed, skip_special_tokens=True)[0]
|
|
return decoded.strip()
|
|
|
|
|
|
def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperature):
|
|
"""Compare pass: ref vs gen -> raw JSON judgement string."""
|
|
messages = [
|
|
{"role": "system", "content": _build_system_prompt(axes)},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "IMAGE 1 = REFERENCE (target):"},
|
|
{"type": "image", "image": ref_pil},
|
|
{"type": "text", "text": "IMAGE 2 = GENERATED candidate:"},
|
|
{"type": "image", "image": gen_pil},
|
|
{"type": "text", "text": "Now return the strict JSON judgement."},
|
|
],
|
|
},
|
|
]
|
|
return _generate_from_messages(model, processor, messages, [ref_pil, gen_pil],
|
|
max_new_tokens, temperature)
|
|
|
|
|
|
def _build_describe_prompt(axes: list[str]) -> str:
|
|
axis_lines = "\n".join(f' "{a}": "<concrete value or n/a>",' for a in axes)
|
|
return (
|
|
"You are describing a REFERENCE image that an image generator must try to "
|
|
"reproduce. Describe ONLY what you observe, concretely, in prompt-ready "
|
|
"phrasing (the words a text-to-image prompt would use).\n\n"
|
|
"Reply with STRICT JSON only, no prose, no markdown fences, exactly:\n"
|
|
"{\n"
|
|
' "caption": "<one detailed paragraph fully describing the image as a generation prompt>",\n'
|
|
' "axes": {\n'
|
|
f"{axis_lines}\n"
|
|
" }\n"
|
|
"}\n"
|
|
"Each axis value is a concrete description of that aspect of the image "
|
|
"(or \"n/a\" if not present). The caption should be directly usable as a prompt."
|
|
)
|
|
|
|
|
|
def _run_describe(model, processor, ref_pil, axes, max_new_tokens, temperature):
|
|
"""Describe pass: reference only -> raw JSON {caption, axes} string."""
|
|
messages = [
|
|
{"role": "system", "content": _build_describe_prompt(axes)},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "Describe this reference image:"},
|
|
{"type": "image", "image": ref_pil},
|
|
{"type": "text", "text": "Return the strict JSON description."},
|
|
],
|
|
},
|
|
]
|
|
return _generate_from_messages(model, processor, messages, [ref_pil],
|
|
max_new_tokens, temperature)
|
|
|
|
|
|
def _parse_json(raw: str) -> dict | None:
|
|
"""Best-effort: pull the first balanced JSON object out of the model output."""
|
|
# Strip code fences if present.
|
|
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
|
|
candidate = fenced.group(1) if fenced else None
|
|
if candidate is None:
|
|
start = raw.find("{")
|
|
if start == -1:
|
|
return None
|
|
depth = 0
|
|
for i in range(start, len(raw)):
|
|
if raw[i] == "{":
|
|
depth += 1
|
|
elif raw[i] == "}":
|
|
depth -= 1
|
|
if depth == 0:
|
|
candidate = raw[start:i + 1]
|
|
break
|
|
if candidate is None:
|
|
return None
|
|
try:
|
|
return json.loads(candidate)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
|
|
|
|
def _merge_swapped(a: dict, b: dict) -> dict:
|
|
"""Average two judgements (normal + order-swapped) to cut position bias."""
|
|
if not b:
|
|
return a
|
|
if not a:
|
|
return b
|
|
out = {"axes": {}}
|
|
out["overall_score"] = round(
|
|
(float(a.get("overall_score", 0)) + float(b.get("overall_score", 0))) / 2.0, 4
|
|
)
|
|
axes = set(a.get("axes", {})) | set(b.get("axes", {}))
|
|
for ax in axes:
|
|
sa = a.get("axes", {}).get(ax, {})
|
|
sb = b.get("axes", {}).get(ax, {})
|
|
score = (float(sa.get("score", 0)) + float(sb.get("score", 0))) / 2.0
|
|
# In pass b the images were swapped, so b.ref describes the generated image
|
|
# and b.gen the reference -> invert b when falling back.
|
|
ref = sa.get("ref") or sb.get("gen") or ""
|
|
gen = sa.get("gen") or sb.get("ref") or ""
|
|
out["axes"][ax] = {"score": round(score, 4), "ref": ref, "gen": gen}
|
|
return out
|
|
|
|
|
|
def _report_base_dir(report_dir: str) -> str:
|
|
if report_dir:
|
|
return report_dir
|
|
try:
|
|
import folder_paths
|
|
return os.path.join(folder_paths.get_output_directory(), "calibrator")
|
|
except Exception:
|
|
return os.path.join(os.path.dirname(os.path.dirname(__file__)), "output", "calibrator")
|
|
|
|
|
|
def _write_report(report_dir, run_tag, overall, merged, diff_analysis, raw_all, prompt_used):
|
|
"""Persist the analysis so the external CLI agent can read it after a queue.
|
|
|
|
Writes a per-run file plus a stable `latest.json` the agent can always poll.
|
|
Returns the per-run file path (or "" on failure)."""
|
|
base = _report_base_dir(report_dir)
|
|
try:
|
|
os.makedirs(base, exist_ok=True)
|
|
except OSError as e:
|
|
print(f"[QwenVLImageJudge] could not create report dir {base}: {e}")
|
|
return ""
|
|
|
|
payload = {
|
|
"run_tag": run_tag,
|
|
"overall_score": round(float(overall), 4),
|
|
"axes": (merged or {}).get("axes", {}),
|
|
"diff_analysis": diff_analysis,
|
|
"prompt_used": prompt_used,
|
|
"raw": raw_all,
|
|
}
|
|
tag = re.sub(r"[^A-Za-z0-9._-]", "_", run_tag) if run_tag else "latest"
|
|
run_path = os.path.join(base, f"calib_{tag}.json")
|
|
for path in (run_path, os.path.join(base, "latest.json")):
|
|
try:
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(payload, f, ensure_ascii=False, indent=2)
|
|
except OSError as e:
|
|
print(f"[QwenVLImageJudge] failed writing report {path}: {e}")
|
|
# A markdown sibling is handy for the agent to read as plain text.
|
|
try:
|
|
md = (f"# Calibration analysis ({tag})\n\n"
|
|
f"**overall_score:** {payload['overall_score']}\n\n"
|
|
f"**prompt_used:**\n\n{prompt_used or '(not provided)'}\n\n"
|
|
f"## per-axis\n\n{diff_analysis}\n")
|
|
with open(os.path.join(base, f"calib_{tag}.md"), "w", encoding="utf-8") as f:
|
|
f.write(md)
|
|
except OSError:
|
|
pass
|
|
return run_path
|
|
|
|
|
|
def _write_describe_report(report_dir, run_tag, caption, axes_spec, raw):
|
|
"""Persist the first-pass description (target spec) for the agent to seed from."""
|
|
base = _report_base_dir(report_dir)
|
|
try:
|
|
os.makedirs(base, exist_ok=True)
|
|
except OSError as e:
|
|
print(f"[QwenVLImageJudge] could not create report dir {base}: {e}")
|
|
return ""
|
|
payload = {
|
|
"mode": "describe",
|
|
"run_tag": run_tag,
|
|
"caption": caption,
|
|
"axes": axes_spec, # per-axis target values -> the agent's initial axis_state
|
|
"raw": raw,
|
|
}
|
|
tag = re.sub(r"[^A-Za-z0-9._-]", "_", run_tag) if run_tag else "describe"
|
|
run_path = os.path.join(base, f"calib_{tag}.json")
|
|
for path in (run_path, os.path.join(base, "latest.json")):
|
|
try:
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(payload, f, ensure_ascii=False, indent=2)
|
|
except OSError as e:
|
|
print(f"[QwenVLImageJudge] failed writing report {path}: {e}")
|
|
return run_path
|
|
|
|
|
|
class QwenVLImageJudge:
|
|
"""ComfyUI node: describe a reference, or score how close a generated image is to it."""
|
|
|
|
CATEGORY = "prompt_calibrator"
|
|
FUNCTION = "judge"
|
|
RETURN_TYPES = ("FLOAT", "STRING", "STRING", "STRING", "STRING")
|
|
RETURN_NAMES = ("overall_score", "axis_scores_json", "analysis", "raw", "report_path")
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"reference_image": ("IMAGE",),
|
|
# describe = reference only -> target description (first pass, seeds the
|
|
# initial prompt). compare = ref vs generated -> per-axis scoring.
|
|
"mode": (["compare", "describe"], {"default": "compare"}),
|
|
"model_path": ("STRING", {"default": DEFAULT_MODEL_PATH}),
|
|
"precision": (["bf16", "fp16", "fp8", "nf4"], {"default": "bf16"}),
|
|
"axes": ("STRING", {"default": DEFAULT_AXES, "multiline": True}),
|
|
"max_new_tokens": ("INT", {"default": 1024, "min": 64, "max": 4096}),
|
|
"temperature": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.5, "step": 0.05}),
|
|
"swap_eval": ("BOOLEAN", {"default": True}),
|
|
},
|
|
"optional": {
|
|
"generated_image": ("IMAGE",), # required for compare, ignored for describe
|
|
"keep_loaded": ("BOOLEAN", {"default": True}),
|
|
"auto_download": ("BOOLEAN", {"default": True}),
|
|
# The agent reads the analysis from these files after each queue.
|
|
"report_dir": ("STRING", {"default": ""}),
|
|
"run_tag": ("STRING", {"default": ""}),
|
|
"prompt_used": ("STRING", {"default": "", "multiline": True}),
|
|
},
|
|
}
|
|
|
|
def judge(self, reference_image, mode, model_path, precision, axes,
|
|
max_new_tokens, temperature, swap_eval, generated_image=None,
|
|
keep_loaded=True, auto_download=True,
|
|
report_dir="", run_tag="", prompt_used=""):
|
|
axis_list = [a.strip() for a in re.split(r"[,\n]", axes) if a.strip()]
|
|
if not axis_list:
|
|
axis_list = [a.strip() for a in DEFAULT_AXES.split(",")]
|
|
|
|
try:
|
|
resolved_path = _resolve_model_source(model_path, auto_download)
|
|
except Exception as e: # missing model / download failure -> surface as score 0
|
|
msg = str(e)
|
|
print(msg)
|
|
return (0.0, "{}", msg, msg, "")
|
|
|
|
ref_pil = _tensor_to_pil(reference_image)
|
|
model, processor = _load_model(resolved_path, precision)
|
|
|
|
if mode == "describe":
|
|
return self._describe(model, processor, ref_pil, axis_list, max_new_tokens,
|
|
temperature, resolved_path, precision, keep_loaded,
|
|
report_dir, run_tag)
|
|
|
|
if generated_image is None:
|
|
msg = "[QwenVLImageJudge] compare mode needs generated_image (or set mode=describe)."
|
|
print(msg)
|
|
return (0.0, "{}", msg, msg, "")
|
|
gen_pil = _tensor_to_pil(generated_image)
|
|
|
|
raw1 = _run_once(model, processor, ref_pil, gen_pil, axis_list, max_new_tokens, temperature)
|
|
parsed1 = _parse_json(raw1) or {}
|
|
|
|
raw_all = raw1
|
|
merged = parsed1
|
|
if swap_eval:
|
|
# Swap which image is called REFERENCE to average out position bias.
|
|
raw2 = _run_once(model, processor, gen_pil, ref_pil, axis_list, max_new_tokens, temperature)
|
|
parsed2 = _parse_json(raw2) or {}
|
|
merged = _merge_swapped(parsed1, parsed2)
|
|
raw_all = raw1 + "\n--- SWAPPED ---\n" + raw2
|
|
|
|
if not keep_loaded:
|
|
_MODEL_CACHE.pop((resolved_path, precision), None)
|
|
del model
|
|
torch.cuda.empty_cache()
|
|
|
|
overall = float(merged.get("overall_score", 0.0)) if merged else 0.0
|
|
axis_scores = json.dumps(merged.get("axes", {}), ensure_ascii=False, indent=2) if merged else "{}"
|
|
|
|
# Human/controller-readable diff summary, worst axes first (biggest gap).
|
|
items = sorted((merged.get("axes", {}) if merged else {}).items(),
|
|
key=lambda kv: float(kv[1].get("score", 0)))
|
|
diff_lines = [
|
|
f"- {ax}: {info.get('score', 0):.2f} ref:[{info.get('ref', '')}] gen:[{info.get('gen', '')}]"
|
|
for ax, info in items
|
|
]
|
|
diff_analysis = "\n".join(diff_lines) if diff_lines else "(no parseable judgement)"
|
|
|
|
report_path = _write_report(
|
|
report_dir, run_tag, overall, merged, diff_analysis, raw_all, prompt_used)
|
|
|
|
return (round(overall, 4), axis_scores, diff_analysis, raw_all, report_path)
|
|
|
|
def _describe(self, model, processor, ref_pil, axis_list, max_new_tokens,
|
|
temperature, resolved_path, precision, keep_loaded, report_dir, run_tag):
|
|
"""First pass: describe the reference image the generator must reproduce.
|
|
Outputs the target spec (per-axis values) + a prompt-ready caption."""
|
|
raw = _run_describe(model, processor, ref_pil, axis_list, max_new_tokens, temperature)
|
|
parsed = _parse_json(raw) or {}
|
|
|
|
if not keep_loaded:
|
|
_MODEL_CACHE.pop((resolved_path, precision), None)
|
|
del model
|
|
torch.cuda.empty_cache()
|
|
|
|
caption = (parsed.get("caption") or "").strip()
|
|
axes_spec = parsed.get("axes", {}) if isinstance(parsed.get("axes"), dict) else {}
|
|
axis_scores = json.dumps(axes_spec, ensure_ascii=False, indent=2)
|
|
analysis = caption if caption else "(no parseable description)"
|
|
|
|
report_path = _write_describe_report(report_dir, run_tag, caption, axes_spec, raw)
|
|
# overall_score is n/a in describe mode; return 1.0 as a neutral placeholder.
|
|
return (1.0, axis_scores, analysis, raw, report_path)
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {"QwenVLImageJudge": QwenVLImageJudge}
|
|
NODE_DISPLAY_NAME_MAPPINGS = {"QwenVLImageJudge": "Qwen3-VL Image Judge (Calibrator)"}
|