Re-enable reasoning for accurate verdicts (no-think rubber-stamped 'match')

Disabling thinking made reasoning models mark everything 'match' even when ref/gen
clearly differ. Added an enable_thinking toggle (default ON) threaded through the
generation path; the prompt now allows reasoning then asks for the result, and
verdict_rule explicitly warns against lazy 'match'. _parse_json now scans for the
JSON object AFTER the reasoning prose (last balanced object with 'axes'), and the
markdown fallback already reads reasoned per-axis output. Default max_new_tokens
2048->3072 so verdicts don't get cut off.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-27 10:56:47 +02:00
parent fee136e98c
commit 22fd24b29e
4 changed files with 96 additions and 74 deletions
+2 -1
View File
@@ -38,7 +38,8 @@ can act on it.
| `precision` | bf16 / fp8 / nf4 | bf16 | **the quant** — applies to the selected model (VRAM table below) |
| `model_path` | STRING | "" (empty) | **manual override** of the dropdown — local dir, HF repo id, or alias (`8b`/`30b-a3b`/`3.5-9b`/`3.6-27b`/`3.6-35b`). Empty = use `model_select` |
| `axes` | STRING **input** | — | (socket) optional override of the profile's axis set; wire a text node or leave unconnected to use `profile` |
| `max_new_tokens` | INT | 2048 | raise it if a reasoning model (Qwen3.5/3.6) gets cut off before finishing |
| `max_new_tokens` | INT | 3072 | reasoning models (Qwen3.5/3.6) need room; raise it if the verdict gets cut off |
| `enable_thinking` | BOOL | true | let the model reason before judging. **Keep on for accurate verdicts** — off makes reasoning models rubber-stamp `match`. Off is faster |
| `temperature` | FLOAT | 0.0 | 0 = greedy/repeatable |
| `swap_eval` | BOOL | true | run twice with images swapped, average → cuts position bias |
| `keep_loaded` | BOOL | true | cache weights across loop iterations |
+92 -71
View File
@@ -336,26 +336,35 @@ def _axis_definition_block(axes: list[str]) -> str:
return "\n".join(f" - {a}: {AXIS_DEFS.get(a, 'as named')}" for a in axes)
def _build_system_prompt(axes: list[str], reference_description: str = "") -> str:
def _build_system_prompt(axes: list[str], reference_description: str = "", think: bool = True) -> str:
axis_lines = "\n".join(
f' "{a}": {{"verdict": "match|partial|mismatch", "ref": "<ref value>", "gen": "<generated image>"}},'
for a in axes)
verdict_rule = (
" - verdict: 'match' if ref and gen are the same; 'mismatch' if they are "
"opposite or clearly different (e.g. 'on top' vs 'on bottom', 'doggy' vs "
"'cowgirl', 'short' vs 'long', 'eyes closed' vs 'at camera'); 'partial' ONLY "
"for a genuine middle ground (same category, minor difference). Do NOT default "
"to 'partial' — if the values are identical use 'match', if clearly different "
"use 'mismatch'.\n")
tail = (
"Output ONLY the JSON object — no reasoning, no step-by-step analysis, no "
"markdown, no commentary. Do NOT think out loud. Your entire reply must start "
"with '{' and end with '}', exactly:\n"
"{\n"
' "axes": {\n'
f"{axis_lines}\n"
" }\n"
"}\n")
" - verdict: COMPARE ref vs gen carefully. 'match' only if they are the same; "
"'mismatch' if opposite or clearly different (e.g. 'on top' vs 'on bottom', "
"'short' vs 'long', 'brown' vs 'blonde', 'eyes closed' vs 'eyes open'); 'partial' "
"for same category with a clear difference. Do NOT lazily mark everything 'match' "
"— if the words differ, it is NOT a match.\n")
if think:
tail = (
"Examine each axis and decide its verdict by actually comparing ref and gen. "
"You may reason first. END your reply with the result for every axis as a JSON "
"object (or a per-axis list with ref/gen/verdict), schema:\n"
"{\n"
' "axes": {\n'
f"{axis_lines}\n"
" }\n"
"}\n")
else:
tail = (
"Output ONLY the JSON object — no prose, no markdown. Start with '{' end with "
"'}', exactly:\n"
"{\n"
' "axes": {\n'
f"{axis_lines}\n"
" }\n"
"}\n")
if reference_description.strip():
# Anchored mode: the reference is a fixed canonical description (text), only the
@@ -417,12 +426,14 @@ def _format_chatml_qwenvl(messages):
return "".join(parts)
def _apply_template(processor, messages):
"""apply_chat_template with thinking disabled (Qwen3.5/3.6 are reasoning models that
otherwise 'think out loud' in prose and never reach the JSON). Falls back gracefully."""
def _apply_template(processor, messages, think=True):
"""apply_chat_template, optionally toggling reasoning. Reasoning models (Qwen3.5/3.6)
judge verdicts far better WITH thinking on (off -> they rubber-stamp 'match'); the
markdown fallback parser reads the reasoned per-axis output. Set think=False for a
faster, JSON-only pass. Falls back to a hand-built ChatML prompt if no template."""
try:
return processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
messages, tokenize=False, add_generation_prompt=True, enable_thinking=think)
except TypeError:
pass # template doesn't accept enable_thinking
except (ValueError, AttributeError):
@@ -433,9 +444,9 @@ def _apply_template(processor, messages):
return _format_chatml_qwenvl(messages)
def _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature):
def _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature, think=True):
"""Template + forward pass for a chat-message list; returns the decoded string."""
text = _apply_template(processor, messages)
text = _apply_template(processor, messages, think)
inputs = processor(text=[text], images=images, return_tensors="pt")
inputs = inputs.to(model.device)
@@ -454,10 +465,10 @@ def _generate_from_messages(model, processor, messages, images, max_new_tokens,
return decoded.strip()
def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperature):
"""Compare pass: ref vs gen -> raw JSON judgement string."""
def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperature, think=True):
"""Compare pass: ref vs gen -> raw judgement string (JSON or reasoned prose)."""
messages = [
{"role": "system", "content": _build_system_prompt(axes)},
{"role": "system", "content": _build_system_prompt(axes, think=think)},
{
"role": "user",
"content": [
@@ -465,29 +476,30 @@ def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperat
{"type": "image", "image": ref_pil},
{"type": "text", "text": "IMAGE 2 = GENERATED candidate:"},
{"type": "image", "image": gen_pil},
{"type": "text", "text": "Now return the strict JSON judgement."},
{"type": "text", "text": "Now judge every axis."},
],
},
]
return _generate_from_messages(model, processor, messages, [ref_pil, gen_pil],
max_new_tokens, temperature)
max_new_tokens, temperature, think)
def _run_anchored(model, processor, gen_pil, axes, max_new_tokens, temperature, reference_description):
def _run_anchored(model, processor, gen_pil, axes, max_new_tokens, temperature,
reference_description, think=True):
"""Anchored compare: fixed canonical reference text + one generated image."""
messages = [
{"role": "system", "content": _build_system_prompt(axes, reference_description)},
{"role": "system", "content": _build_system_prompt(axes, reference_description, think=think)},
{
"role": "user",
"content": [
{"type": "text", "text": "GENERATED candidate image:"},
{"type": "image", "image": gen_pil},
{"type": "text", "text": "Compare it to the reference description and return the strict JSON."},
{"type": "text", "text": "Compare it to the reference description and judge every axis."},
],
},
]
return _generate_from_messages(model, processor, messages, [gen_pil],
max_new_tokens, temperature)
max_new_tokens, temperature, think)
def _build_describe_prompt(axes: list[str]) -> str:
@@ -515,7 +527,7 @@ def _build_describe_prompt(axes: list[str]) -> str:
)
def _run_chat(model, processor, images, system_prompt, user_prompt, max_new_tokens, temperature):
def _run_chat(model, processor, images, system_prompt, user_prompt, max_new_tokens, temperature, think=True):
"""General VLM pass: your own system/user prompt over the image(s) -> raw text."""
content = [{"type": "image", "image": img} for img in images]
content.append({"type": "text", "text": user_prompt or "Describe this image."})
@@ -523,11 +535,11 @@ def _run_chat(model, processor, images, system_prompt, user_prompt, max_new_toke
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": content})
return _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature)
return _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature, think)
def _run_describe(model, processor, ref_pil, axes, max_new_tokens, temperature):
"""Describe pass: reference only -> raw JSON {caption, axes} string."""
def _run_describe(model, processor, ref_pil, axes, max_new_tokens, temperature, think=True):
"""Describe pass: reference only -> raw {description, axes} (JSON or reasoned prose)."""
messages = [
{"role": "system", "content": _build_describe_prompt(axes)},
{
@@ -535,38 +547,41 @@ def _run_describe(model, processor, ref_pil, axes, max_new_tokens, temperature):
"content": [
{"type": "text", "text": "Describe this reference image:"},
{"type": "image", "image": ref_pil},
{"type": "text", "text": "Return the strict JSON description."},
{"type": "text", "text": "Give the full description."},
],
},
]
return _generate_from_messages(model, processor, messages, [ref_pil],
max_new_tokens, temperature)
max_new_tokens, temperature, think)
def _parse_json(raw: str) -> dict | None:
"""Best-effort: pull the first balanced JSON object out of the model output."""
# Strip code fences if present.
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
candidate = fenced.group(1) if fenced else None
if candidate is None:
start = raw.find("{")
if start == -1:
return None
depth = 0
for i in range(start, len(raw)):
if raw[i] == "{":
depth += 1
elif raw[i] == "}":
depth -= 1
if depth == 0:
candidate = raw[start:i + 1]
break
if candidate is None:
return None
try:
return json.loads(candidate)
except json.JSONDecodeError:
return None
"""Pull a JSON object out of the output. Reasoning models put the JSON AFTER prose,
so collect all balanced top-level objects and return the last one that parses and
contains 'axes' (or 'description') — falling back to the last that parses at all."""
candidates = []
depth = start = 0
for i, ch in enumerate(raw):
if ch == "{":
if depth == 0:
start = i
depth += 1
elif ch == "}" and depth > 0:
depth -= 1
if depth == 0:
candidates.append(raw[start:i + 1])
best = None
for cand in candidates:
try:
obj = json.loads(cand)
except json.JSONDecodeError:
continue
if isinstance(obj, dict):
best = obj
if "axes" in obj or "description" in obj:
# keep scanning; prefer the LAST such object (final answer)
best = obj
return best
def _parse_markdown_verdicts(raw: str, axes: list[str]) -> dict:
@@ -795,9 +810,12 @@ class QwenVLImageJudge:
{"default": list(MODEL_PRESETS.keys())[0]}),
"model_path": ("STRING", {"default": ""}), # manual override (local dir / HF repo / alias)
"precision": (["bf16", "fp8", "nf4"], {"default": "bf16"}),
"max_new_tokens": ("INT", {"default": 2048, "min": 64, "max": 8192}),
"max_new_tokens": ("INT", {"default": 3072, "min": 64, "max": 8192}),
"temperature": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.5, "step": 0.05}),
"swap_eval": ("BOOLEAN", {"default": True}),
# Reasoning models (Qwen3.5/3.6) judge verdicts FAR better with thinking on
# (off -> they rubber-stamp 'match'). Costs more tokens; raise max_new_tokens.
"enable_thinking": ("BOOLEAN", {"default": True}),
"keep_loaded": ("BOOLEAN", {"default": True}),
"auto_download": ("BOOLEAN", {"default": True}),
# Small config values stay as typeable fields.
@@ -820,7 +838,7 @@ class QwenVLImageJudge:
def judge(self, reference_image, mode, model_path, precision,
max_new_tokens, temperature, swap_eval, profile="general",
model_select=MANUAL_CHOICE, generated_image=None,
enable_thinking=True, model_select=MANUAL_CHOICE, generated_image=None,
keep_loaded=True, auto_download=True,
report_dir="", run_tag="", axes="", reference_description="",
system_prompt="", user_prompt="Describe this image."):
@@ -862,12 +880,12 @@ class QwenVLImageJudge:
gen_pil = _tensor_to_pil(generated_image) if generated_image is not None else None
return self._chat(model, processor, ref_pil, gen_pil, system_prompt, user_prompt,
max_new_tokens, temperature, resolved_path, eff_precision,
keep_loaded, report_dir, run_tag)
keep_loaded, report_dir, run_tag, enable_thinking)
if mode == "describe":
return self._describe(model, processor, ref_pil, axis_list, max_new_tokens,
temperature, resolved_path, eff_precision, keep_loaded,
report_dir, run_tag)
report_dir, run_tag, enable_thinking)
if generated_image is None:
msg = "[QwenVLImageJudge] compare mode needs generated_image (or set mode=describe)."
@@ -879,16 +897,18 @@ class QwenVLImageJudge:
# Anchored: fixed canonical reference text + one generated image. No swap
# (single image), and the reference side stays identical across iterations.
raw_all = _run_anchored(model, processor, gen_pil, axis_list, max_new_tokens,
temperature, reference_description)
temperature, reference_description, enable_thinking)
merged = _parse_axes(raw_all, axis_list)
else:
raw1 = _run_once(model, processor, ref_pil, gen_pil, axis_list, max_new_tokens, temperature)
raw1 = _run_once(model, processor, ref_pil, gen_pil, axis_list, max_new_tokens,
temperature, enable_thinking)
parsed1 = _parse_axes(raw1, axis_list)
raw_all = raw1
merged = parsed1
if swap_eval:
# Swap which image is called REFERENCE to average out position bias.
raw2 = _run_once(model, processor, gen_pil, ref_pil, axis_list, max_new_tokens, temperature)
raw2 = _run_once(model, processor, gen_pil, ref_pil, axis_list, max_new_tokens,
temperature, enable_thinking)
parsed2 = _parse_axes(raw2, axis_list)
merged = _merge_swapped(parsed1, parsed2)
raw_all = raw1 + "\n--- SWAPPED ---\n" + raw2
@@ -921,11 +941,11 @@ class QwenVLImageJudge:
def _chat(self, model, processor, ref_pil, gen_pil, system_prompt, user_prompt,
max_new_tokens, temperature, resolved_path, precision, keep_loaded,
report_dir, run_tag):
report_dir, run_tag, think=True):
"""General-VLM mode: not a judge — just runs your prompt over the image(s)."""
images = [ref_pil] + ([gen_pil] if gen_pil is not None else [])
text = _run_chat(model, processor, images, system_prompt, user_prompt,
max_new_tokens, temperature).strip()
max_new_tokens, temperature, think).strip()
if not keep_loaded:
_MODEL_CACHE.pop((resolved_path, precision), None)
del model
@@ -934,10 +954,11 @@ class QwenVLImageJudge:
return (1.0, "{}", text, text, report_path)
def _describe(self, model, processor, ref_pil, axis_list, max_new_tokens,
temperature, resolved_path, precision, keep_loaded, report_dir, run_tag):
temperature, resolved_path, precision, keep_loaded, report_dir, run_tag,
think=True):
"""First pass: describe the reference image the generator must reproduce.
Outputs the target spec (per-axis values) + a prompt-ready caption."""
raw = _run_describe(model, processor, ref_pil, axis_list, max_new_tokens, temperature)
raw = _run_describe(model, processor, ref_pil, axis_list, max_new_tokens, temperature, think)
parsed = _parse_json(raw) or {}
if not keep_loaded:
+1 -1
View File
@@ -68,7 +68,7 @@
"model_path": "/media/p5/qwen3vl_4b_abliterated_comfy_convert/hf_bf16",
"precision": "bf16",
"profile": "general",
"max_new_tokens": 2048,
"max_new_tokens": 3072,
"temperature": 0.0,
"swap_eval": true,
"keep_loaded": true,
+1 -1
View File
@@ -12,7 +12,7 @@
"profile": "general",
"model_path": "/media/p5/qwen3vl_4b_abliterated_comfy_convert/hf_bf16",
"precision": "bf16",
"max_new_tokens": 2048,
"max_new_tokens": 3072,
"temperature": 0.0,
"swap_eval": false,
"keep_loaded": true,