diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index bce709f..e2ec812 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -271,17 +271,33 @@ class SelvaLoraTrainer: "tooltip": "Path to a step checkpoint (.pt) to resume training from.", }), "seed": ("INT", {"default": 42}), - "timestep_mode": (["uniform", "logit_normal"], { + "timestep_mode": (["uniform", "logit_normal", "curriculum"], { "default": "uniform", "tooltip": "How to sample training timesteps. " - "uniform samples all timesteps equally (default, matches original MMAudio training). " - "logit_normal concentrates steps near t=0.5 — reaches lower loss but perceptual improvement is dataset-dependent.", + "uniform: all timesteps equally (matches original MMAudio). " + "logit_normal: concentrates near t=0.5. " + "curriculum: logit_normal for first curriculum_switch% of steps then uniform (recommended for small datasets).", }), "logit_normal_sigma": ("FLOAT", { "default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1, "tooltip": "Spread of the logit-normal distribution. " "1.0 = moderate peak at t=0.5. Higher approaches uniform. " - "Only used when timestep_mode=logit_normal.", + "Used with logit_normal and curriculum modes.", + }), + "curriculum_switch": ("FLOAT", { + "default": 0.6, "min": 0.1, "max": 0.9, "step": 0.05, + "tooltip": "Fraction of steps to run logit_normal before switching to uniform. " + "0.6 = switch at 60% of total steps. Only used with timestep_mode=curriculum.", + }), + "lora_dropout": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 0.3, "step": 0.01, + "tooltip": "Dropout applied to the LoRA path only (not the frozen base weights). " + "0=disabled. 0.05–0.1 helps regularize on small datasets (arXiv:2404.09610).", + }), + "lora_plus_ratio": ("FLOAT", { + "default": 1.0, "min": 1.0, "max": 32.0, "step": 1.0, + "tooltip": "LoRA+ LR ratio: lr_B = lr × ratio. " + "1.0 = standard LoRA. 16.0 = LoRA+ (arXiv:2402.12354).", }), }, } @@ -305,7 +321,8 @@ class SelvaLoraTrainer: def train(self, model, data_dir, output_dir, steps, rank, lr, alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100, grad_accum=1, save_every=500, resume_path="", seed=42, - timestep_mode="uniform", logit_normal_sigma=1.0): + timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6, + lora_dropout=0.0, lora_plus_ratio=1.0): torch.manual_seed(seed) random.seed(seed) @@ -442,7 +459,8 @@ class SelvaLoraTrainer: data_dir, output_dir, steps, rank, lr, alpha_val, target_suffixes, batch_size, warmup_steps, grad_accum, save_every, resume_path, seed, - timestep_mode, logit_normal_sigma, + timestep_mode, logit_normal_sigma, curriculum_switch, + lora_dropout, lora_plus_ratio, ) def _train_inner( @@ -451,19 +469,21 @@ class SelvaLoraTrainer: data_dir, output_dir, steps, rank, lr, alpha_val, target_suffixes, batch_size, warmup_steps, grad_accum, save_every, resume_path, seed, - timestep_mode="uniform", logit_normal_sigma=1.0, + timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6, + lora_dropout=0.0, lora_plus_ratio=1.0, ): # --- Prepare generator copy with LoRA --- generator = copy.deepcopy(model["generator"]).to(device, dtype) n_lora = apply_lora(generator, rank=rank, alpha=alpha_val, - target_suffixes=target_suffixes) + target_suffixes=target_suffixes, dropout=lora_dropout) if n_lora == 0: raise RuntimeError( f"[LoRA Trainer] No layers matched target={target_suffixes}. " "Check the 'target' field." ) - print(f"[LoRA Trainer] Wrapped {n_lora} layers (rank={rank}, alpha={alpha_val})", flush=True) + print(f"[LoRA Trainer] Wrapped {n_lora} layers " + f"(rank={rank}, alpha={alpha_val}, dropout={lora_dropout})", flush=True) for name, p in generator.named_parameters(): p.requires_grad_("lora_" in name) @@ -475,8 +495,16 @@ class SelvaLoraTrainer: ) # --- Optimizer + scheduler --- - lora_params = [p for p in generator.parameters() if p.requires_grad] - optimizer = torch.optim.AdamW(lora_params, lr=lr, weight_decay=1e-2) + # LoRA+: split A and B into separate param groups with different LRs. + # ratio=1.0 = standard LoRA (same LR for both). ratio=16 = LoRA+. + lora_A_params = [p for n, p in generator.named_parameters() if "lora_A" in n and p.requires_grad] + lora_B_params = [p for n, p in generator.named_parameters() if "lora_B" in n and p.requires_grad] + optimizer = torch.optim.AdamW([ + {"params": lora_A_params, "lr": lr}, + {"params": lora_B_params, "lr": lr * lora_plus_ratio}, + ], weight_decay=1e-2) + if lora_plus_ratio != 1.0: + print(f"[LoRA Trainer] LoRA+: lr_A={lr:.2e} lr_B={lr * lora_plus_ratio:.2e}", flush=True) def lr_lambda(s): return s / max(1, warmup_steps) if s < warmup_steps else 1.0 @@ -518,8 +546,15 @@ class SelvaLoraTrainer: "steps": steps, "timestep_mode": timestep_mode, "logit_normal_sigma": logit_normal_sigma, + "curriculum_switch": curriculum_switch, + "lora_dropout": lora_dropout, + "lora_plus_ratio": lora_plus_ratio, } + # For curriculum mode: compute the step at which we switch from logit_normal to uniform + curriculum_switch_step = start_step + int((steps - start_step) * curriculum_switch) + _curriculum_switched = False + print(f"\n[LoRA Trainer] Training {remaining} steps " f"(step {start_step + 1} → {steps}, batch_size={batch_size}, " f"timestep_mode={timestep_mode})\n", flush=True) @@ -538,11 +573,17 @@ class SelvaLoraTrainer: generator.normalize(x1) - if timestep_mode == "logit_normal": + if timestep_mode == "logit_normal" or ( + timestep_mode == "curriculum" and step <= curriculum_switch_step + ): u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma t = torch.sigmoid(u) else: t = torch.rand(batch_size, device=device, dtype=dtype) + + if timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched: + print(f"[LoRA Trainer] Curriculum switch: logit_normal → uniform at step {step}", flush=True) + _curriculum_switched = True x0 = torch.randn_like(x1) xt = fm.get_conditional_flow(x0, x1, t) @@ -552,7 +593,7 @@ class SelvaLoraTrainer: running_loss += loss.item() * grad_accum if step % grad_accum == 0: - torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) + torch.nn.utils.clip_grad_norm_(lora_A_params + lora_B_params, max_norm=1.0) optimizer.step() scheduler.step() optimizer.zero_grad() diff --git a/selva_core/model/lora.py b/selva_core/model/lora.py index f726cdb..86dd729 100644 --- a/selva_core/model/lora.py +++ b/selva_core/model/lora.py @@ -25,13 +25,14 @@ import torch.nn as nn class LoRALinear(nn.Module): """nn.Linear with a frozen base weight and trainable low-rank A/B matrices. - Output: base(x) + (x @ A.T @ B.T) * (alpha / rank) + Output: base(x) + (dropout(x) @ A.T @ B.T) * (alpha / rank) A is initialised with Kaiming uniform; B is initialised to zero so the adapter contribution starts at zero and does not disturb pretrained behaviour. + Dropout is applied only to the LoRA path, not the base linear. """ - def __init__(self, linear: nn.Linear, rank: int, alpha: float): + def __init__(self, linear: nn.Linear, rank: int, alpha: float, dropout: float = 0.0): super().__init__() in_f = linear.in_features out_f = linear.out_features @@ -46,16 +47,18 @@ class LoRALinear(nn.Module): self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device)) self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device)) self.scale = alpha / rank + self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) + (x @ self.lora_A.T @ self.lora_B.T) * self.scale + return self.linear(x) + (self.dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scale def extra_repr(self) -> str: rank = self.lora_A.shape[0] + p = self.dropout.p if isinstance(self.dropout, nn.Dropout) else 0.0 return (f"in={self.linear.in_features}, out={self.linear.out_features}, " - f"rank={rank}, scale={self.scale:.4f}") + f"rank={rank}, scale={self.scale:.4f}, dropout={p}") def apply_lora( @@ -63,6 +66,7 @@ def apply_lora( rank: int = 16, alpha: float = None, target_suffixes: tuple = ("attn.qkv",), + dropout: float = 0.0, ) -> int: """Replace matching nn.Linear layers with LoRALinear in-place. @@ -74,6 +78,8 @@ def apply_lora( ("attn.qkv",) which targets all SelfAttention QKV projections in the MM-DiT generator. Add "linear1" to also wrap post-attention output projections. + dropout: Dropout probability on the LoRA path (not the base linear). + 0.05–0.1 helps regularize on small datasets. Returns: Number of linear layers wrapped. @@ -92,7 +98,7 @@ def apply_lora( parent = model for part in parts[:-1]: parent = getattr(parent, part) - setattr(parent, parts[-1], LoRALinear(module, rank, alpha)) + setattr(parent, parts[-1], LoRALinear(module, rank, alpha, dropout=dropout)) count += 1 return count diff --git a/train_lora.py b/train_lora.py index 8203d67..9ba6c63 100644 --- a/train_lora.py +++ b/train_lora.py @@ -167,10 +167,16 @@ def main(): help="Path to a step checkpoint (.pt) to resume training from.") parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"]) parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal"], - help="Timestep sampling distribution. uniform matches original MMAudio training. logit_normal reaches lower loss but perceptual improvement is dataset-dependent.") + parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal", "curriculum"], + help="Timestep sampling. uniform=original MMAudio, logit_normal=concentrated near t=0.5, curriculum=logit_normal then uniform.") parser.add_argument("--logit_normal_sigma", type=float, default=1.0, - help="Spread of logit-normal distribution (only used with --timestep_mode logit_normal).") + help="Spread of logit-normal distribution.") + parser.add_argument("--curriculum_switch", type=float, default=0.6, + help="Fraction of steps to use logit_normal before switching to uniform (curriculum mode only).") + parser.add_argument("--lora_dropout", type=float, default=0.0, + help="Dropout on the LoRA path only. 0.05–0.1 helps on small datasets.") + parser.add_argument("--lora_plus_ratio", type=float, default=1.0, + help="LoRA+ LR ratio: lr_B = lr * ratio. 1.0=standard, 16.0=LoRA+.") args = parser.parse_args() torch.manual_seed(args.seed) @@ -234,8 +240,9 @@ def main(): rank=args.rank, alpha=args.alpha, target_suffixes=tuple(args.target), + dropout=args.lora_dropout, ) - print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target})") + print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target}, dropout={args.lora_dropout})") if n_lora == 0: print("[LoRA] ERROR: no layers were wrapped — check --target names.") sys.exit(1) @@ -315,8 +322,16 @@ def main(): print(f"[LoRA] {len(dataset)} clip(s) ready.") # --- Optimizer + LR scheduler --- - lora_params = [p for p in net_generator.parameters() if p.requires_grad] - optimizer = torch.optim.AdamW(lora_params, lr=args.lr, weight_decay=1e-2) + # LoRA+: separate param groups for A and B with different LRs. + # ratio=1.0 = standard LoRA. ratio=16 = LoRA+ (arXiv:2402.12354). + lora_A_params = [p for n, p in net_generator.named_parameters() if "lora_A" in n and p.requires_grad] + lora_B_params = [p for n, p in net_generator.named_parameters() if "lora_B" in n and p.requires_grad] + optimizer = torch.optim.AdamW([ + {"params": lora_A_params, "lr": args.lr}, + {"params": lora_B_params, "lr": args.lr * args.lora_plus_ratio}, + ], weight_decay=1e-2) + if args.lora_plus_ratio != 1.0: + print(f"[LoRA] LoRA+: lr_A={args.lr:.2e} lr_B={args.lr * args.lora_plus_ratio:.2e}") def lr_lambda(step): if step < args.warmup_steps: @@ -351,6 +366,9 @@ def main(): f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}") print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n") + curriculum_switch_step = start_step + int((args.steps - start_step) * args.curriculum_switch) + _curriculum_switched = False + total_loss = 0.0 for step in range(start_step + 1, args.steps + 1): batch = random.choices(dataset, k=args.batch_size) @@ -363,11 +381,18 @@ def main(): net_generator.normalize(x1) - if args.timestep_mode == "logit_normal": + if args.timestep_mode == "logit_normal" or ( + args.timestep_mode == "curriculum" and step <= curriculum_switch_step + ): u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma t = torch.sigmoid(u) else: t = torch.rand(args.batch_size, device=device, dtype=dtype) + + if args.timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched: + print(f"[LoRA] Curriculum switch: logit_normal → uniform at step {step}") + _curriculum_switched = True + x0 = torch.randn_like(x1) xt = fm.get_conditional_flow(x0, x1, t) @@ -378,7 +403,7 @@ def main(): total_loss += loss.item() * args.grad_accum if step % args.grad_accum == 0: - torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) + torch.nn.utils.clip_grad_norm_(lora_A_params + lora_B_params, max_norm=1.0) optimizer.step() scheduler.step() optimizer.zero_grad() @@ -404,6 +429,9 @@ def main(): "steps": args.steps, "timestep_mode": args.timestep_mode, "logit_normal_sigma": args.logit_normal_sigma, + "curriculum_switch": args.curriculum_switch, + "lora_dropout": args.lora_dropout, + "lora_plus_ratio": args.lora_plus_ratio, }, }, ckpt_path) print(f"[LoRA] Saved {ckpt_path}") @@ -424,6 +452,9 @@ def main(): "steps": args.steps, "timestep_mode": args.timestep_mode, "logit_normal_sigma": args.logit_normal_sigma, + "curriculum_switch": args.curriculum_switch, + "lora_dropout": args.lora_dropout, + "lora_plus_ratio": args.lora_plus_ratio, } torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final) (output_dir / "meta.json").write_text(json.dumps(meta, indent=2))