Add describe (first-pass) mode to the judge node
New mode on QwenVLImageJudge: 'describe' looks at the reference alone and returns a prompt-ready caption + per-axis target spec to seed the very first prompt (the generator has nothing to reproduce yet). 'compare' is the existing ref-vs-gen scoring. generated_image is now optional (required only for compare); shared generation refactored into _generate_from_messages; third output renamed diff_analysis -> analysis (mode-agnostic). agent_bridge gains --mode (describe needs no receptor/prompt); added workflow_describe_api.json. Docs updated with the first-pass bootstrap step. Fixed error-return arity to 5-tuple. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+125
-25
@@ -275,28 +275,14 @@ def _format_chatml_qwenvl(messages):
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperature):
|
||||
"""One forward pass; returns the raw decoded string."""
|
||||
messages = [
|
||||
{"role": "system", "content": _build_system_prompt(axes)},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "IMAGE 1 = REFERENCE (target):"},
|
||||
{"type": "image", "image": ref_pil},
|
||||
{"type": "text", "text": "IMAGE 2 = GENERATED candidate:"},
|
||||
{"type": "image", "image": gen_pil},
|
||||
{"type": "text", "text": "Now return the strict JSON judgement."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature):
|
||||
"""Template + forward pass for a chat-message list; returns the decoded string."""
|
||||
try:
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
except (ValueError, AttributeError):
|
||||
# Processor/tokenizer carries no chat template -> build ChatML by hand.
|
||||
text = _format_chatml_qwenvl(messages)
|
||||
inputs = processor(text=[text], images=[ref_pil, gen_pil], return_tensors="pt")
|
||||
inputs = processor(text=[text], images=images, return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
|
||||
gen_kwargs = dict(max_new_tokens=max_new_tokens)
|
||||
@@ -312,6 +298,60 @@ def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperat
|
||||
return decoded.strip()
|
||||
|
||||
|
||||
def _run_once(model, processor, ref_pil, gen_pil, axes, max_new_tokens, temperature):
|
||||
"""Compare pass: ref vs gen -> raw JSON judgement string."""
|
||||
messages = [
|
||||
{"role": "system", "content": _build_system_prompt(axes)},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "IMAGE 1 = REFERENCE (target):"},
|
||||
{"type": "image", "image": ref_pil},
|
||||
{"type": "text", "text": "IMAGE 2 = GENERATED candidate:"},
|
||||
{"type": "image", "image": gen_pil},
|
||||
{"type": "text", "text": "Now return the strict JSON judgement."},
|
||||
],
|
||||
},
|
||||
]
|
||||
return _generate_from_messages(model, processor, messages, [ref_pil, gen_pil],
|
||||
max_new_tokens, temperature)
|
||||
|
||||
|
||||
def _build_describe_prompt(axes: list[str]) -> str:
|
||||
axis_lines = "\n".join(f' "{a}": "<concrete value or n/a>",' for a in axes)
|
||||
return (
|
||||
"You are describing a REFERENCE image that an image generator must try to "
|
||||
"reproduce. Describe ONLY what you observe, concretely, in prompt-ready "
|
||||
"phrasing (the words a text-to-image prompt would use).\n\n"
|
||||
"Reply with STRICT JSON only, no prose, no markdown fences, exactly:\n"
|
||||
"{\n"
|
||||
' "caption": "<one detailed paragraph fully describing the image as a generation prompt>",\n'
|
||||
' "axes": {\n'
|
||||
f"{axis_lines}\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"Each axis value is a concrete description of that aspect of the image "
|
||||
"(or \"n/a\" if not present). The caption should be directly usable as a prompt."
|
||||
)
|
||||
|
||||
|
||||
def _run_describe(model, processor, ref_pil, axes, max_new_tokens, temperature):
|
||||
"""Describe pass: reference only -> raw JSON {caption, axes} string."""
|
||||
messages = [
|
||||
{"role": "system", "content": _build_describe_prompt(axes)},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this reference image:"},
|
||||
{"type": "image", "image": ref_pil},
|
||||
{"type": "text", "text": "Return the strict JSON description."},
|
||||
],
|
||||
},
|
||||
]
|
||||
return _generate_from_messages(model, processor, messages, [ref_pil],
|
||||
max_new_tokens, temperature)
|
||||
|
||||
|
||||
def _parse_json(raw: str) -> dict | None:
|
||||
"""Best-effort: pull the first balanced JSON object out of the model output."""
|
||||
# Strip code fences if present.
|
||||
@@ -412,20 +452,48 @@ def _write_report(report_dir, run_tag, overall, merged, diff_analysis, raw_all,
|
||||
return run_path
|
||||
|
||||
|
||||
def _write_describe_report(report_dir, run_tag, caption, axes_spec, raw):
|
||||
"""Persist the first-pass description (target spec) for the agent to seed from."""
|
||||
base = _report_base_dir(report_dir)
|
||||
try:
|
||||
os.makedirs(base, exist_ok=True)
|
||||
except OSError as e:
|
||||
print(f"[QwenVLImageJudge] could not create report dir {base}: {e}")
|
||||
return ""
|
||||
payload = {
|
||||
"mode": "describe",
|
||||
"run_tag": run_tag,
|
||||
"caption": caption,
|
||||
"axes": axes_spec, # per-axis target values -> the agent's initial axis_state
|
||||
"raw": raw,
|
||||
}
|
||||
tag = re.sub(r"[^A-Za-z0-9._-]", "_", run_tag) if run_tag else "describe"
|
||||
run_path = os.path.join(base, f"calib_{tag}.json")
|
||||
for path in (run_path, os.path.join(base, "latest.json")):
|
||||
try:
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
except OSError as e:
|
||||
print(f"[QwenVLImageJudge] failed writing report {path}: {e}")
|
||||
return run_path
|
||||
|
||||
|
||||
class QwenVLImageJudge:
|
||||
"""ComfyUI node: score how close a generated image is to a reference."""
|
||||
"""ComfyUI node: describe a reference, or score how close a generated image is to it."""
|
||||
|
||||
CATEGORY = "prompt_calibrator"
|
||||
FUNCTION = "judge"
|
||||
RETURN_TYPES = ("FLOAT", "STRING", "STRING", "STRING", "STRING")
|
||||
RETURN_NAMES = ("overall_score", "axis_scores_json", "diff_analysis", "raw", "report_path")
|
||||
RETURN_NAMES = ("overall_score", "axis_scores_json", "analysis", "raw", "report_path")
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"reference_image": ("IMAGE",),
|
||||
"generated_image": ("IMAGE",),
|
||||
# describe = reference only -> target description (first pass, seeds the
|
||||
# initial prompt). compare = ref vs generated -> per-axis scoring.
|
||||
"mode": (["compare", "describe"], {"default": "compare"}),
|
||||
"model_path": ("STRING", {"default": DEFAULT_MODEL_PATH}),
|
||||
"precision": (["bf16", "fp16", "fp8", "nf4"], {"default": "bf16"}),
|
||||
"axes": ("STRING", {"default": DEFAULT_AXES, "multiline": True}),
|
||||
@@ -434,6 +502,7 @@ class QwenVLImageJudge:
|
||||
"swap_eval": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional": {
|
||||
"generated_image": ("IMAGE",), # required for compare, ignored for describe
|
||||
"keep_loaded": ("BOOLEAN", {"default": True}),
|
||||
"auto_download": ("BOOLEAN", {"default": True}),
|
||||
# The agent reads the analysis from these files after each queue.
|
||||
@@ -443,8 +512,9 @@ class QwenVLImageJudge:
|
||||
},
|
||||
}
|
||||
|
||||
def judge(self, reference_image, generated_image, model_path, precision, axes,
|
||||
max_new_tokens, temperature, swap_eval, keep_loaded=True, auto_download=True,
|
||||
def judge(self, reference_image, mode, model_path, precision, axes,
|
||||
max_new_tokens, temperature, swap_eval, generated_image=None,
|
||||
keep_loaded=True, auto_download=True,
|
||||
report_dir="", run_tag="", prompt_used=""):
|
||||
axis_list = [a.strip() for a in re.split(r"[,\n]", axes) if a.strip()]
|
||||
if not axis_list:
|
||||
@@ -455,13 +525,22 @@ class QwenVLImageJudge:
|
||||
except Exception as e: # missing model / download failure -> surface as score 0
|
||||
msg = str(e)
|
||||
print(msg)
|
||||
return (0.0, "{}", msg, msg)
|
||||
return (0.0, "{}", msg, msg, "")
|
||||
|
||||
ref_pil = _tensor_to_pil(reference_image)
|
||||
gen_pil = _tensor_to_pil(generated_image)
|
||||
|
||||
model, processor = _load_model(resolved_path, precision)
|
||||
|
||||
if mode == "describe":
|
||||
return self._describe(model, processor, ref_pil, axis_list, max_new_tokens,
|
||||
temperature, resolved_path, precision, keep_loaded,
|
||||
report_dir, run_tag)
|
||||
|
||||
if generated_image is None:
|
||||
msg = "[QwenVLImageJudge] compare mode needs generated_image (or set mode=describe)."
|
||||
print(msg)
|
||||
return (0.0, "{}", msg, msg, "")
|
||||
gen_pil = _tensor_to_pil(generated_image)
|
||||
|
||||
raw1 = _run_once(model, processor, ref_pil, gen_pil, axis_list, max_new_tokens, temperature)
|
||||
parsed1 = _parse_json(raw1) or {}
|
||||
|
||||
@@ -496,6 +575,27 @@ class QwenVLImageJudge:
|
||||
|
||||
return (round(overall, 4), axis_scores, diff_analysis, raw_all, report_path)
|
||||
|
||||
def _describe(self, model, processor, ref_pil, axis_list, max_new_tokens,
|
||||
temperature, resolved_path, precision, keep_loaded, report_dir, run_tag):
|
||||
"""First pass: describe the reference image the generator must reproduce.
|
||||
Outputs the target spec (per-axis values) + a prompt-ready caption."""
|
||||
raw = _run_describe(model, processor, ref_pil, axis_list, max_new_tokens, temperature)
|
||||
parsed = _parse_json(raw) or {}
|
||||
|
||||
if not keep_loaded:
|
||||
_MODEL_CACHE.pop((resolved_path, precision), None)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
caption = (parsed.get("caption") or "").strip()
|
||||
axes_spec = parsed.get("axes", {}) if isinstance(parsed.get("axes"), dict) else {}
|
||||
axis_scores = json.dumps(axes_spec, ensure_ascii=False, indent=2)
|
||||
analysis = caption if caption else "(no parseable description)"
|
||||
|
||||
report_path = _write_describe_report(report_dir, run_tag, caption, axes_spec, raw)
|
||||
# overall_score is n/a in describe mode; return 1.0 as a neutral placeholder.
|
||||
return (1.0, axis_scores, analysis, raw, report_path)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {"QwenVLImageJudge": QwenVLImageJudge}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"QwenVLImageJudge": "Qwen3-VL Image Judge (Calibrator)"}
|
||||
|
||||
Reference in New Issue
Block a user