feat: add batch_size parameter to training (default 4)
Replaces single-sample steps with batched sampling via random.choices(). Tensors are stacked to [B, T, C] before the forward pass; t is now [B]. Default grad_accum lowered to 1 since real batching gives stable gradients. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+2
-1
@@ -107,7 +107,8 @@ The script will:
|
|||||||
| `--lr` | `1e-4` | Learning rate |
|
| `--lr` | `1e-4` | Learning rate |
|
||||||
| `--steps` | `2000` | Total training steps |
|
| `--steps` | `2000` | Total training steps |
|
||||||
| `--warmup_steps` | `100` | Linear LR warmup steps |
|
| `--warmup_steps` | `100` | Linear LR warmup steps |
|
||||||
| `--grad_accum` | `4` | Gradient accumulation steps (effective batch = grad_accum × 1) |
|
| `--batch_size` | `4` | Clips per training step |
|
||||||
|
| `--grad_accum` | `1` | Gradient accumulation steps |
|
||||||
| `--save_every` | `500` | Save a checkpoint every N steps |
|
| `--save_every` | `500` | Save a checkpoint every N steps |
|
||||||
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step01000.pt`) |
|
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step01000.pt`) |
|
||||||
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
|
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
|
||||||
|
|||||||
+16
-13
@@ -260,9 +260,11 @@ class SelvaLoraTrainer:
|
|||||||
"default": "attn.qkv",
|
"default": "attn.qkv",
|
||||||
"tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.",
|
"tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.",
|
||||||
}),
|
}),
|
||||||
|
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32,
|
||||||
|
"tooltip": "Number of clips per training step. Higher = more stable gradients, more VRAM."}),
|
||||||
"warmup_steps": ("INT", {"default": 100, "min": 0, "max": 5000}),
|
"warmup_steps": ("INT", {"default": 100, "min": 0, "max": 5000}),
|
||||||
"grad_accum": ("INT", {"default": 4, "min": 1, "max": 32,
|
"grad_accum": ("INT", {"default": 1, "min": 1, "max": 32,
|
||||||
"tooltip": "Gradient accumulation steps."}),
|
"tooltip": "Gradient accumulation steps. Usually 1 when batch_size > 1."}),
|
||||||
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
|
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
|
||||||
"resume_path": ("STRING", {
|
"resume_path": ("STRING", {
|
||||||
"default": "",
|
"default": "",
|
||||||
@@ -289,8 +291,8 @@ class SelvaLoraTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
||||||
alpha=0.0, target="attn.qkv", warmup_steps=100,
|
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
||||||
grad_accum=4, save_every=500, resume_path="", seed=42):
|
grad_accum=1, save_every=500, resume_path="", seed=42):
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@@ -392,7 +394,7 @@ class SelvaLoraTrainer:
|
|||||||
model, dataset, feature_utils_orig, seq_cfg,
|
model, dataset, feature_utils_orig, seq_cfg,
|
||||||
device, dtype, variant, mode,
|
device, dtype, variant, mode,
|
||||||
data_dir, output_dir, steps, rank, lr,
|
data_dir, output_dir, steps, rank, lr,
|
||||||
alpha_val, target_suffixes, warmup_steps,
|
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||||
grad_accum, save_every, resume_path, seed,
|
grad_accum, save_every, resume_path, seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -400,7 +402,7 @@ class SelvaLoraTrainer:
|
|||||||
self, model, dataset, feature_utils_orig, seq_cfg,
|
self, model, dataset, feature_utils_orig, seq_cfg,
|
||||||
device, dtype, variant, mode,
|
device, dtype, variant, mode,
|
||||||
data_dir, output_dir, steps, rank, lr,
|
data_dir, output_dir, steps, rank, lr,
|
||||||
alpha_val, target_suffixes, warmup_steps,
|
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||||
grad_accum, save_every, resume_path, seed,
|
grad_accum, save_every, resume_path, seed,
|
||||||
):
|
):
|
||||||
# --- Prepare generator copy with LoRA ---
|
# --- Prepare generator copy with LoRA ---
|
||||||
@@ -469,19 +471,20 @@ class SelvaLoraTrainer:
|
|||||||
}
|
}
|
||||||
|
|
||||||
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
||||||
f"(step {start_step + 1} → {steps})\n", flush=True)
|
f"(step {start_step + 1} → {steps}, batch_size={batch_size})\n", flush=True)
|
||||||
|
|
||||||
for step in range(start_step + 1, steps + 1):
|
for step in range(start_step + 1, steps + 1):
|
||||||
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
batch = random.choices(dataset, k=batch_size)
|
||||||
|
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
||||||
|
|
||||||
x1 = x1_cpu.to(device, dtype)
|
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
|
||||||
clip_f = clip_f_cpu.to(device, dtype)
|
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
|
||||||
sync_f = sync_f_cpu.to(device, dtype)
|
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
||||||
text_clip = text_clip_cpu.to(device, dtype)
|
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
|
||||||
|
|
||||||
generator.normalize(x1)
|
generator.normalize(x1)
|
||||||
|
|
||||||
t = torch.rand(1, device=device, dtype=dtype)
|
t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||||
x0 = torch.randn_like(x1)
|
x0 = torch.randn_like(x1)
|
||||||
xt = fm.get_conditional_flow(x0, x1, t)
|
xt = fm.get_conditional_flow(x0, x1, t)
|
||||||
|
|
||||||
|
|||||||
+10
-8
@@ -160,7 +160,8 @@ def main():
|
|||||||
parser.add_argument("--lr", type=float, default=1e-4)
|
parser.add_argument("--lr", type=float, default=1e-4)
|
||||||
parser.add_argument("--steps", type=int, default=2000)
|
parser.add_argument("--steps", type=int, default=2000)
|
||||||
parser.add_argument("--warmup_steps",type=int, default=100)
|
parser.add_argument("--warmup_steps",type=int, default=100)
|
||||||
parser.add_argument("--grad_accum", type=int, default=4, help="Gradient accumulation steps")
|
parser.add_argument("--batch_size", type=int, default=4, help="Clips per training step")
|
||||||
|
parser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
|
||||||
parser.add_argument("--save_every", type=int, default=500)
|
parser.add_argument("--save_every", type=int, default=500)
|
||||||
parser.add_argument("--resume", default=None,
|
parser.add_argument("--resume", default=None,
|
||||||
help="Path to a step checkpoint (.pt) to resume training from.")
|
help="Path to a step checkpoint (.pt) to resume training from.")
|
||||||
@@ -326,21 +327,22 @@ def main():
|
|||||||
|
|
||||||
remaining = args.steps - start_step
|
remaining = args.steps - start_step
|
||||||
print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1} → {args.steps}), "
|
print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1} → {args.steps}), "
|
||||||
f"lr={args.lr}, grad_accum={args.grad_accum}")
|
f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
|
||||||
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
|
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
|
||||||
|
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
for step in range(start_step + 1, args.steps + 1):
|
for step in range(start_step + 1, args.steps + 1):
|
||||||
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
batch = random.choices(dataset, k=args.batch_size)
|
||||||
|
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
||||||
|
|
||||||
x1 = x1_cpu.to(device, dtype)
|
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
|
||||||
clip_f = clip_f_cpu.to(device, dtype)
|
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
|
||||||
sync_f = sync_f_cpu.to(device, dtype)
|
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
||||||
text_clip = text_clip_cpu.to(device, dtype)
|
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
|
||||||
|
|
||||||
net_generator.normalize(x1)
|
net_generator.normalize(x1)
|
||||||
|
|
||||||
t = torch.rand(1, device=device, dtype=dtype)
|
t = torch.rand(args.batch_size, device=device, dtype=dtype)
|
||||||
x0 = torch.randn_like(x1)
|
x0 = torch.randn_like(x1)
|
||||||
xt = fm.get_conditional_flow(x0, x1, t)
|
xt = fm.get_conditional_flow(x0, x1, t)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user