From a5014e49ebe3b776d8066abaa05cfd4925314608 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Mon, 6 Apr 2026 00:35:42 +0200 Subject: [PATCH] feat: add logit-normal timestep sampling to reduce white noise artifacts Uniform timestep sampling undertrained t>0.8 (the final denoising steps), leaving residual noise that CFG amplifies at inference. Logit-normal sampling concentrates training near t=0.5 while still covering the full range, improving high-t coverage and reducing noise floor in generated audio. Default changed from uniform to logit_normal (sigma=1.0). Previous behavior available with timestep_mode=uniform. Co-Authored-By: Claude Sonnet 4.6 --- LORA_TRAINING.md | 18 ++++++++++++++++++ nodes/selva_lora_trainer.py | 38 +++++++++++++++++++++++++++++-------- train_lora.py | 38 ++++++++++++++++++++++++------------- 3 files changed, 73 insertions(+), 21 deletions(-) diff --git a/LORA_TRAINING.md b/LORA_TRAINING.md index cac7fc8..e557d57 100644 --- a/LORA_TRAINING.md +++ b/LORA_TRAINING.md @@ -127,6 +127,8 @@ The script will: | `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step04000.pt`) | | `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` | | `--seed` | `42` | Random seed | +| `--timestep_mode` | `logit_normal` | Timestep sampling: `logit_normal` (recommended) or `uniform` | +| `--logit_normal_sigma` | `1.0` | Spread of the logit-normal distribution. Only used with `logit_normal` | --- @@ -241,6 +243,22 @@ Add `linear1` to also adapt post-attention projections for large-scale domain sh Only add `linear1` once you have 150+ clips — it doubles the adapted parameter count and overfits faster on small datasets. +### Timestep sampling mode + +The default `logit_normal` mode samples training timesteps from a bell-shaped distribution centered at t=0.5 (via `sigmoid(N(0, σ))`). This gives more training budget to the middle of the noise schedule — the semantically rich region where the model learns what the sound should sound like — while still covering the full range. + +The alternative `uniform` mode samples all timesteps equally. This is mathematically valid but undertrains the high-t region (t > 0.8), which is where final audio quality is determined. Undertraining there leaves residual noise that is then amplified by CFG at inference. + +| Mode | When to use | +|---|---| +| `logit_normal` (default, σ=1.0) | Recommended for all cases — reduces white noise artifacts | +| `uniform` | Baseline / comparison; equivalent to original MMAudio training | + +The `logit_normal_sigma` parameter controls the width of the distribution: +- σ=1.0: moderate peak at t=0.5, balanced coverage (default) +- σ=0.5: sharper peak, less coverage of extremes +- σ=2.0: broader, approaches uniform + ### Adapter strength at inference | Strength | Effect | diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index dddbf0f..49ad7a7 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -271,6 +271,18 @@ class SelvaLoraTrainer: "tooltip": "Path to a step checkpoint (.pt) to resume training from.", }), "seed": ("INT", {"default": 42}), + "timestep_mode": (["logit_normal", "uniform"], { + "default": "logit_normal", + "tooltip": "How to sample training timesteps. " + "logit_normal concentrates steps near t=0.5 (recommended — reduces white noise artifacts). " + "uniform samples all timesteps equally (original behavior).", + }), + "logit_normal_sigma": ("FLOAT", { + "default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1, + "tooltip": "Spread of the logit-normal distribution. " + "1.0 = moderate peak at t=0.5. Higher approaches uniform. " + "Only used when timestep_mode=logit_normal.", + }), }, } @@ -292,7 +304,8 @@ class SelvaLoraTrainer: def train(self, model, data_dir, output_dir, steps, rank, lr, alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100, - grad_accum=1, save_every=500, resume_path="", seed=42): + grad_accum=1, save_every=500, resume_path="", seed=42, + timestep_mode="logit_normal", logit_normal_sigma=1.0): torch.manual_seed(seed) random.seed(seed) @@ -396,6 +409,7 @@ class SelvaLoraTrainer: data_dir, output_dir, steps, rank, lr, alpha_val, target_suffixes, batch_size, warmup_steps, grad_accum, save_every, resume_path, seed, + timestep_mode, logit_normal_sigma, ) def _train_inner( @@ -404,6 +418,7 @@ class SelvaLoraTrainer: data_dir, output_dir, steps, rank, lr, alpha_val, target_suffixes, batch_size, warmup_steps, grad_accum, save_every, resume_path, seed, + timestep_mode="logit_normal", logit_normal_sigma=1.0, ): # --- Prepare generator copy with LoRA --- generator = copy.deepcopy(model["generator"]).to(device, dtype) @@ -463,15 +478,18 @@ class SelvaLoraTrainer: running_loss = 0.0 meta = { - "variant": variant, - "rank": rank, - "alpha": alpha_val, - "target": list(target_suffixes), - "steps": steps, + "variant": variant, + "rank": rank, + "alpha": alpha_val, + "target": list(target_suffixes), + "steps": steps, + "timestep_mode": timestep_mode, + "logit_normal_sigma": logit_normal_sigma, } print(f"\n[LoRA Trainer] Training {remaining} steps " - f"(step {start_step + 1} → {steps}, batch_size={batch_size})\n", flush=True) + f"(step {start_step + 1} → {steps}, batch_size={batch_size}, " + f"timestep_mode={timestep_mode})\n", flush=True) for step in range(start_step + 1, steps + 1): batch = random.choices(dataset, k=batch_size) @@ -484,7 +502,11 @@ class SelvaLoraTrainer: generator.normalize(x1) - t = torch.rand(batch_size, device=device, dtype=dtype) + if timestep_mode == "logit_normal": + u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma + t = torch.sigmoid(u) + else: + t = torch.rand(batch_size, device=device, dtype=dtype) x0 = torch.randn_like(x1) xt = fm.get_conditional_flow(x0, x1, t) diff --git a/train_lora.py b/train_lora.py index 4a5595e..49e069a 100644 --- a/train_lora.py +++ b/train_lora.py @@ -165,8 +165,12 @@ def main(): parser.add_argument("--save_every", type=int, default=500) parser.add_argument("--resume", default=None, help="Path to a step checkpoint (.pt) to resume training from.") - parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"]) - parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"]) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--timestep_mode", default="logit_normal", choices=["logit_normal", "uniform"], + help="Timestep sampling distribution. logit_normal reduces white noise artifacts.") + parser.add_argument("--logit_normal_sigma", type=float, default=1.0, + help="Spread of logit-normal distribution (only used with --timestep_mode logit_normal).") args = parser.parse_args() torch.manual_seed(args.seed) @@ -342,7 +346,11 @@ def main(): net_generator.normalize(x1) - t = torch.rand(args.batch_size, device=device, dtype=dtype) + if args.timestep_mode == "logit_normal": + u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma + t = torch.sigmoid(u) + else: + t = torch.rand(args.batch_size, device=device, dtype=dtype) x0 = torch.randn_like(x1) xt = fm.get_conditional_flow(x0, x1, t) @@ -372,11 +380,13 @@ def main(): "scheduler": scheduler.state_dict(), "step": step, "meta": { - "variant": args.variant, - "rank": args.rank, - "alpha": args.alpha if args.alpha is not None else float(args.rank), - "target": args.target, - "steps": args.steps, + "variant": args.variant, + "rank": args.rank, + "alpha": args.alpha if args.alpha is not None else float(args.rank), + "target": args.target, + "steps": args.steps, + "timestep_mode": args.timestep_mode, + "logit_normal_sigma": args.logit_normal_sigma, }, }, ckpt_path) print(f"[LoRA] Saved {ckpt_path}") @@ -390,11 +400,13 @@ def main(): i += 1 final = output_dir / f"adapter_final_{i:03d}.pt" meta = { - "variant": args.variant, - "rank": args.rank, - "alpha": args.alpha if args.alpha is not None else float(args.rank), - "target": args.target, - "steps": args.steps, + "variant": args.variant, + "rank": args.rank, + "alpha": args.alpha if args.alpha is not None else float(args.rank), + "target": args.target, + "steps": args.steps, + "timestep_mode": args.timestep_mode, + "logit_normal_sigma": args.logit_normal_sigma, } torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final) (output_dir / "meta.json").write_text(json.dumps(meta, indent=2))