feat: add batch_size parameter to training (default 4)
Replaces single-sample steps with batched sampling via random.choices(). Tensors are stacked to [B, T, C] before the forward pass; t is now [B]. Default grad_accum lowered to 1 since real batching gives stable gradients. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+10
-8
@@ -160,7 +160,8 @@ def main():
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--steps", type=int, default=2000)
|
||||
parser.add_argument("--warmup_steps",type=int, default=100)
|
||||
parser.add_argument("--grad_accum", type=int, default=4, help="Gradient accumulation steps")
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="Clips per training step")
|
||||
parser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
|
||||
parser.add_argument("--save_every", type=int, default=500)
|
||||
parser.add_argument("--resume", default=None,
|
||||
help="Path to a step checkpoint (.pt) to resume training from.")
|
||||
@@ -326,21 +327,22 @@ def main():
|
||||
|
||||
remaining = args.steps - start_step
|
||||
print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1} → {args.steps}), "
|
||||
f"lr={args.lr}, grad_accum={args.grad_accum}")
|
||||
f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
|
||||
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
|
||||
|
||||
total_loss = 0.0
|
||||
for step in range(start_step + 1, args.steps + 1):
|
||||
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
||||
batch = random.choices(dataset, k=args.batch_size)
|
||||
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
||||
|
||||
x1 = x1_cpu.to(device, dtype)
|
||||
clip_f = clip_f_cpu.to(device, dtype)
|
||||
sync_f = sync_f_cpu.to(device, dtype)
|
||||
text_clip = text_clip_cpu.to(device, dtype)
|
||||
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
|
||||
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
|
||||
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
||||
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
|
||||
|
||||
net_generator.normalize(x1)
|
||||
|
||||
t = torch.rand(1, device=device, dtype=dtype)
|
||||
t = torch.rand(args.batch_size, device=device, dtype=dtype)
|
||||
x0 = torch.randn_like(x1)
|
||||
xt = fm.get_conditional_flow(x0, x1, t)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user