From 738190aa9cefc724fea4cfb735e2b5e111b30b97 Mon Sep 17 00:00:00 2001 From: ethanfel Date: Wed, 31 Dec 2025 14:30:05 +0100 Subject: [PATCH] Refactor JSONLoader and add batch loading classes --- json_loader.py | 152 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 138 insertions(+), 14 deletions(-) diff --git a/json_loader.py b/json_loader.py index 651c8d8..c5aa616 100644 --- a/json_loader.py +++ b/json_loader.py @@ -1,7 +1,7 @@ import json import os -# --- Helper --- +# --- Shared Helper --- def read_json_data(json_path): if not os.path.exists(json_path): print(f"[JSON Loader] Warning: File not found at {json_path}") @@ -14,7 +14,7 @@ def read_json_data(json_path): return {} # ========================================== -# 1. DEDICATED LORA NODE +# 1. DEDICATED LORA NODE (Existing) # ========================================== class JSONLoaderLoRA: @classmethod @@ -42,19 +42,18 @@ class JSONLoaderLoRA: ) # ========================================== -# 2. MAIN NODES +# 2. MAIN NODES (Existing) # ========================================== -# --- Node A: Standard (I2V) --- class JSONLoaderStandard: @classmethod def INPUT_TYPES(s): return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}} RETURN_TYPES = ( - "STRING", "STRING", "STRING", "STRING", # GenP, GenN, CurP, CurN - "STRING", "FLOAT", "INT", # Cam, FLF, Seed - "STRING", "STRING", "STRING" # Paths + "STRING", "STRING", "STRING", "STRING", + "STRING", "FLOAT", "INT", + "STRING", "STRING", "STRING" ) RETURN_NAMES = ( "general_prompt", "general_negative", "current_prompt", "negative", @@ -86,17 +85,16 @@ class JSONLoaderStandard: str(data.get("flf image path", "")) ) -# --- Node B: VACE Full --- class JSONLoaderVACE: @classmethod def INPUT_TYPES(s): return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}} RETURN_TYPES = ( - "STRING", "STRING", "STRING", "STRING", # GenP, GenN, CurP, CurN - "STRING", "FLOAT", "INT", # Cam, FLF, Seed - "INT", "INT", "INT", "STRING", "INT", "INT", # VACE Specs - "STRING", "STRING" # Paths + "STRING", "STRING", "STRING", "STRING", + "STRING", "FLOAT", "INT", + "INT", "INT", "INT", "STRING", "INT", "INT", + "STRING", "STRING" ) RETURN_NAMES = ( "general_prompt", "general_negative", "current_prompt", "negative", @@ -136,15 +134,141 @@ class JSONLoaderVACE: str(data.get("reference image path", "")) ) +# ========================================== +# 3. NEW BATCH NODES +# ========================================== + +class JSONLoaderBatchI2V: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "json_path": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}) + } + } + + RETURN_TYPES = ( + "STRING", "STRING", "STRING", "STRING", + "STRING", "FLOAT", "INT", + "STRING", "STRING", "STRING" + ) + RETURN_NAMES = ( + "general_prompt", "general_negative", "current_prompt", "negative", + "camera", "flf", "seed", + "video_file_path", "reference_image_path", "flf_image_path" + ) + FUNCTION = "load_batch_i2v" + CATEGORY = "utils/json" + + def load_batch_i2v(self, json_path, sequence_number): + data = read_json_data(json_path) + + # Batch Logic: Select specific sequence + target_data = data + if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0: + # Adjust 1-based index to 0-based + idx = sequence_number - 1 + # Modulo for looping (safely handles index > length) + idx = idx % len(data["batch_data"]) + target_data = data["batch_data"][idx] + + def to_float(val): + try: return float(val) + except: return 0.0 + def to_int(val): + try: return int(float(val)) + except: return 0 + + return ( + str(target_data.get("general_prompt", "")), + str(target_data.get("general_negative", "")), + str(target_data.get("current_prompt", "")), + str(target_data.get("negative", "")), + str(target_data.get("camera", "")), + to_float(target_data.get("flf", 0.0)), + to_int(target_data.get("seed", 0)), + str(target_data.get("video file path", "")), + str(target_data.get("reference image path", "")), + str(target_data.get("flf image path", "")) + ) + + +class JSONLoaderBatchVACE: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "json_path": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}) + } + } + + RETURN_TYPES = ( + "STRING", "STRING", "STRING", "STRING", + "STRING", "FLOAT", "INT", + "INT", "INT", "INT", "STRING", "INT", "INT", + "STRING", "STRING" + ) + RETURN_NAMES = ( + "general_prompt", "general_negative", "current_prompt", "negative", + "camera", "flf", "seed", + "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", + "video_file_path", "reference_image_path" + ) + FUNCTION = "load_batch_vace" + CATEGORY = "utils/json" + + def load_batch_vace(self, json_path, sequence_number): + data = read_json_data(json_path) + + # Batch Logic + target_data = data + if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0: + idx = sequence_number - 1 + idx = idx % len(data["batch_data"]) + target_data = data["batch_data"][idx] + + def to_float(val): + try: return float(val) + except: return 0.0 + def to_int(val): + try: return int(float(val)) + except: return 0 + + return ( + str(target_data.get("general_prompt", "")), + str(target_data.get("general_negative", "")), + str(target_data.get("current_prompt", "")), + str(target_data.get("negative", "")), + str(target_data.get("camera", "")), + to_float(target_data.get("flf", 0.0)), + to_int(target_data.get("seed", 0)), + + to_int(target_data.get("frame_to_skip", 81)), + to_int(target_data.get("input_a_frames", 0)), + to_int(target_data.get("input_b_frames", 0)), + str(target_data.get("reference path", "")), + to_int(target_data.get("reference switch", 1)), + to_int(target_data.get("vace schedule", 1)), + + str(target_data.get("video file path", "")), + str(target_data.get("reference image path", "")) + ) + # --- Mappings --- NODE_CLASS_MAPPINGS = { "JSONLoaderLoRA": JSONLoaderLoRA, "JSONLoaderStandard": JSONLoaderStandard, - "JSONLoaderVACE": JSONLoaderVACE + "JSONLoaderVACE": JSONLoaderVACE, + "JSONLoaderBatchI2V": JSONLoaderBatchI2V, + "JSONLoaderBatchVACE": JSONLoaderBatchVACE } NODE_DISPLAY_NAME_MAPPINGS = { "JSONLoaderLoRA": "JSON Loader (LoRAs Only)", "JSONLoaderStandard": "JSON Loader (Standard/I2V)", - "JSONLoaderVACE": "JSON Loader (VACE Full)" + "JSONLoaderVACE": "JSON Loader (VACE Full)", + "JSONLoaderBatchI2V": "JSON Batch Loader (I2V)", + "JSONLoaderBatchVACE": "JSON Batch Loader (VACE)" }