feat: add LoRA dropout, LoRA+ asymmetric LR, and curriculum timestep sampling
- LoRA dropout: applied to the LoRA path only (not frozen base weights), 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610) - LoRA+: separate optimizer param groups for lora_A and lora_B with configurable LR ratio; ratio=16 enables LoRA+ (arXiv:2402.12354) - Curriculum mode: logit_normal for first N% of steps then uniform, directly addresses early convergence + fine-detail degradation at boundaries (arXiv:2603.12517) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+54
-13
@@ -271,17 +271,33 @@ class SelvaLoraTrainer:
|
||||
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
|
||||
}),
|
||||
"seed": ("INT", {"default": 42}),
|
||||
"timestep_mode": (["uniform", "logit_normal"], {
|
||||
"timestep_mode": (["uniform", "logit_normal", "curriculum"], {
|
||||
"default": "uniform",
|
||||
"tooltip": "How to sample training timesteps. "
|
||||
"uniform samples all timesteps equally (default, matches original MMAudio training). "
|
||||
"logit_normal concentrates steps near t=0.5 — reaches lower loss but perceptual improvement is dataset-dependent.",
|
||||
"uniform: all timesteps equally (matches original MMAudio). "
|
||||
"logit_normal: concentrates near t=0.5. "
|
||||
"curriculum: logit_normal for first curriculum_switch% of steps then uniform (recommended for small datasets).",
|
||||
}),
|
||||
"logit_normal_sigma": ("FLOAT", {
|
||||
"default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1,
|
||||
"tooltip": "Spread of the logit-normal distribution. "
|
||||
"1.0 = moderate peak at t=0.5. Higher approaches uniform. "
|
||||
"Only used when timestep_mode=logit_normal.",
|
||||
"Used with logit_normal and curriculum modes.",
|
||||
}),
|
||||
"curriculum_switch": ("FLOAT", {
|
||||
"default": 0.6, "min": 0.1, "max": 0.9, "step": 0.05,
|
||||
"tooltip": "Fraction of steps to run logit_normal before switching to uniform. "
|
||||
"0.6 = switch at 60% of total steps. Only used with timestep_mode=curriculum.",
|
||||
}),
|
||||
"lora_dropout": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.01,
|
||||
"tooltip": "Dropout applied to the LoRA path only (not the frozen base weights). "
|
||||
"0=disabled. 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610).",
|
||||
}),
|
||||
"lora_plus_ratio": ("FLOAT", {
|
||||
"default": 1.0, "min": 1.0, "max": 32.0, "step": 1.0,
|
||||
"tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. "
|
||||
"1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).",
|
||||
}),
|
||||
},
|
||||
}
|
||||
@@ -305,7 +321,8 @@ class SelvaLoraTrainer:
|
||||
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
||||
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
||||
grad_accum=1, save_every=500, resume_path="", seed=42,
|
||||
timestep_mode="uniform", logit_normal_sigma=1.0):
|
||||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||
lora_dropout=0.0, lora_plus_ratio=1.0):
|
||||
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
@@ -442,7 +459,8 @@ class SelvaLoraTrainer:
|
||||
data_dir, output_dir, steps, rank, lr,
|
||||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||
grad_accum, save_every, resume_path, seed,
|
||||
timestep_mode, logit_normal_sigma,
|
||||
timestep_mode, logit_normal_sigma, curriculum_switch,
|
||||
lora_dropout, lora_plus_ratio,
|
||||
)
|
||||
|
||||
def _train_inner(
|
||||
@@ -451,19 +469,21 @@ class SelvaLoraTrainer:
|
||||
data_dir, output_dir, steps, rank, lr,
|
||||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||
grad_accum, save_every, resume_path, seed,
|
||||
timestep_mode="uniform", logit_normal_sigma=1.0,
|
||||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||
lora_dropout=0.0, lora_plus_ratio=1.0,
|
||||
):
|
||||
# --- Prepare generator copy with LoRA ---
|
||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||
|
||||
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
|
||||
target_suffixes=target_suffixes)
|
||||
target_suffixes=target_suffixes, dropout=lora_dropout)
|
||||
if n_lora == 0:
|
||||
raise RuntimeError(
|
||||
f"[LoRA Trainer] No layers matched target={target_suffixes}. "
|
||||
"Check the 'target' field."
|
||||
)
|
||||
print(f"[LoRA Trainer] Wrapped {n_lora} layers (rank={rank}, alpha={alpha_val})", flush=True)
|
||||
print(f"[LoRA Trainer] Wrapped {n_lora} layers "
|
||||
f"(rank={rank}, alpha={alpha_val}, dropout={lora_dropout})", flush=True)
|
||||
|
||||
for name, p in generator.named_parameters():
|
||||
p.requires_grad_("lora_" in name)
|
||||
@@ -475,8 +495,16 @@ class SelvaLoraTrainer:
|
||||
)
|
||||
|
||||
# --- Optimizer + scheduler ---
|
||||
lora_params = [p for p in generator.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.AdamW(lora_params, lr=lr, weight_decay=1e-2)
|
||||
# LoRA+: split A and B into separate param groups with different LRs.
|
||||
# ratio=1.0 = standard LoRA (same LR for both). ratio=16 = LoRA+.
|
||||
lora_A_params = [p for n, p in generator.named_parameters() if "lora_A" in n and p.requires_grad]
|
||||
lora_B_params = [p for n, p in generator.named_parameters() if "lora_B" in n and p.requires_grad]
|
||||
optimizer = torch.optim.AdamW([
|
||||
{"params": lora_A_params, "lr": lr},
|
||||
{"params": lora_B_params, "lr": lr * lora_plus_ratio},
|
||||
], weight_decay=1e-2)
|
||||
if lora_plus_ratio != 1.0:
|
||||
print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True)
|
||||
|
||||
def lr_lambda(s):
|
||||
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
|
||||
@@ -518,8 +546,15 @@ class SelvaLoraTrainer:
|
||||
"steps": steps,
|
||||
"timestep_mode": timestep_mode,
|
||||
"logit_normal_sigma": logit_normal_sigma,
|
||||
"curriculum_switch": curriculum_switch,
|
||||
"lora_dropout": lora_dropout,
|
||||
"lora_plus_ratio": lora_plus_ratio,
|
||||
}
|
||||
|
||||
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
|
||||
curriculum_switch_step = start_step + int((steps - start_step) * curriculum_switch)
|
||||
_curriculum_switched = False
|
||||
|
||||
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
||||
f"(step {start_step + 1} → {steps}, batch_size={batch_size}, "
|
||||
f"timestep_mode={timestep_mode})\n", flush=True)
|
||||
@@ -538,11 +573,17 @@ class SelvaLoraTrainer:
|
||||
|
||||
generator.normalize(x1)
|
||||
|
||||
if timestep_mode == "logit_normal":
|
||||
if timestep_mode == "logit_normal" or (
|
||||
timestep_mode == "curriculum" and step <= curriculum_switch_step
|
||||
):
|
||||
u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma
|
||||
t = torch.sigmoid(u)
|
||||
else:
|
||||
t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||
|
||||
if timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched:
|
||||
print(f"[LoRA Trainer] Curriculum switch: logit_normal → uniform at step {step}", flush=True)
|
||||
_curriculum_switched = True
|
||||
x0 = torch.randn_like(x1)
|
||||
xt = fm.get_conditional_flow(x0, x1, t)
|
||||
|
||||
@@ -552,7 +593,7 @@ class SelvaLoraTrainer:
|
||||
running_loss += loss.item() * grad_accum
|
||||
|
||||
if step % grad_accum == 0:
|
||||
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
|
||||
torch.nn.utils.clip_grad_norm_(lora_A_params + lora_B_params, max_norm=1.0)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
Reference in New Issue
Block a user