Add describe (first-pass) mode to the judge node

New mode on QwenVLImageJudge: 'describe' looks at the reference alone and returns
a prompt-ready caption + per-axis target spec to seed the very first prompt (the
generator has nothing to reproduce yet). 'compare' is the existing ref-vs-gen
scoring. generated_image is now optional (required only for compare); shared
generation refactored into _generate_from_messages; third output renamed
diff_analysis -> analysis (mode-agnostic). agent_bridge gains --mode (describe
needs no receptor/prompt); added workflow_describe_api.json. Docs updated with the
first-pass bootstrap step. Fixed error-return arity to 5-tuple.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-26 23:04:09 +02:00
parent 959ec70065
commit c7ef756a71
6 changed files with 211 additions and 47 deletions
+125 -25
View File
@@ -275,28 +275,14 @@ def _format_chatml_qwenvl(messages):
return "".join(parts)
def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperature):
"""One forward pass; returns the raw decoded string."""
messages = [
{"role": "system", "content": _build_system_prompt(axes)},
{
"role": "user",
"content": [
{"type": "text", "text": "IMAGE 1 = REFERENCE (target):"},
{"type": "image", "image": ref_pil},
{"type": "text", "text": "IMAGE 2 = GENERATED candidate:"},
{"type": "image", "image": gen_pil},
{"type": "text", "text": "Now return the strict JSON judgement."},
],
},
]
def _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature):
"""Template + forward pass for a chat-message list; returns the decoded string."""
try:
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except (ValueError, AttributeError):
# Processor/tokenizer carries no chat template -> build ChatML by hand.
text = _format_chatml_qwenvl(messages)
inputs = processor(text=[text], images=[ref_pil, gen_pil], return_tensors="pt")
inputs = processor(text=[text], images=images, return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = dict(max_new_tokens=max_new_tokens)
@@ -312,6 +298,60 @@ def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperat
return decoded.strip()
def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperature):
"""Compare pass: ref vs gen -> raw JSON judgement string."""
messages = [
{"role": "system", "content": _build_system_prompt(axes)},
{
"role": "user",
"content": [
{"type": "text", "text": "IMAGE 1 = REFERENCE (target):"},
{"type": "image", "image": ref_pil},
{"type": "text", "text": "IMAGE 2 = GENERATED candidate:"},
{"type": "image", "image": gen_pil},
{"type": "text", "text": "Now return the strict JSON judgement."},
],
},
]
return _generate_from_messages(model, processor, messages, [ref_pil, gen_pil],
max_new_tokens, temperature)
def _build_describe_prompt(axes: list[str]) -> str:
axis_lines = "\n".join(f' "{a}": "<concrete value or n/a>",' for a in axes)
return (
"You are describing a REFERENCE image that an image generator must try to "
"reproduce. Describe ONLY what you observe, concretely, in prompt-ready "
"phrasing (the words a text-to-image prompt would use).\n\n"
"Reply with STRICT JSON only, no prose, no markdown fences, exactly:\n"
"{\n"
' "caption": "<one detailed paragraph fully describing the image as a generation prompt>",\n'
' "axes": {\n'
f"{axis_lines}\n"
" }\n"
"}\n"
"Each axis value is a concrete description of that aspect of the image "
"(or \"n/a\" if not present). The caption should be directly usable as a prompt."
)
def _run_describe(model, processor, ref_pil, axes, max_new_tokens, temperature):
"""Describe pass: reference only -> raw JSON {caption, axes} string."""
messages = [
{"role": "system", "content": _build_describe_prompt(axes)},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this reference image:"},
{"type": "image", "image": ref_pil},
{"type": "text", "text": "Return the strict JSON description."},
],
},
]
return _generate_from_messages(model, processor, messages, [ref_pil],
max_new_tokens, temperature)
def _parse_json(raw: str) -> dict | None:
"""Best-effort: pull the first balanced JSON object out of the model output."""
# Strip code fences if present.
@@ -412,20 +452,48 @@ def _write_report(report_dir, run_tag, overall, merged, diff_analysis, raw_all,
return run_path
def _write_describe_report(report_dir, run_tag, caption, axes_spec, raw):
"""Persist the first-pass description (target spec) for the agent to seed from."""
base = _report_base_dir(report_dir)
try:
os.makedirs(base, exist_ok=True)
except OSError as e:
print(f"[QwenVLImageJudge] could not create report dir {base}: {e}")
return ""
payload = {
"mode": "describe",
"run_tag": run_tag,
"caption": caption,
"axes": axes_spec, # per-axis target values -> the agent's initial axis_state
"raw": raw,
}
tag = re.sub(r"[^A-Za-z0-9._-]", "_", run_tag) if run_tag else "describe"
run_path = os.path.join(base, f"calib_{tag}.json")
for path in (run_path, os.path.join(base, "latest.json")):
try:
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
except OSError as e:
print(f"[QwenVLImageJudge] failed writing report {path}: {e}")
return run_path
class QwenVLImageJudge:
"""ComfyUI node: score how close a generated image is to a reference."""
"""ComfyUI node: describe a reference, or score how close a generated image is to it."""
CATEGORY = "prompt_calibrator"
FUNCTION = "judge"
RETURN_TYPES = ("FLOAT", "STRING", "STRING", "STRING", "STRING")
RETURN_NAMES = ("overall_score", "axis_scores_json", "diff_analysis", "raw", "report_path")
RETURN_NAMES = ("overall_score", "axis_scores_json", "analysis", "raw", "report_path")
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"reference_image": ("IMAGE",),
"generated_image": ("IMAGE",),
# describe = reference only -> target description (first pass, seeds the
# initial prompt). compare = ref vs generated -> per-axis scoring.
"mode": (["compare", "describe"], {"default": "compare"}),
"model_path": ("STRING", {"default": DEFAULT_MODEL_PATH}),
"precision": (["bf16", "fp16", "fp8", "nf4"], {"default": "bf16"}),
"axes": ("STRING", {"default": DEFAULT_AXES, "multiline": True}),
@@ -434,6 +502,7 @@ class QwenVLImageJudge:
"swap_eval": ("BOOLEAN", {"default": True}),
},
"optional": {
"generated_image": ("IMAGE",), # required for compare, ignored for describe
"keep_loaded": ("BOOLEAN", {"default": True}),
"auto_download": ("BOOLEAN", {"default": True}),
# The agent reads the analysis from these files after each queue.
@@ -443,8 +512,9 @@ class QwenVLImageJudge:
},
}
def judge(self, reference_image, generated_image, model_path, precision, axes,
max_new_tokens, temperature, swap_eval, keep_loaded=True, auto_download=True,
def judge(self, reference_image, mode, model_path, precision, axes,
max_new_tokens, temperature, swap_eval, generated_image=None,
keep_loaded=True, auto_download=True,
report_dir="", run_tag="", prompt_used=""):
axis_list = [a.strip() for a in re.split(r"[,\n]", axes) if a.strip()]
if not axis_list:
@@ -455,13 +525,22 @@ class QwenVLImageJudge:
except Exception as e: # missing model / download failure -> surface as score 0
msg = str(e)
print(msg)
return (0.0, "{}", msg, msg)
return (0.0, "{}", msg, msg, "")
ref_pil = _tensor_to_pil(reference_image)
gen_pil = _tensor_to_pil(generated_image)
model, processor = _load_model(resolved_path, precision)
if mode == "describe":
return self._describe(model, processor, ref_pil, axis_list, max_new_tokens,
temperature, resolved_path, precision, keep_loaded,
report_dir, run_tag)
if generated_image is None:
msg = "[QwenVLImageJudge] compare mode needs generated_image (or set mode=describe)."
print(msg)
return (0.0, "{}", msg, msg, "")
gen_pil = _tensor_to_pil(generated_image)
raw1 = _run_once(model, processor, ref_pil, gen_pil, axis_list, max_new_tokens, temperature)
parsed1 = _parse_json(raw1) or {}
@@ -496,6 +575,27 @@ class QwenVLImageJudge:
return (round(overall, 4), axis_scores, diff_analysis, raw_all, report_path)
def _describe(self, model, processor, ref_pil, axis_list, max_new_tokens,
temperature, resolved_path, precision, keep_loaded, report_dir, run_tag):
"""First pass: describe the reference image the generator must reproduce.
Outputs the target spec (per-axis values) + a prompt-ready caption."""
raw = _run_describe(model, processor, ref_pil, axis_list, max_new_tokens, temperature)
parsed = _parse_json(raw) or {}
if not keep_loaded:
_MODEL_CACHE.pop((resolved_path, precision), None)
del model
torch.cuda.empty_cache()
caption = (parsed.get("caption") or "").strip()
axes_spec = parsed.get("axes", {}) if isinstance(parsed.get("axes"), dict) else {}
axis_scores = json.dumps(axes_spec, ensure_ascii=False, indent=2)
analysis = caption if caption else "(no parseable description)"
report_path = _write_describe_report(report_dir, run_tag, caption, axes_spec, raw)
# overall_score is n/a in describe mode; return 1.0 as a neutral placeholder.
return (1.0, axis_scores, analysis, raw, report_path)
NODE_CLASS_MAPPINGS = {"QwenVLImageJudge": QwenVLImageJudge}
NODE_DISPLAY_NAME_MAPPINGS = {"QwenVLImageJudge": "Qwen3-VL Image Judge (Calibrator)"}