""" Qwen3-VL Image-Similarity Judge node for ComfyUI. The "vllm node" of the Prompt Calibrator. It takes a REFERENCE image and a GENERATED image and asks a local Qwen3-VL model how close the generated image is to the reference, returning a machine-readable score + per-axis difference analysis that the calibration controller can act on. Reuses the standard transformers Qwen3-VL plumbing (the same approach used by ComfyUI-QwenVL-MultiImage / ComfyUI_Qwen3-VL-Instruct), but forces strict JSON output so the result is usable by an automated loop rather than a human reader. Default model is the locally converted huihui-ai Qwen3-VL-4B-Instruct *abliterated* (uncensored) weights, which do not refuse to analyze adult imagery. """ from __future__ import annotations import json import os import re import numpy as np import torch from PIL import Image # Default to the model already converted on this machine (works out of the box). DEFAULT_MODEL_PATH = "/media/p5/qwen3vl_4b_abliterated_comfy_convert/hf_bf16" DEFAULT_MODEL_PATH_FP8 = "/media/p5/qwen3vl_4b_abliterated_comfy_convert/hf_fp8" # Recommended abliterated upgrades for the RTX 5090 32 GB (latest Qwen VL family). # Download with: hf download --local-dir , then point model_path at it. RECOMMENDED_MODELS = { # Best judge that fits 32 GB. MoE (3B active -> fast). Use precision="nf4" # (~18 GB) on 32 GB, or the GGUF quants via a GGUF node. transformers class: # Qwen3VLMoeForConditionalGeneration (auto-detected below). "30b-a3b": "huihui-ai/Huihui-Qwen3-VL-30B-A3B-Instruct-abliterated", # Easy middle ground: bf16 ~17 GB, no quantization hassle, drop-in here. "8b": "huihui-ai/Huihui-Qwen3-VL-8B-Instruct-abliterated", # Lightweight, already local. "4b": "huihui-ai/Huihui-Qwen3-VL-4B-Instruct-abliterated", } # Difference axes the judge scores. Granular by default so the comparison is # discriminative for explicit/adult imagery (where coarse axes blur the differences # that matter). Fully configurable on the node — trim or extend per use case. # subject_count number of people # gender_mix gender composition (e.g. 1F, 2F1M) # body_type physique / build / proportions per subject # distinctive_features tattoos / piercings / marks (identity anchors) # age_appearance apparent age # ethnicity_skin ethnicity / skin tone # hair length, color, style # clothing_state degree of undress + specific garments # sexual_act the act / activity being performed # position sexual position / arrangement of bodies # penetration type & visibility of penetration # explicitness how graphic / genital visibility level # body_contact who contacts whom; interaction between subjects # pose non-act body positioning # facial_expression face / affect # gaze eye contact / look direction # framing shot type / crop (close-up <-> full body) # camera_angle POV / angle / perspective # scene location / setting / background # lighting_color palette, lighting, color grade # art_style photoreal vs anime/illustrated, render style DEFAULT_AXES = ( "subject_count, gender_mix, body_type, distinctive_features, age_appearance, " "ethnicity_skin, hair, clothing_state, sexual_act, position, penetration, " "explicitness, body_contact, pose, facial_expression, gaze, framing, " "camera_angle, scene, lighting_color, art_style" ) # Cache loaded (model, processor) keyed by (path, precision) so the loop does not # reload weights every iteration. _MODEL_CACHE: dict[tuple[str, str], tuple] = {} def _looks_like_repo_id(s: str) -> bool: """'org/name' HF repo id, not an absolute/local filesystem path.""" return ("/" in s) and (" " not in s) and (not os.path.isabs(s)) and (not s.startswith(".")) def _download_target_dir(repo_id: str) -> str: """Where to put downloaded weights — prefer ComfyUI's models/prompt_generator/.""" name = repo_id.split("/")[-1] try: import folder_paths # available when running inside ComfyUI base = os.path.join(folder_paths.models_dir, "prompt_generator") except Exception: base = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models") return os.path.join(base, name) def _resolve_model_source(model_path: str, auto_download: bool) -> str: """Turn model_path (local dir | short alias | HF repo id) into a local dir. Downloads from the Hub on first use if needed (and auto_download is on). """ # Short alias -> full repo id (e.g. "30b-a3b", "8b", "4b"). if model_path in RECOMMENDED_MODELS: model_path = RECOMMENDED_MODELS[model_path] if os.path.isdir(model_path): return model_path if _looks_like_repo_id(model_path): target = _download_target_dir(model_path) # Already downloaded? (a config.json is enough to trust the local copy) if os.path.isfile(os.path.join(target, "config.json")): return target if not auto_download: raise FileNotFoundError( f"[QwenVLImageJudge] '{model_path}' is not downloaded and auto_download is off. " f"Enable auto_download or pre-fetch it to {target}.") from huggingface_hub import snapshot_download print(f"[QwenVLImageJudge] downloading {model_path} -> {target} (first run only, may be large)...") local = snapshot_download( repo_id=model_path, local_dir=target, # weights + processor/tokenizer/config/template; skip duplicate GGUF/onnx blobs. allow_patterns=["*.json", "*.jinja", "*.safetensors", "*.txt", "*.model", "merges.txt", "*.py"], ) print(f"[QwenVLImageJudge] download complete: {local}") return local # A local path that simply doesn't exist. raise FileNotFoundError( f"[QwenVLImageJudge] model_path not found: {model_path}. " f"Use a local checkpoint dir, a HF repo id (org/name), or an alias " f"({', '.join(RECOMMENDED_MODELS)}).") def _tensor_to_pil(image: "torch.Tensor") -> Image.Image: """ComfyUI IMAGE tensor (B,H,W,C float 0..1) -> first-frame PIL.Image (RGB).""" if image is None: raise ValueError("Judge node received an empty image input.") arr = image if hasattr(arr, "detach"): arr = arr.detach().cpu().numpy() arr = np.asarray(arr) if arr.ndim == 4: # batch -> take first frame arr = arr[0] arr = np.clip(arr * 255.0, 0, 255).astype(np.uint8) if arr.ndim == 2: arr = np.stack([arr] * 3, axis=-1) if arr.shape[-1] == 4: # drop alpha arr = arr[..., :3] return Image.fromarray(arr, mode="RGB") def _resolve_vl_class(model_path: str): """Pick the right transformers class. AutoModelForImageTextToText reads the checkpoint's `architectures` and instantiates the correct dense (Qwen3VLForConditionalGeneration) or MoE (Qwen3VLMoeForConditionalGeneration) class automatically — so 4B/8B *and* 30B-A3B all work without branching.""" try: from transformers import AutoModelForImageTextToText as _Auto return _Auto except ImportError: # pragma: no cover - older transformers name = model_path.lower() is_moe = any(t in name for t in ("a3b", "moe", "30b", "235b")) if is_moe: from transformers import Qwen3VLMoeForConditionalGeneration as _C else: from transformers import Qwen3VLForConditionalGeneration as _C return _C def _load_model(model_path: str, precision: str): key = (model_path, precision) if key in _MODEL_CACHE: return _MODEL_CACHE[key] # Imported lazily so the node can be registered even if transformers is old. from transformers import AutoProcessor _VLModel = _resolve_vl_class(model_path) load_kwargs = dict(device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True) if precision == "nf4": # 4-bit (bitsandbytes) — lets the 30B-A3B abliterated MoE fit in ~18 GB on 32 GB. from transformers import BitsAndBytesConfig load_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) elif precision == "fp8": # Pre-quantized FP8 weights: let the checkpoint dictate dtype. pass else: load_kwargs["dtype"] = torch.bfloat16 if precision == "bf16" else torch.float16 model = _VLModel.from_pretrained(model_path, **load_kwargs) model.eval() processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) _ensure_chat_template(processor, model_path) _MODEL_CACHE[key] = (model, processor) return model, processor def _ensure_chat_template(processor, model_path: str): """Some ComfyUI-converted checkpoints ship the template as chat_template.jinja (or only on the tokenizer), which AutoProcessor doesn't always pick up. Backfill processor.chat_template from those sources so apply_chat_template works.""" if getattr(processor, "chat_template", None): return for fn in ("chat_template.jinja", "chat_template.json"): fp = os.path.join(model_path, fn) if os.path.isfile(fp): try: with open(fp, "r", encoding="utf-8") as f: raw = f.read() processor.chat_template = json.loads(raw).get("chat_template") if fn.endswith(".json") else raw if processor.chat_template: return except (OSError, ValueError): pass tok = getattr(processor, "tokenizer", None) if tok is not None and getattr(tok, "chat_template", None): processor.chat_template = tok.chat_template def _build_system_prompt(axes: list[str]) -> str: axis_lines = "\n".join( f' "{a}": {{"score": <0..1>, "ref": "", "gen": ""}},' for a in axes) return ( "You are a meticulous visual-similarity judge for an image-generation " "calibration loop. You are shown two images: IMAGE 1 is the REFERENCE " "(the target) and IMAGE 2 is the GENERATED candidate. Judge how closely " "the GENERATED image reproduces the REFERENCE.\n\n" "For every axis report THREE things:\n" " - ref: concretely what IMAGE 1 (reference / target) shows for this axis\n" " - gen: concretely what IMAGE 2 (generated) shows for this axis\n" " - score: 0..1 closeness, where 0.0 = unrelated, 0.5 = same general " "category but clearly different details, 1.0 = near-identical.\n" "Use specific concrete values (e.g. ref 'doggy style', gen 'missionary'), " "not vague notes. Describe ONLY what you observe — do NOT suggest fixes or " "prompt changes; correction is handled by a separate model.\n\n" "Reply with STRICT JSON only, no prose, no markdown fences, exactly:\n" "{\n" ' "overall_score": <0..1>,\n' ' "axes": {\n' f"{axis_lines}\n" " }\n" "}\n" "overall_score must be consistent with the per-axis scores. If an axis is " "not applicable to either image, set score 1.0 and ref/gen to \"n/a\"." ) def _format_chatml_qwenvl(messages): """Manual Qwen-VL ChatML prompt, used when the processor has no chat template (e.g. checkpoints converted for ComfyUI that drop chat_template.json). Mirrors apply_chat_template: each image -> <|vision_start|><|image_pad|><|vision_end|>, which the processor then expands to the right number of image tokens.""" parts = [] for msg in messages: parts.append(f"<|im_start|>{msg['role']}\n") content = msg["content"] if isinstance(content, str): parts.append(content) else: for item in content: if item.get("type") == "image": parts.append("<|vision_start|><|image_pad|><|vision_end|>") elif item.get("type") == "text": parts.append(item.get("text", "")) parts.append("<|im_end|>\n") parts.append("<|im_start|>assistant\n") return "".join(parts) def _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature): """Template + forward pass for a chat-message list; returns the decoded string.""" try: text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) except (ValueError, AttributeError): # Processor/tokenizer carries no chat template -> build ChatML by hand. text = _format_chatml_qwenvl(messages) inputs = processor(text=[text], images=images, return_tensors="pt") inputs = inputs.to(model.device) gen_kwargs = dict(max_new_tokens=max_new_tokens) if temperature and temperature > 0: gen_kwargs.update(do_sample=True, temperature=float(temperature)) else: gen_kwargs.update(do_sample=False) with torch.inference_mode(): out = model.generate(**inputs, **gen_kwargs) trimmed = out[:, inputs.input_ids.shape[1]:] decoded = processor.batch_decode(trimmed, skip_special_tokens=True)[0] return decoded.strip() def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperature): """Compare pass: ref vs gen -> raw JSON judgement string.""" messages = [ {"role": "system", "content": _build_system_prompt(axes)}, { "role": "user", "content": [ {"type": "text", "text": "IMAGE 1 = REFERENCE (target):"}, {"type": "image", "image": ref_pil}, {"type": "text", "text": "IMAGE 2 = GENERATED candidate:"}, {"type": "image", "image": gen_pil}, {"type": "text", "text": "Now return the strict JSON judgement."}, ], }, ] return _generate_from_messages(model, processor, messages, [ref_pil, gen_pil], max_new_tokens, temperature) def _build_describe_prompt(axes: list[str]) -> str: axis_lines = "\n".join(f' "{a}": "",' for a in axes) return ( "You are describing a REFERENCE image that an image generator must try to " "reproduce. Describe ONLY what you observe, concretely, in prompt-ready " "phrasing (the words a text-to-image prompt would use).\n\n" "Reply with STRICT JSON only, no prose, no markdown fences, exactly:\n" "{\n" ' "caption": "",\n' ' "axes": {\n' f"{axis_lines}\n" " }\n" "}\n" "Each axis value is a concrete description of that aspect of the image " "(or \"n/a\" if not present). The caption should be directly usable as a prompt." ) def _run_describe(model, processor, ref_pil, axes, max_new_tokens, temperature): """Describe pass: reference only -> raw JSON {caption, axes} string.""" messages = [ {"role": "system", "content": _build_describe_prompt(axes)}, { "role": "user", "content": [ {"type": "text", "text": "Describe this reference image:"}, {"type": "image", "image": ref_pil}, {"type": "text", "text": "Return the strict JSON description."}, ], }, ] return _generate_from_messages(model, processor, messages, [ref_pil], max_new_tokens, temperature) def _parse_json(raw: str) -> dict | None: """Best-effort: pull the first balanced JSON object out of the model output.""" # Strip code fences if present. fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL) candidate = fenced.group(1) if fenced else None if candidate is None: start = raw.find("{") if start == -1: return None depth = 0 for i in range(start, len(raw)): if raw[i] == "{": depth += 1 elif raw[i] == "}": depth -= 1 if depth == 0: candidate = raw[start:i + 1] break if candidate is None: return None try: return json.loads(candidate) except json.JSONDecodeError: return None def _merge_swapped(a: dict, b: dict) -> dict: """Average two judgements (normal + order-swapped) to cut position bias.""" if not b: return a if not a: return b out = {"axes": {}} out["overall_score"] = round( (float(a.get("overall_score", 0)) + float(b.get("overall_score", 0))) / 2.0, 4 ) axes = set(a.get("axes", {})) | set(b.get("axes", {})) for ax in axes: sa = a.get("axes", {}).get(ax, {}) sb = b.get("axes", {}).get(ax, {}) score = (float(sa.get("score", 0)) + float(sb.get("score", 0))) / 2.0 # In pass b the images were swapped, so b.ref describes the generated image # and b.gen the reference -> invert b when falling back. ref = sa.get("ref") or sb.get("gen") or "" gen = sa.get("gen") or sb.get("ref") or "" out["axes"][ax] = {"score": round(score, 4), "ref": ref, "gen": gen} return out def _report_base_dir(report_dir: str) -> str: if report_dir: return report_dir try: import folder_paths return os.path.join(folder_paths.get_output_directory(), "calibrator") except Exception: return os.path.join(os.path.dirname(os.path.dirname(__file__)), "output", "calibrator") def _write_report(report_dir, run_tag, overall, merged, diff_analysis, raw_all, prompt_used): """Persist the analysis so the external CLI agent can read it after a queue. Writes a per-run file plus a stable `latest.json` the agent can always poll. Returns the per-run file path (or "" on failure).""" base = _report_base_dir(report_dir) try: os.makedirs(base, exist_ok=True) except OSError as e: print(f"[QwenVLImageJudge] could not create report dir {base}: {e}") return "" payload = { "run_tag": run_tag, "overall_score": round(float(overall), 4), "axes": (merged or {}).get("axes", {}), "diff_analysis": diff_analysis, "prompt_used": prompt_used, "raw": raw_all, } tag = re.sub(r"[^A-Za-z0-9._-]", "_", run_tag) if run_tag else "latest" run_path = os.path.join(base, f"calib_{tag}.json") for path in (run_path, os.path.join(base, "latest.json")): try: with open(path, "w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False, indent=2) except OSError as e: print(f"[QwenVLImageJudge] failed writing report {path}: {e}") # A markdown sibling is handy for the agent to read as plain text. try: md = (f"# Calibration analysis ({tag})\n\n" f"**overall_score:** {payload['overall_score']}\n\n" f"**prompt_used:**\n\n{prompt_used or '(not provided)'}\n\n" f"## per-axis\n\n{diff_analysis}\n") with open(os.path.join(base, f"calib_{tag}.md"), "w", encoding="utf-8") as f: f.write(md) except OSError: pass return run_path def _write_describe_report(report_dir, run_tag, caption, axes_spec, raw): """Persist the first-pass description (target spec) for the agent to seed from.""" base = _report_base_dir(report_dir) try: os.makedirs(base, exist_ok=True) except OSError as e: print(f"[QwenVLImageJudge] could not create report dir {base}: {e}") return "" payload = { "mode": "describe", "run_tag": run_tag, "caption": caption, "axes": axes_spec, # per-axis target values -> the agent's initial axis_state "raw": raw, } tag = re.sub(r"[^A-Za-z0-9._-]", "_", run_tag) if run_tag else "describe" run_path = os.path.join(base, f"calib_{tag}.json") for path in (run_path, os.path.join(base, "latest.json")): try: with open(path, "w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False, indent=2) except OSError as e: print(f"[QwenVLImageJudge] failed writing report {path}: {e}") return run_path class QwenVLImageJudge: """ComfyUI node: describe a reference, or score how close a generated image is to it.""" CATEGORY = "prompt_calibrator" FUNCTION = "judge" RETURN_TYPES = ("FLOAT", "STRING", "STRING", "STRING", "STRING") RETURN_NAMES = ("overall_score", "axis_scores_json", "analysis", "raw", "report_path") @classmethod def INPUT_TYPES(cls): return { "required": { "reference_image": ("IMAGE",), # describe = reference only -> target description (first pass, seeds the # initial prompt). compare = ref vs generated -> per-axis scoring. "mode": (["compare", "describe"], {"default": "compare"}), "model_path": ("STRING", {"default": DEFAULT_MODEL_PATH}), "precision": (["bf16", "fp16", "fp8", "nf4"], {"default": "bf16"}), "axes": ("STRING", {"default": DEFAULT_AXES, "multiline": True}), "max_new_tokens": ("INT", {"default": 1024, "min": 64, "max": 4096}), "temperature": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.5, "step": 0.05}), "swap_eval": ("BOOLEAN", {"default": True}), }, "optional": { "generated_image": ("IMAGE",), # required for compare, ignored for describe "keep_loaded": ("BOOLEAN", {"default": True}), "auto_download": ("BOOLEAN", {"default": True}), # The agent reads the analysis from these files after each queue. "report_dir": ("STRING", {"default": ""}), "run_tag": ("STRING", {"default": ""}), "prompt_used": ("STRING", {"default": "", "multiline": True}), }, } def judge(self, reference_image, mode, model_path, precision, axes, max_new_tokens, temperature, swap_eval, generated_image=None, keep_loaded=True, auto_download=True, report_dir="", run_tag="", prompt_used=""): axis_list = [a.strip() for a in re.split(r"[,\n]", axes) if a.strip()] if not axis_list: axis_list = [a.strip() for a in DEFAULT_AXES.split(",")] try: resolved_path = _resolve_model_source(model_path, auto_download) except Exception as e: # missing model / download failure -> surface as score 0 msg = str(e) print(msg) return (0.0, "{}", msg, msg, "") ref_pil = _tensor_to_pil(reference_image) model, processor = _load_model(resolved_path, precision) if mode == "describe": return self._describe(model, processor, ref_pil, axis_list, max_new_tokens, temperature, resolved_path, precision, keep_loaded, report_dir, run_tag) if generated_image is None: msg = "[QwenVLImageJudge] compare mode needs generated_image (or set mode=describe)." print(msg) return (0.0, "{}", msg, msg, "") gen_pil = _tensor_to_pil(generated_image) raw1 = _run_once(model, processor, ref_pil, gen_pil, axis_list, max_new_tokens, temperature) parsed1 = _parse_json(raw1) or {} raw_all = raw1 merged = parsed1 if swap_eval: # Swap which image is called REFERENCE to average out position bias. raw2 = _run_once(model, processor, gen_pil, ref_pil, axis_list, max_new_tokens, temperature) parsed2 = _parse_json(raw2) or {} merged = _merge_swapped(parsed1, parsed2) raw_all = raw1 + "\n--- SWAPPED ---\n" + raw2 if not keep_loaded: _MODEL_CACHE.pop((resolved_path, precision), None) del model torch.cuda.empty_cache() overall = float(merged.get("overall_score", 0.0)) if merged else 0.0 axis_scores = json.dumps(merged.get("axes", {}), ensure_ascii=False, indent=2) if merged else "{}" # Human/controller-readable diff summary, worst axes first (biggest gap). items = sorted((merged.get("axes", {}) if merged else {}).items(), key=lambda kv: float(kv[1].get("score", 0))) diff_lines = [ f"- {ax}: {info.get('score', 0):.2f} ref:[{info.get('ref', '')}] gen:[{info.get('gen', '')}]" for ax, info in items ] diff_analysis = "\n".join(diff_lines) if diff_lines else "(no parseable judgement)" report_path = _write_report( report_dir, run_tag, overall, merged, diff_analysis, raw_all, prompt_used) return (round(overall, 4), axis_scores, diff_analysis, raw_all, report_path) def _describe(self, model, processor, ref_pil, axis_list, max_new_tokens, temperature, resolved_path, precision, keep_loaded, report_dir, run_tag): """First pass: describe the reference image the generator must reproduce. Outputs the target spec (per-axis values) + a prompt-ready caption.""" raw = _run_describe(model, processor, ref_pil, axis_list, max_new_tokens, temperature) parsed = _parse_json(raw) or {} if not keep_loaded: _MODEL_CACHE.pop((resolved_path, precision), None) del model torch.cuda.empty_cache() caption = (parsed.get("caption") or "").strip() axes_spec = parsed.get("axes", {}) if isinstance(parsed.get("axes"), dict) else {} axis_scores = json.dumps(axes_spec, ensure_ascii=False, indent=2) analysis = caption if caption else "(no parseable description)" report_path = _write_describe_report(report_dir, run_tag, caption, axes_spec, raw) # overall_score is n/a in describe mode; return 1.0 as a neutral placeholder. return (1.0, axis_scores, analysis, raw, report_path) NODE_CLASS_MAPPINGS = {"QwenVLImageJudge": QwenVLImageJudge} NODE_DISPLAY_NAME_MAPPINGS = {"QwenVLImageJudge": "Qwen3-VL Image Judge (Calibrator)"}