feat: add batch_size parameter to training (default 4)

Replaces single-sample steps with batched sampling via random.choices().
Tensors are stacked to [B, T, C] before the forward pass; t is now [B].
Default grad_accum lowered to 1 since real batching gives stable gradients.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 23:36:12 +02:00
parent 3f67de694c
commit 09b3b94ddd
3 changed files with 28 additions and 22 deletions
+10 -8
View File
@@ -160,7 +160,8 @@ def main():
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--steps", type=int, default=2000)
parser.add_argument("--warmup_steps",type=int, default=100)
parser.add_argument("--grad_accum", type=int, default=4, help="Gradient accumulation steps")
parser.add_argument("--batch_size", type=int, default=4, help="Clips per training step")
parser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--save_every", type=int, default=500)
parser.add_argument("--resume", default=None,
help="Path to a step checkpoint (.pt) to resume training from.")
@@ -326,21 +327,22 @@ def main():
remaining = args.steps - start_step
print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1}{args.steps}), "
f"lr={args.lr}, grad_accum={args.grad_accum}")
f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
total_loss = 0.0
for step in range(start_step + 1, args.steps + 1):
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
batch = random.choices(dataset, k=args.batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = x1_cpu.to(device, dtype)
clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype)
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
net_generator.normalize(x1)
t = torch.rand(1, device=device, dtype=dtype)
t = torch.rand(args.batch_size, device=device, dtype=dtype)
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)