feat: add batch_size parameter to training (default 4)

Replaces single-sample steps with batched sampling via random.choices().
Tensors are stacked to [B, T, C] before the forward pass; t is now [B].
Default grad_accum lowered to 1 since real batching gives stable gradients.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 23:36:12 +02:00
parent 3f67de694c
commit 09b3b94ddd
3 changed files with 28 additions and 22 deletions
+2 -1
View File
@@ -107,7 +107,8 @@ The script will:
| `--lr` | `1e-4` | Learning rate | | `--lr` | `1e-4` | Learning rate |
| `--steps` | `2000` | Total training steps | | `--steps` | `2000` | Total training steps |
| `--warmup_steps` | `100` | Linear LR warmup steps | | `--warmup_steps` | `100` | Linear LR warmup steps |
| `--grad_accum` | `4` | Gradient accumulation steps (effective batch = grad_accum × 1) | | `--batch_size` | `4` | Clips per training step |
| `--grad_accum` | `1` | Gradient accumulation steps |
| `--save_every` | `500` | Save a checkpoint every N steps | | `--save_every` | `500` | Save a checkpoint every N steps |
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step01000.pt`) | | `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step01000.pt`) |
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` | | `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
+16 -13
View File
@@ -260,9 +260,11 @@ class SelvaLoraTrainer:
"default": "attn.qkv", "default": "attn.qkv",
"tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.", "tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.",
}), }),
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32,
"tooltip": "Number of clips per training step. Higher = more stable gradients, more VRAM."}),
"warmup_steps": ("INT", {"default": 100, "min": 0, "max": 5000}), "warmup_steps": ("INT", {"default": 100, "min": 0, "max": 5000}),
"grad_accum": ("INT", {"default": 4, "min": 1, "max": 32, "grad_accum": ("INT", {"default": 1, "min": 1, "max": 32,
"tooltip": "Gradient accumulation steps."}), "tooltip": "Gradient accumulation steps. Usually 1 when batch_size > 1."}),
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}), "save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
"resume_path": ("STRING", { "resume_path": ("STRING", {
"default": "", "default": "",
@@ -289,8 +291,8 @@ class SelvaLoraTrainer:
) )
def train(self, model, data_dir, output_dir, steps, rank, lr, def train(self, model, data_dir, output_dir, steps, rank, lr,
alpha=0.0, target="attn.qkv", warmup_steps=100, alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
grad_accum=4, save_every=500, resume_path="", seed=42): grad_accum=1, save_every=500, resume_path="", seed=42):
torch.manual_seed(seed) torch.manual_seed(seed)
random.seed(seed) random.seed(seed)
@@ -392,7 +394,7 @@ class SelvaLoraTrainer:
model, dataset, feature_utils_orig, seq_cfg, model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode, device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr, data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, warmup_steps, alpha_val, target_suffixes, batch_size, warmup_steps,
grad_accum, save_every, resume_path, seed, grad_accum, save_every, resume_path, seed,
) )
@@ -400,7 +402,7 @@ class SelvaLoraTrainer:
self, model, dataset, feature_utils_orig, seq_cfg, self, model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode, device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr, data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, warmup_steps, alpha_val, target_suffixes, batch_size, warmup_steps,
grad_accum, save_every, resume_path, seed, grad_accum, save_every, resume_path, seed,
): ):
# --- Prepare generator copy with LoRA --- # --- Prepare generator copy with LoRA ---
@@ -469,19 +471,20 @@ class SelvaLoraTrainer:
} }
print(f"\n[LoRA Trainer] Training {remaining} steps " print(f"\n[LoRA Trainer] Training {remaining} steps "
f"(step {start_step + 1}{steps})\n", flush=True) f"(step {start_step + 1}{steps}, batch_size={batch_size})\n", flush=True)
for step in range(start_step + 1, steps + 1): for step in range(start_step + 1, steps + 1):
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) batch = random.choices(dataset, k=batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = x1_cpu.to(device, dtype) x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = clip_f_cpu.to(device, dtype) clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype) sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype) text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
generator.normalize(x1) generator.normalize(x1)
t = torch.rand(1, device=device, dtype=dtype) t = torch.rand(batch_size, device=device, dtype=dtype)
x0 = torch.randn_like(x1) x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t) xt = fm.get_conditional_flow(x0, x1, t)
+10 -8
View File
@@ -160,7 +160,8 @@ def main():
parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--steps", type=int, default=2000) parser.add_argument("--steps", type=int, default=2000)
parser.add_argument("--warmup_steps",type=int, default=100) parser.add_argument("--warmup_steps",type=int, default=100)
parser.add_argument("--grad_accum", type=int, default=4, help="Gradient accumulation steps") parser.add_argument("--batch_size", type=int, default=4, help="Clips per training step")
parser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--save_every", type=int, default=500) parser.add_argument("--save_every", type=int, default=500)
parser.add_argument("--resume", default=None, parser.add_argument("--resume", default=None,
help="Path to a step checkpoint (.pt) to resume training from.") help="Path to a step checkpoint (.pt) to resume training from.")
@@ -326,21 +327,22 @@ def main():
remaining = args.steps - start_step remaining = args.steps - start_step
print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1}{args.steps}), " print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1}{args.steps}), "
f"lr={args.lr}, grad_accum={args.grad_accum}") f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n") print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
total_loss = 0.0 total_loss = 0.0
for step in range(start_step + 1, args.steps + 1): for step in range(start_step + 1, args.steps + 1):
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset) batch = random.choices(dataset, k=args.batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = x1_cpu.to(device, dtype) x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = clip_f_cpu.to(device, dtype) clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype) sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype) text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
net_generator.normalize(x1) net_generator.normalize(x1)
t = torch.rand(1, device=device, dtype=dtype) t = torch.rand(args.batch_size, device=device, dtype=dtype)
x0 = torch.randn_like(x1) x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t) xt = fm.get_conditional_flow(x0, x1, t)