feat: add batch_size parameter to training (default 4)
Replaces single-sample steps with batched sampling via random.choices(). Tensors are stacked to [B, T, C] before the forward pass; t is now [B]. Default grad_accum lowered to 1 since real batching gives stable gradients. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+16
-13
@@ -260,9 +260,11 @@ class SelvaLoraTrainer:
|
||||
"default": "attn.qkv",
|
||||
"tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.",
|
||||
}),
|
||||
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32,
|
||||
"tooltip": "Number of clips per training step. Higher = more stable gradients, more VRAM."}),
|
||||
"warmup_steps": ("INT", {"default": 100, "min": 0, "max": 5000}),
|
||||
"grad_accum": ("INT", {"default": 4, "min": 1, "max": 32,
|
||||
"tooltip": "Gradient accumulation steps."}),
|
||||
"grad_accum": ("INT", {"default": 1, "min": 1, "max": 32,
|
||||
"tooltip": "Gradient accumulation steps. Usually 1 when batch_size > 1."}),
|
||||
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
|
||||
"resume_path": ("STRING", {
|
||||
"default": "",
|
||||
@@ -289,8 +291,8 @@ class SelvaLoraTrainer:
|
||||
)
|
||||
|
||||
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
||||
alpha=0.0, target="attn.qkv", warmup_steps=100,
|
||||
grad_accum=4, save_every=500, resume_path="", seed=42):
|
||||
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
||||
grad_accum=1, save_every=500, resume_path="", seed=42):
|
||||
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
@@ -392,7 +394,7 @@ class SelvaLoraTrainer:
|
||||
model, dataset, feature_utils_orig, seq_cfg,
|
||||
device, dtype, variant, mode,
|
||||
data_dir, output_dir, steps, rank, lr,
|
||||
alpha_val, target_suffixes, warmup_steps,
|
||||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||
grad_accum, save_every, resume_path, seed,
|
||||
)
|
||||
|
||||
@@ -400,7 +402,7 @@ class SelvaLoraTrainer:
|
||||
self, model, dataset, feature_utils_orig, seq_cfg,
|
||||
device, dtype, variant, mode,
|
||||
data_dir, output_dir, steps, rank, lr,
|
||||
alpha_val, target_suffixes, warmup_steps,
|
||||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||
grad_accum, save_every, resume_path, seed,
|
||||
):
|
||||
# --- Prepare generator copy with LoRA ---
|
||||
@@ -469,19 +471,20 @@ class SelvaLoraTrainer:
|
||||
}
|
||||
|
||||
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
||||
f"(step {start_step + 1} → {steps})\n", flush=True)
|
||||
f"(step {start_step + 1} → {steps}, batch_size={batch_size})\n", flush=True)
|
||||
|
||||
for step in range(start_step + 1, steps + 1):
|
||||
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
||||
batch = random.choices(dataset, k=batch_size)
|
||||
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
||||
|
||||
x1 = x1_cpu.to(device, dtype)
|
||||
clip_f = clip_f_cpu.to(device, dtype)
|
||||
sync_f = sync_f_cpu.to(device, dtype)
|
||||
text_clip = text_clip_cpu.to(device, dtype)
|
||||
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
|
||||
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
|
||||
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
||||
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
|
||||
|
||||
generator.normalize(x1)
|
||||
|
||||
t = torch.rand(1, device=device, dtype=dtype)
|
||||
t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||
x0 = torch.randn_like(x1)
|
||||
xt = fm.get_conditional_flow(x0, x1, t)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user