feat: add batch_size parameter to training (default 4)

Replaces single-sample steps with batched sampling via random.choices().
Tensors are stacked to [B, T, C] before the forward pass; t is now [B].
Default grad_accum lowered to 1 since real batching gives stable gradients.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 23:36:12 +02:00
parent 3f67de694c
commit 09b3b94ddd
3 changed files with 28 additions and 22 deletions
+16 -13
View File
@@ -260,9 +260,11 @@ class SelvaLoraTrainer:
"default": "attn.qkv",
"tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.",
}),
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32,
"tooltip": "Number of clips per training step. Higher = more stable gradients, more VRAM."}),
"warmup_steps": ("INT", {"default": 100, "min": 0, "max": 5000}),
"grad_accum": ("INT", {"default": 4, "min": 1, "max": 32,
"tooltip": "Gradient accumulation steps."}),
"grad_accum": ("INT", {"default": 1, "min": 1, "max": 32,
"tooltip": "Gradient accumulation steps. Usually 1 when batch_size > 1."}),
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
"resume_path": ("STRING", {
"default": "",
@@ -289,8 +291,8 @@ class SelvaLoraTrainer:
)
def train(self, model, data_dir, output_dir, steps, rank, lr,
alpha=0.0, target="attn.qkv", warmup_steps=100,
grad_accum=4, save_every=500, resume_path="", seed=42):
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
grad_accum=1, save_every=500, resume_path="", seed=42):
torch.manual_seed(seed)
random.seed(seed)
@@ -392,7 +394,7 @@ class SelvaLoraTrainer:
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, warmup_steps,
alpha_val, target_suffixes, batch_size, warmup_steps,
grad_accum, save_every, resume_path, seed,
)
@@ -400,7 +402,7 @@ class SelvaLoraTrainer:
self, model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, warmup_steps,
alpha_val, target_suffixes, batch_size, warmup_steps,
grad_accum, save_every, resume_path, seed,
):
# --- Prepare generator copy with LoRA ---
@@ -469,19 +471,20 @@ class SelvaLoraTrainer:
}
print(f"\n[LoRA Trainer] Training {remaining} steps "
f"(step {start_step + 1}{steps})\n", flush=True)
f"(step {start_step + 1}{steps}, batch_size={batch_size})\n", flush=True)
for step in range(start_step + 1, steps + 1):
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
batch = random.choices(dataset, k=batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = x1_cpu.to(device, dtype)
clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype)
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
generator.normalize(x1)
t = torch.rand(1, device=device, dtype=dtype)
t = torch.rand(batch_size, device=device, dtype=dtype)
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)