diff --git a/README.md b/README.md index 6ba91bc..893abca 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,8 @@ can act on it. | `precision` | bf16 / fp8 / nf4 | bf16 | **the quant** — applies to the selected model (VRAM table below) | | `model_path` | STRING | "" (empty) | **manual override** of the dropdown — local dir, HF repo id, or alias (`8b`/`30b-a3b`/`3.5-9b`/`3.6-27b`/`3.6-35b`). Empty = use `model_select` | | `axes` | STRING **input** | — | (socket) optional override of the profile's axis set; wire a text node or leave unconnected to use `profile` | -| `max_new_tokens` | INT | 2048 | raise it if a reasoning model (Qwen3.5/3.6) gets cut off before finishing | +| `max_new_tokens` | INT | 3072 | reasoning models (Qwen3.5/3.6) need room; raise it if the verdict gets cut off | +| `enable_thinking` | BOOL | true | let the model reason before judging. **Keep on for accurate verdicts** — off makes reasoning models rubber-stamp `match`. Off is faster | | `temperature` | FLOAT | 0.0 | 0 = greedy/repeatable | | `swap_eval` | BOOL | true | run twice with images swapped, average → cuts position bias | | `keep_loaded` | BOOL | true | cache weights across loop iterations | diff --git a/nodes/qwen_judge.py b/nodes/qwen_judge.py index 27c5ad9..b387ff2 100644 --- a/nodes/qwen_judge.py +++ b/nodes/qwen_judge.py @@ -336,26 +336,35 @@ def _axis_definition_block(axes: list[str]) -> str: return "\n".join(f" - {a}: {AXIS_DEFS.get(a, 'as named')}" for a in axes) -def _build_system_prompt(axes: list[str], reference_description: str = "") -> str: +def _build_system_prompt(axes: list[str], reference_description: str = "", think: bool = True) -> str: axis_lines = "\n".join( f' "{a}": {{"verdict": "match|partial|mismatch", "ref": "", "gen": ""}},' for a in axes) verdict_rule = ( - " - verdict: 'match' if ref and gen are the same; 'mismatch' if they are " - "opposite or clearly different (e.g. 'on top' vs 'on bottom', 'doggy' vs " - "'cowgirl', 'short' vs 'long', 'eyes closed' vs 'at camera'); 'partial' ONLY " - "for a genuine middle ground (same category, minor difference). Do NOT default " - "to 'partial' — if the values are identical use 'match', if clearly different " - "use 'mismatch'.\n") - tail = ( - "Output ONLY the JSON object — no reasoning, no step-by-step analysis, no " - "markdown, no commentary. Do NOT think out loud. Your entire reply must start " - "with '{' and end with '}', exactly:\n" - "{\n" - ' "axes": {\n' - f"{axis_lines}\n" - " }\n" - "}\n") + " - verdict: COMPARE ref vs gen carefully. 'match' only if they are the same; " + "'mismatch' if opposite or clearly different (e.g. 'on top' vs 'on bottom', " + "'short' vs 'long', 'brown' vs 'blonde', 'eyes closed' vs 'eyes open'); 'partial' " + "for same category with a clear difference. Do NOT lazily mark everything 'match' " + "— if the words differ, it is NOT a match.\n") + if think: + tail = ( + "Examine each axis and decide its verdict by actually comparing ref and gen. " + "You may reason first. END your reply with the result for every axis as a JSON " + "object (or a per-axis list with ref/gen/verdict), schema:\n" + "{\n" + ' "axes": {\n' + f"{axis_lines}\n" + " }\n" + "}\n") + else: + tail = ( + "Output ONLY the JSON object — no prose, no markdown. Start with '{' end with " + "'}', exactly:\n" + "{\n" + ' "axes": {\n' + f"{axis_lines}\n" + " }\n" + "}\n") if reference_description.strip(): # Anchored mode: the reference is a fixed canonical description (text), only the @@ -417,12 +426,14 @@ def _format_chatml_qwenvl(messages): return "".join(parts) -def _apply_template(processor, messages): - """apply_chat_template with thinking disabled (Qwen3.5/3.6 are reasoning models that - otherwise 'think out loud' in prose and never reach the JSON). Falls back gracefully.""" +def _apply_template(processor, messages, think=True): + """apply_chat_template, optionally toggling reasoning. Reasoning models (Qwen3.5/3.6) + judge verdicts far better WITH thinking on (off -> they rubber-stamp 'match'); the + markdown fallback parser reads the reasoned per-axis output. Set think=False for a + faster, JSON-only pass. Falls back to a hand-built ChatML prompt if no template.""" try: return processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) + messages, tokenize=False, add_generation_prompt=True, enable_thinking=think) except TypeError: pass # template doesn't accept enable_thinking except (ValueError, AttributeError): @@ -433,9 +444,9 @@ def _apply_template(processor, messages): return _format_chatml_qwenvl(messages) -def _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature): +def _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature, think=True): """Template + forward pass for a chat-message list; returns the decoded string.""" - text = _apply_template(processor, messages) + text = _apply_template(processor, messages, think) inputs = processor(text=[text], images=images, return_tensors="pt") inputs = inputs.to(model.device) @@ -454,10 +465,10 @@ def _generate_from_messages(model, processor, messages, images, max_new_tokens, return decoded.strip() -def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperature): - """Compare pass: ref vs gen -> raw JSON judgement string.""" +def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperature, think=True): + """Compare pass: ref vs gen -> raw judgement string (JSON or reasoned prose).""" messages = [ - {"role": "system", "content": _build_system_prompt(axes)}, + {"role": "system", "content": _build_system_prompt(axes, think=think)}, { "role": "user", "content": [ @@ -465,29 +476,30 @@ def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperat {"type": "image", "image": ref_pil}, {"type": "text", "text": "IMAGE 2 = GENERATED candidate:"}, {"type": "image", "image": gen_pil}, - {"type": "text", "text": "Now return the strict JSON judgement."}, + {"type": "text", "text": "Now judge every axis."}, ], }, ] return _generate_from_messages(model, processor, messages, [ref_pil, gen_pil], - max_new_tokens, temperature) + max_new_tokens, temperature, think) -def _run_anchored(model, processor, gen_pil, axes, max_new_tokens, temperature, reference_description): +def _run_anchored(model, processor, gen_pil, axes, max_new_tokens, temperature, + reference_description, think=True): """Anchored compare: fixed canonical reference text + one generated image.""" messages = [ - {"role": "system", "content": _build_system_prompt(axes, reference_description)}, + {"role": "system", "content": _build_system_prompt(axes, reference_description, think=think)}, { "role": "user", "content": [ {"type": "text", "text": "GENERATED candidate image:"}, {"type": "image", "image": gen_pil}, - {"type": "text", "text": "Compare it to the reference description and return the strict JSON."}, + {"type": "text", "text": "Compare it to the reference description and judge every axis."}, ], }, ] return _generate_from_messages(model, processor, messages, [gen_pil], - max_new_tokens, temperature) + max_new_tokens, temperature, think) def _build_describe_prompt(axes: list[str]) -> str: @@ -515,7 +527,7 @@ def _build_describe_prompt(axes: list[str]) -> str: ) -def _run_chat(model, processor, images, system_prompt, user_prompt, max_new_tokens, temperature): +def _run_chat(model, processor, images, system_prompt, user_prompt, max_new_tokens, temperature, think=True): """General VLM pass: your own system/user prompt over the image(s) -> raw text.""" content = [{"type": "image", "image": img} for img in images] content.append({"type": "text", "text": user_prompt or "Describe this image."}) @@ -523,11 +535,11 @@ def _run_chat(model, processor, images, system_prompt, user_prompt, max_new_toke if system_prompt.strip(): messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": content}) - return _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature) + return _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature, think) -def _run_describe(model, processor, ref_pil, axes, max_new_tokens, temperature): - """Describe pass: reference only -> raw JSON {caption, axes} string.""" +def _run_describe(model, processor, ref_pil, axes, max_new_tokens, temperature, think=True): + """Describe pass: reference only -> raw {description, axes} (JSON or reasoned prose).""" messages = [ {"role": "system", "content": _build_describe_prompt(axes)}, { @@ -535,38 +547,41 @@ def _run_describe(model, processor, ref_pil, axes, max_new_tokens, temperature): "content": [ {"type": "text", "text": "Describe this reference image:"}, {"type": "image", "image": ref_pil}, - {"type": "text", "text": "Return the strict JSON description."}, + {"type": "text", "text": "Give the full description."}, ], }, ] return _generate_from_messages(model, processor, messages, [ref_pil], - max_new_tokens, temperature) + max_new_tokens, temperature, think) def _parse_json(raw: str) -> dict | None: - """Best-effort: pull the first balanced JSON object out of the model output.""" - # Strip code fences if present. - fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL) - candidate = fenced.group(1) if fenced else None - if candidate is None: - start = raw.find("{") - if start == -1: - return None - depth = 0 - for i in range(start, len(raw)): - if raw[i] == "{": - depth += 1 - elif raw[i] == "}": - depth -= 1 - if depth == 0: - candidate = raw[start:i + 1] - break - if candidate is None: - return None - try: - return json.loads(candidate) - except json.JSONDecodeError: - return None + """Pull a JSON object out of the output. Reasoning models put the JSON AFTER prose, + so collect all balanced top-level objects and return the last one that parses and + contains 'axes' (or 'description') — falling back to the last that parses at all.""" + candidates = [] + depth = start = 0 + for i, ch in enumerate(raw): + if ch == "{": + if depth == 0: + start = i + depth += 1 + elif ch == "}" and depth > 0: + depth -= 1 + if depth == 0: + candidates.append(raw[start:i + 1]) + best = None + for cand in candidates: + try: + obj = json.loads(cand) + except json.JSONDecodeError: + continue + if isinstance(obj, dict): + best = obj + if "axes" in obj or "description" in obj: + # keep scanning; prefer the LAST such object (final answer) + best = obj + return best def _parse_markdown_verdicts(raw: str, axes: list[str]) -> dict: @@ -795,9 +810,12 @@ class QwenVLImageJudge: {"default": list(MODEL_PRESETS.keys())[0]}), "model_path": ("STRING", {"default": ""}), # manual override (local dir / HF repo / alias) "precision": (["bf16", "fp8", "nf4"], {"default": "bf16"}), - "max_new_tokens": ("INT", {"default": 2048, "min": 64, "max": 8192}), + "max_new_tokens": ("INT", {"default": 3072, "min": 64, "max": 8192}), "temperature": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.5, "step": 0.05}), "swap_eval": ("BOOLEAN", {"default": True}), + # Reasoning models (Qwen3.5/3.6) judge verdicts FAR better with thinking on + # (off -> they rubber-stamp 'match'). Costs more tokens; raise max_new_tokens. + "enable_thinking": ("BOOLEAN", {"default": True}), "keep_loaded": ("BOOLEAN", {"default": True}), "auto_download": ("BOOLEAN", {"default": True}), # Small config values stay as typeable fields. @@ -820,7 +838,7 @@ class QwenVLImageJudge: def judge(self, reference_image, mode, model_path, precision, max_new_tokens, temperature, swap_eval, profile="general", - model_select=MANUAL_CHOICE, generated_image=None, + enable_thinking=True, model_select=MANUAL_CHOICE, generated_image=None, keep_loaded=True, auto_download=True, report_dir="", run_tag="", axes="", reference_description="", system_prompt="", user_prompt="Describe this image."): @@ -862,12 +880,12 @@ class QwenVLImageJudge: gen_pil = _tensor_to_pil(generated_image) if generated_image is not None else None return self._chat(model, processor, ref_pil, gen_pil, system_prompt, user_prompt, max_new_tokens, temperature, resolved_path, eff_precision, - keep_loaded, report_dir, run_tag) + keep_loaded, report_dir, run_tag, enable_thinking) if mode == "describe": return self._describe(model, processor, ref_pil, axis_list, max_new_tokens, temperature, resolved_path, eff_precision, keep_loaded, - report_dir, run_tag) + report_dir, run_tag, enable_thinking) if generated_image is None: msg = "[QwenVLImageJudge] compare mode needs generated_image (or set mode=describe)." @@ -879,16 +897,18 @@ class QwenVLImageJudge: # Anchored: fixed canonical reference text + one generated image. No swap # (single image), and the reference side stays identical across iterations. raw_all = _run_anchored(model, processor, gen_pil, axis_list, max_new_tokens, - temperature, reference_description) + temperature, reference_description, enable_thinking) merged = _parse_axes(raw_all, axis_list) else: - raw1 = _run_once(model, processor, ref_pil, gen_pil, axis_list, max_new_tokens, temperature) + raw1 = _run_once(model, processor, ref_pil, gen_pil, axis_list, max_new_tokens, + temperature, enable_thinking) parsed1 = _parse_axes(raw1, axis_list) raw_all = raw1 merged = parsed1 if swap_eval: # Swap which image is called REFERENCE to average out position bias. - raw2 = _run_once(model, processor, gen_pil, ref_pil, axis_list, max_new_tokens, temperature) + raw2 = _run_once(model, processor, gen_pil, ref_pil, axis_list, max_new_tokens, + temperature, enable_thinking) parsed2 = _parse_axes(raw2, axis_list) merged = _merge_swapped(parsed1, parsed2) raw_all = raw1 + "\n--- SWAPPED ---\n" + raw2 @@ -921,11 +941,11 @@ class QwenVLImageJudge: def _chat(self, model, processor, ref_pil, gen_pil, system_prompt, user_prompt, max_new_tokens, temperature, resolved_path, precision, keep_loaded, - report_dir, run_tag): + report_dir, run_tag, think=True): """General-VLM mode: not a judge — just runs your prompt over the image(s).""" images = [ref_pil] + ([gen_pil] if gen_pil is not None else []) text = _run_chat(model, processor, images, system_prompt, user_prompt, - max_new_tokens, temperature).strip() + max_new_tokens, temperature, think).strip() if not keep_loaded: _MODEL_CACHE.pop((resolved_path, precision), None) del model @@ -934,10 +954,11 @@ class QwenVLImageJudge: return (1.0, "{}", text, text, report_path) def _describe(self, model, processor, ref_pil, axis_list, max_new_tokens, - temperature, resolved_path, precision, keep_loaded, report_dir, run_tag): + temperature, resolved_path, precision, keep_loaded, report_dir, run_tag, + think=True): """First pass: describe the reference image the generator must reproduce. Outputs the target spec (per-axis values) + a prompt-ready caption.""" - raw = _run_describe(model, processor, ref_pil, axis_list, max_new_tokens, temperature) + raw = _run_describe(model, processor, ref_pil, axis_list, max_new_tokens, temperature, think) parsed = _parse_json(raw) or {} if not keep_loaded: diff --git a/workflow/workflow_api.json b/workflow/workflow_api.json index 7e86cac..5d07205 100644 --- a/workflow/workflow_api.json +++ b/workflow/workflow_api.json @@ -68,7 +68,7 @@ "model_path": "/media/p5/qwen3vl_4b_abliterated_comfy_convert/hf_bf16", "precision": "bf16", "profile": "general", - "max_new_tokens": 2048, + "max_new_tokens": 3072, "temperature": 0.0, "swap_eval": true, "keep_loaded": true, diff --git a/workflow/workflow_describe_api.json b/workflow/workflow_describe_api.json index 2aa4cd1..1f6ba6c 100644 --- a/workflow/workflow_describe_api.json +++ b/workflow/workflow_describe_api.json @@ -12,7 +12,7 @@ "profile": "general", "model_path": "/media/p5/qwen3vl_4b_abliterated_comfy_convert/hf_bf16", "precision": "bf16", - "max_new_tokens": 2048, + "max_new_tokens": 3072, "temperature": 0.0, "swap_eval": false, "keep_loaded": true,