diff --git a/LORA_TRAINING.md b/LORA_TRAINING.md index 8bd5696..3bc1d5c 100644 --- a/LORA_TRAINING.md +++ b/LORA_TRAINING.md @@ -107,7 +107,8 @@ The script will: | `--lr` | `1e-4` | Learning rate | | `--steps` | `2000` | Total training steps | | `--warmup_steps` | `100` | Linear LR warmup steps | -| `--grad_accum` | `4` | Gradient accumulation steps (effective batch = grad_accum × 1) | +| `--batch_size` | `4` | Clips per training step | +| `--grad_accum` | `1` | Gradient accumulation steps | | `--save_every` | `500` | Save a checkpoint every N steps | | `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step01000.pt`) | | `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` | diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 47f8cb1..ece535f 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -260,9 +260,11 @@ class SelvaLoraTrainer: "default": "attn.qkv", "tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.", }), + "batch_size": ("INT", {"default": 4, "min": 1, "max": 32, + "tooltip": "Number of clips per training step. Higher = more stable gradients, more VRAM."}), "warmup_steps": ("INT", {"default": 100, "min": 0, "max": 5000}), - "grad_accum": ("INT", {"default": 4, "min": 1, "max": 32, - "tooltip": "Gradient accumulation steps."}), + "grad_accum": ("INT", {"default": 1, "min": 1, "max": 32, + "tooltip": "Gradient accumulation steps. Usually 1 when batch_size > 1."}), "save_every": ("INT", {"default": 500, "min": 50, "max": 10000}), "resume_path": ("STRING", { "default": "", @@ -289,8 +291,8 @@ class SelvaLoraTrainer: ) def train(self, model, data_dir, output_dir, steps, rank, lr, - alpha=0.0, target="attn.qkv", warmup_steps=100, - grad_accum=4, save_every=500, resume_path="", seed=42): + alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100, + grad_accum=1, save_every=500, resume_path="", seed=42): torch.manual_seed(seed) random.seed(seed) @@ -392,7 +394,7 @@ class SelvaLoraTrainer: model, dataset, feature_utils_orig, seq_cfg, device, dtype, variant, mode, data_dir, output_dir, steps, rank, lr, - alpha_val, target_suffixes, warmup_steps, + alpha_val, target_suffixes, batch_size, warmup_steps, grad_accum, save_every, resume_path, seed, ) @@ -400,7 +402,7 @@ class SelvaLoraTrainer: self, model, dataset, feature_utils_orig, seq_cfg, device, dtype, variant, mode, data_dir, output_dir, steps, rank, lr, - alpha_val, target_suffixes, warmup_steps, + alpha_val, target_suffixes, batch_size, warmup_steps, grad_accum, save_every, resume_path, seed, ): # --- Prepare generator copy with LoRA --- @@ -469,19 +471,20 @@ class SelvaLoraTrainer: } print(f"\n[LoRA Trainer] Training {remaining} steps " - f"(step {start_step + 1} → {steps})\n", flush=True) + f"(step {start_step + 1} → {steps}, batch_size={batch_size})\n", flush=True) for step in range(start_step + 1, steps + 1): - x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) + batch = random.choices(dataset, k=batch_size) + x1_list, clip_list, sync_list, text_list = zip(*batch) - x1 = x1_cpu.to(device, dtype) - clip_f = clip_f_cpu.to(device, dtype) - sync_f = sync_f_cpu.to(device, dtype) - text_clip = text_clip_cpu.to(device, dtype) + x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype) + clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype) + sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype) + text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype) generator.normalize(x1) - t = torch.rand(1, device=device, dtype=dtype) + t = torch.rand(batch_size, device=device, dtype=dtype) x0 = torch.randn_like(x1) xt = fm.get_conditional_flow(x0, x1, t) diff --git a/train_lora.py b/train_lora.py index 3ecbe8d..fa9e8d4 100644 --- a/train_lora.py +++ b/train_lora.py @@ -160,7 +160,8 @@ def main(): parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--steps", type=int, default=2000) parser.add_argument("--warmup_steps",type=int, default=100) - parser.add_argument("--grad_accum", type=int, default=4, help="Gradient accumulation steps") + parser.add_argument("--batch_size", type=int, default=4, help="Clips per training step") + parser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--save_every", type=int, default=500) parser.add_argument("--resume", default=None, help="Path to a step checkpoint (.pt) to resume training from.") @@ -326,21 +327,22 @@ def main(): remaining = args.steps - start_step print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1} → {args.steps}), " - f"lr={args.lr}, grad_accum={args.grad_accum}") + f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}") print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n") total_loss = 0.0 for step in range(start_step + 1, args.steps + 1): - x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) + batch = random.choices(dataset, k=args.batch_size) + x1_list, clip_list, sync_list, text_list = zip(*batch) - x1 = x1_cpu.to(device, dtype) - clip_f = clip_f_cpu.to(device, dtype) - sync_f = sync_f_cpu.to(device, dtype) - text_clip = text_clip_cpu.to(device, dtype) + x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype) + clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype) + sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype) + text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype) net_generator.normalize(x1) - t = torch.rand(1, device=device, dtype=dtype) + t = torch.rand(args.batch_size, device=device, dtype=dtype) x0 = torch.randn_like(x1) xt = fm.get_conditional_flow(x0, x1, t)