diff --git a/README.md b/README.md index 8f23079..f077520 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ can act on it. | `precision` | bf16 / fp8 / nf4 | bf16 | **the quant** — applies to the selected model (VRAM table below) | | `model_path` | STRING | "" (empty) | **manual override** of the dropdown — local dir, HF repo id, or alias (`8b`/`30b-a3b`/`3.5-9b`/`3.6-27b`/`3.6-35b`). Empty = use `model_select` | | `axes` | STRING | "" (empty) | **override** the profile's axis set with a custom comma/newline list; empty = use `profile` | -| `max_new_tokens` | INT | 1024 | | +| `max_new_tokens` | INT | 2048 | raise it if a reasoning model (Qwen3.5/3.6) gets cut off before finishing | | `temperature` | FLOAT | 0.0 | 0 = greedy/repeatable | | `swap_eval` | BOOL | true | run twice with images swapped, average → cuts position bias | | `keep_loaded` | BOOL | true | cache weights across loop iterations | diff --git a/nodes/qwen_judge.py b/nodes/qwen_judge.py index a6ec95f..965027b 100644 --- a/nodes/qwen_judge.py +++ b/nodes/qwen_judge.py @@ -348,7 +348,9 @@ def _build_system_prompt(axes: list[str], reference_description: str = "") -> st "to 'partial' — if the values are identical use 'match', if clearly different " "use 'mismatch'.\n") tail = ( - "Reply with STRICT JSON only, no prose, no markdown fences, exactly:\n" + "Output ONLY the JSON object — no reasoning, no step-by-step analysis, no " + "markdown, no commentary. Do NOT think out loud. Your entire reply must start " + "with '{' and end with '}', exactly:\n" "{\n" ' "axes": {\n' f"{axis_lines}\n" @@ -415,13 +417,25 @@ def _format_chatml_qwenvl(messages): return "".join(parts) +def _apply_template(processor, messages): + """apply_chat_template with thinking disabled (Qwen3.5/3.6 are reasoning models that + otherwise 'think out loud' in prose and never reach the JSON). Falls back gracefully.""" + try: + return processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) + except TypeError: + pass # template doesn't accept enable_thinking + except (ValueError, AttributeError): + return _format_chatml_qwenvl(messages) + try: + return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + except (ValueError, AttributeError): + return _format_chatml_qwenvl(messages) + + def _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature): """Template + forward pass for a chat-message list; returns the decoded string.""" - try: - text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - except (ValueError, AttributeError): - # Processor/tokenizer carries no chat template -> build ChatML by hand. - text = _format_chatml_qwenvl(messages) + text = _apply_template(processor, messages) inputs = processor(text=[text], images=images, return_tensors="pt") inputs = inputs.to(model.device) @@ -435,6 +449,8 @@ def _generate_from_messages(model, processor, messages, images, max_new_tokens, out = model.generate(**inputs, **gen_kwargs) trimmed = out[:, inputs.input_ids.shape[1]:] decoded = processor.batch_decode(trimmed, skip_special_tokens=True)[0] + # Strip any ... block a reasoning model may still emit. + decoded = re.sub(r".*?", "", decoded, flags=re.DOTALL) return decoded.strip() @@ -486,7 +502,8 @@ def _build_describe_prompt(axes: list[str]) -> str: "phrasing (the words a text-to-image prompt would use).\n\n" "Axes and exactly what each one means:\n" f"{_axis_definition_block(axes)}\n\n" - "Reply with STRICT JSON only, no prose, no markdown fences, exactly:\n" + "Output ONLY the JSON object — no reasoning, no analysis, no markdown, no " + "commentary. Do NOT think out loud. Start with '{' and end with '}', exactly:\n" "{\n" ' "description": "",\n' ' "axes": {\n' @@ -552,6 +569,39 @@ def _parse_json(raw: str) -> dict | None: return None +def _parse_markdown_verdicts(raw: str, axes: list[str]) -> dict: + """Fallback for reasoning models that emit prose instead of JSON. Reasoning models + reliably write a block per axis like: + **hair:** + - Ref: short, curly, brown + - Gen: long, straight, blonde + - Verdict: mismatch + Extract {verdict, ref, gen} per axis from that. Returns {} if nothing parseable.""" + out = {} + for ax in axes: + m = re.search(rf"(?im)^[\s\d.>*\-]*\**\s*{re.escape(ax)}\s*\**\s*:?\s*$" + rf"|\**\s*{re.escape(ax)}\s*\**\s*:", raw) + if not m: + continue + seg = raw[m.end(): m.end() + 500] + vd = re.search(r"(?i)verdict[\s*:>-]*\**\s*(match|partial|mismatch)", seg) + if not vd: + continue + ref = re.search(r"(?im)^\W*ref[a-z]*\W*[:\-]\s*\**\s*(.+?)\s*$", seg) + gen = re.search(r"(?im)^\W*gen[a-z]*\W*[:\-]\s*\**\s*(.+?)\s*$", seg) + clean = lambda s: s.group(1).strip().strip("*").strip(" .") if s else "" + out[ax] = {"verdict": vd.group(1).lower(), "ref": clean(ref), "gen": clean(gen)} + return {"axes": out} if out else {} + + +def _parse_axes(raw: str, axes: list[str]) -> dict: + """JSON first; if the model emitted prose instead, fall back to markdown verdicts.""" + j = _parse_json(raw) + if j and isinstance(j.get("axes"), dict) and j["axes"]: + return j + return _parse_markdown_verdicts(raw, axes) + + _VERDICT_ORDINAL = {"match": 1.0, "partial": 0.5, "mismatch": 0.0} @@ -746,7 +796,7 @@ class QwenVLImageJudge: "model_path": ("STRING", {"default": ""}), # manual override (local dir / HF repo / alias) "precision": (["bf16", "fp8", "nf4"], {"default": "bf16"}), "axes": ("STRING", {"default": "", "multiline": True}), - "max_new_tokens": ("INT", {"default": 1024, "min": 64, "max": 4096}), + "max_new_tokens": ("INT", {"default": 2048, "min": 64, "max": 8192}), "temperature": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.5, "step": 0.05}), "swap_eval": ("BOOLEAN", {"default": True}), "keep_loaded": ("BOOLEAN", {"default": True}), @@ -830,16 +880,16 @@ class QwenVLImageJudge: # (single image), and the reference side stays identical across iterations. raw_all = _run_anchored(model, processor, gen_pil, axis_list, max_new_tokens, temperature, reference_description) - merged = _parse_json(raw_all) or {} + merged = _parse_axes(raw_all, axis_list) else: raw1 = _run_once(model, processor, ref_pil, gen_pil, axis_list, max_new_tokens, temperature) - parsed1 = _parse_json(raw1) or {} + parsed1 = _parse_axes(raw1, axis_list) raw_all = raw1 merged = parsed1 if swap_eval: # Swap which image is called REFERENCE to average out position bias. raw2 = _run_once(model, processor, gen_pil, ref_pil, axis_list, max_new_tokens, temperature) - parsed2 = _parse_json(raw2) or {} + parsed2 = _parse_axes(raw2, axis_list) merged = _merge_swapped(parsed1, parsed2) raw_all = raw1 + "\n--- SWAPPED ---\n" + raw2 @@ -895,7 +945,8 @@ class QwenVLImageJudge: del model torch.cuda.empty_cache() - caption = (parsed.get("description") or parsed.get("caption") or "").strip() + # Fall back to the raw text as the caption if the model emitted prose, not JSON. + caption = (parsed.get("description") or parsed.get("caption") or raw).strip() axes_spec = parsed.get("axes", {}) if isinstance(parsed.get("axes"), dict) else {} axis_scores = json.dumps(axes_spec, ensure_ascii=False, indent=2) # The canonical reference text the compare pass will anchor on: paragraph + axes.