diff --git a/README.md b/README.md index 7d8aba9..1b304b1 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ can act on it. | name | type | default | notes | |---|---|---|---| | `reference_image` | IMAGE | — | the target | -| `mode` | compare / describe | compare | `describe` = first pass over the reference only → caption + target spec (seeds the prompt). `compare` = score ref vs generated | +| `mode` | compare / describe / chat | compare | `compare` = score ref vs generated. `describe` = first pass over the reference → caption + target spec. `chat` = **general VLM**: your `system_prompt` + `user_prompt` over the image(s) → raw text | | `profile` | general / oral / penetration / handjob / solo | general | **analysis profile** — act-specialized axis set; the act-critical axes are distance/proximity-aware (e.g. `mouth_genital_distance`) so magnitude isn't hidden behind a coarse label | | `generated_image` | IMAGE (optional) | — | the candidate to score (required for `compare`, ignored for `describe`) | | `model_select` | dropdown (model name) | 4B local | **which judge** (transformers/safetensors, auto-downloaded): Qwen3-VL 4B/8B/30B-A3B, **Qwen3.5-9B**, **Qwen3.6-27B/35B-A3B** (newer, natively multimodal). Param size shown in the label | @@ -43,19 +43,27 @@ can act on it. | `swap_eval` | BOOL | true | run twice with images swapped, average → cuts position bias | | `keep_loaded` | BOOL | true | cache weights across loop iterations | | `auto_download` | BOOL | true | if `model_path` is a repo id/alias and not local, fetch it from HF into `models/prompt_generator/` | +| `system_prompt` | STRING | "" | **chat mode**: your system prompt | +| `user_prompt` | STRING | "Describe this image." | **chat mode**: your instruction over the image(s) | **Auto-download:** set `model_path` to `30b-a3b` (alias) or any `org/name` repo id and leave `auto_download` on — the node snapshot-downloads it on first run (into ComfyUI's `models/prompt_generator/`) and reuses the local copy afterward. Local paths and the default skip download entirely. +**General VLM (chat mode):** set `mode=chat` and the node becomes a plain vision-language +node — feed an image (and optionally a second), write your own `system_prompt`/`user_prompt`, +and read the model's text from the `analysis` output. Reuses the same model dropdown, quant, +and auto-download as the judge, so it's a one-node abliterated VLM for captioning, tagging, +Q&A, prompt-from-image, etc. (CLI: `agent_bridge.py --mode chat --user-prompt "..."`). + **Outputs** | name | type | use | |---|---|---| | `overall_score` | FLOAT 0..1 | compare: mean verdict (computed here, not by the model). describe: `1.0` placeholder | | `axis_scores_json` | STRING (JSON) | compare: per-axis `{verdict, ref, gen}` (verdict = match/partial/mismatch). describe: `{axis: value}` | -| `analysis` | STRING | compare: header (`overall, N mismatches`) + axes worst-first (`VERDICT ref:[…] gen:[…]`). describe: the `caption` | +| `analysis` | STRING | compare: header (`overall, N mismatches`) + axes worst-first (`VERDICT ref:[…] gen:[…]`). describe: the `caption`. chat: the model's response | | `raw` | STRING | raw model output (both passes if `swap_eval`) | | `report_path` | STRING | path to the written `calib_.json` (carries `mismatch_count`) | diff --git a/agent_bridge.py b/agent_bridge.py index 6b84646..a2c74a0 100644 --- a/agent_bridge.py +++ b/agent_bridge.py @@ -49,7 +49,7 @@ def _http_json(url: str, payload: dict | None = None, timeout: int = 30): def _inject(graph: dict, prompt: str, negative: str, seed: int, run_tag: str, mode: str, reference_description: str = "", profile: str = "", model_select: str = "", - model_path: str = ""): + model_path: str = "", system_prompt: str = "", user_prompt: str = ""): """Set the receptor's prompt/seed and the judge's mode/run_tag in-place. compare mode needs a receptor (to inject the prompt). describe mode is the first @@ -76,6 +76,10 @@ def _inject(graph: dict, prompt: str, negative: str, seed: int, run_tag: str, mo inputs["model_select"] = model_select if model_path: inputs["model_path"] = model_path + if system_prompt: + inputs["system_prompt"] = system_prompt + if user_prompt: + inputs["user_prompt"] = user_prompt if mode == "compare" and not found_receptor: raise SystemExit( f"[agent_bridge] no '{RECEPTOR_CLASS}' node in the workflow — add the " @@ -116,8 +120,10 @@ def main(argv=None): ap = argparse.ArgumentParser(description="Drive one ComfyUI calibration iteration.") ap.add_argument("--server", default="127.0.0.1:8188") ap.add_argument("--workflow", required=True, help="API-format workflow JSON") - ap.add_argument("--mode", choices=["compare", "describe"], default="compare", - help="describe = first pass over the reference only (no prompt needed)") + ap.add_argument("--mode", choices=["compare", "describe", "chat"], default="compare", + help="describe = first pass over the reference; chat = general VLM with your prompts") + ap.add_argument("--system-prompt", default="", help="chat mode: system prompt") + ap.add_argument("--user-prompt", default="", help="chat mode: user prompt over the image(s)") ap.add_argument("--prompt", default="", help="generation prompt (required for compare)") ap.add_argument("--negative", default="") ap.add_argument("--seed", type=int, default=0) @@ -150,7 +156,7 @@ def main(argv=None): graph = json.load(f) _inject(graph, args.prompt, args.negative, args.seed, args.run_tag, args.mode, ref_desc, - args.profile, args.model_select, args.model_path) + args.profile, args.model_select, args.model_path, args.system_prompt, args.user_prompt) client_id = uuid.uuid4().hex try: diff --git a/nodes/qwen_judge.py b/nodes/qwen_judge.py index 38dc126..17f1df4 100644 --- a/nodes/qwen_judge.py +++ b/nodes/qwen_judge.py @@ -498,6 +498,17 @@ def _build_describe_prompt(axes: list[str]) -> str: ) +def _run_chat(model, processor, images, system_prompt, user_prompt, max_new_tokens, temperature): + """General VLM pass: your own system/user prompt over the image(s) -> raw text.""" + content = [{"type": "image", "image": img} for img in images] + content.append({"type": "text", "text": user_prompt or "Describe this image."}) + messages = [] + if system_prompt.strip(): + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": content}) + return _generate_from_messages(model, processor, messages, images, max_new_tokens, temperature) + + def _run_describe(model, processor, ref_pil, axes, max_new_tokens, temperature): """Describe pass: reference only -> raw JSON {caption, axes} string.""" messages = [ @@ -651,6 +662,27 @@ def _write_report(report_dir, run_tag, overall, merged, diff_analysis, raw_all, return run_path +def _write_chat_report(report_dir, run_tag, system_prompt, user_prompt, response): + """Persist a general-VLM (chat) response so the agent/loop can read it.""" + base = _report_base_dir(report_dir) + try: + os.makedirs(base, exist_ok=True) + except OSError as e: + print(f"[QwenVLImageJudge] could not create report dir {base}: {e}") + return "" + payload = {"mode": "chat", "run_tag": run_tag, "system_prompt": system_prompt, + "user_prompt": user_prompt, "response": response} + tag = re.sub(r"[^A-Za-z0-9._-]", "_", run_tag) if run_tag else "chat" + run_path = os.path.join(base, f"calib_{tag}.json") + for path in (run_path, os.path.join(base, "latest.json")): + try: + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + except OSError as e: + print(f"[QwenVLImageJudge] failed writing report {path}: {e}") + return run_path + + def _format_canonical_reference(caption: str, axes_spec: dict) -> str: """One canonical reference description = the paragraph + the per-axis target values. The compare pass anchors on this so the reference side stays consistent @@ -703,9 +735,10 @@ class QwenVLImageJudge: return { "required": { "reference_image": ("IMAGE",), - # describe = reference only -> target description (first pass, seeds the - # initial prompt). compare = ref vs generated -> per-axis scoring. - "mode": (["compare", "describe"], {"default": "compare"}), + # compare = ref vs generated -> per-axis scoring. describe = reference only + # -> target description (first pass). chat = general VLM: your own + # system_prompt + user_prompt over the image(s) -> raw text. + "mode": (["compare", "describe", "chat"], {"default": "compare"}), # Analysis profile: act-specialized axis set (distance-aware where it # matters). `axes` below overrides it when non-empty. "profile": (list(PROFILES.keys()), {"default": "general"}), @@ -730,6 +763,9 @@ class QwenVLImageJudge: # compare: canonical reference text (from describe). When set, compare # anchors on it instead of re-reading the reference image each time. "reference_description": ("STRING", {"default": "", "multiline": True}), + # chat mode: use the node as a general VLM with your own prompts. + "system_prompt": ("STRING", {"default": "", "multiline": True}), + "user_prompt": ("STRING", {"default": "Describe this image.", "multiline": True}), }, } @@ -737,7 +773,8 @@ class QwenVLImageJudge: max_new_tokens, temperature, swap_eval, profile="general", model_select=MANUAL_CHOICE, generated_image=None, keep_loaded=True, auto_download=True, - report_dir="", run_tag="", prompt_used="", reference_description=""): + report_dir="", run_tag="", prompt_used="", reference_description="", + system_prompt="", user_prompt="Describe this image."): # `axes` overrides the profile when provided; otherwise use the profile's axis set. axis_list = [a.strip() for a in re.split(r"[,\n]", axes) if a.strip()] if not axis_list: @@ -772,6 +809,12 @@ class QwenVLImageJudge: ref_pil = _tensor_to_pil(reference_image) model, processor = _load_model(resolved_path, eff_precision) + if mode == "chat": + gen_pil = _tensor_to_pil(generated_image) if generated_image is not None else None + return self._chat(model, processor, ref_pil, gen_pil, system_prompt, user_prompt, + max_new_tokens, temperature, resolved_path, eff_precision, + keep_loaded, report_dir, run_tag) + if mode == "describe": return self._describe(model, processor, ref_pil, axis_list, max_new_tokens, temperature, resolved_path, eff_precision, keep_loaded, @@ -827,6 +870,20 @@ class QwenVLImageJudge: return (round(overall, 4), axis_scores, diff_analysis, raw_all, report_path) + def _chat(self, model, processor, ref_pil, gen_pil, system_prompt, user_prompt, + max_new_tokens, temperature, resolved_path, precision, keep_loaded, + report_dir, run_tag): + """General-VLM mode: not a judge — just runs your prompt over the image(s).""" + images = [ref_pil] + ([gen_pil] if gen_pil is not None else []) + text = _run_chat(model, processor, images, system_prompt, user_prompt, + max_new_tokens, temperature).strip() + if not keep_loaded: + _MODEL_CACHE.pop((resolved_path, precision), None) + del model + torch.cuda.empty_cache() + report_path = _write_chat_report(report_dir, run_tag, system_prompt, user_prompt, text) + return (1.0, "{}", text, text, report_path) + def _describe(self, model, processor, ref_pil, axis_list, max_new_tokens, temperature, resolved_path, precision, keep_loaded, report_dir, run_tag): """First pass: describe the reference image the generator must reproduce.