feat: add logit-normal timestep sampling to reduce white noise artifacts
Uniform timestep sampling undertrained t>0.8 (the final denoising steps), leaving residual noise that CFG amplifies at inference. Logit-normal sampling concentrates training near t=0.5 while still covering the full range, improving high-t coverage and reducing noise floor in generated audio. Default changed from uniform to logit_normal (sigma=1.0). Previous behavior available with timestep_mode=uniform. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -127,6 +127,8 @@ The script will:
|
|||||||
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step04000.pt`) |
|
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step04000.pt`) |
|
||||||
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
|
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
|
||||||
| `--seed` | `42` | Random seed |
|
| `--seed` | `42` | Random seed |
|
||||||
|
| `--timestep_mode` | `logit_normal` | Timestep sampling: `logit_normal` (recommended) or `uniform` |
|
||||||
|
| `--logit_normal_sigma` | `1.0` | Spread of the logit-normal distribution. Only used with `logit_normal` |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -241,6 +243,22 @@ Add `linear1` to also adapt post-attention projections for large-scale domain sh
|
|||||||
|
|
||||||
Only add `linear1` once you have 150+ clips — it doubles the adapted parameter count and overfits faster on small datasets.
|
Only add `linear1` once you have 150+ clips — it doubles the adapted parameter count and overfits faster on small datasets.
|
||||||
|
|
||||||
|
### Timestep sampling mode
|
||||||
|
|
||||||
|
The default `logit_normal` mode samples training timesteps from a bell-shaped distribution centered at t=0.5 (via `sigmoid(N(0, σ))`). This gives more training budget to the middle of the noise schedule — the semantically rich region where the model learns what the sound should sound like — while still covering the full range.
|
||||||
|
|
||||||
|
The alternative `uniform` mode samples all timesteps equally. This is mathematically valid but undertrains the high-t region (t > 0.8), which is where final audio quality is determined. Undertraining there leaves residual noise that is then amplified by CFG at inference.
|
||||||
|
|
||||||
|
| Mode | When to use |
|
||||||
|
|---|---|
|
||||||
|
| `logit_normal` (default, σ=1.0) | Recommended for all cases — reduces white noise artifacts |
|
||||||
|
| `uniform` | Baseline / comparison; equivalent to original MMAudio training |
|
||||||
|
|
||||||
|
The `logit_normal_sigma` parameter controls the width of the distribution:
|
||||||
|
- σ=1.0: moderate peak at t=0.5, balanced coverage (default)
|
||||||
|
- σ=0.5: sharper peak, less coverage of extremes
|
||||||
|
- σ=2.0: broader, approaches uniform
|
||||||
|
|
||||||
### Adapter strength at inference
|
### Adapter strength at inference
|
||||||
|
|
||||||
| Strength | Effect |
|
| Strength | Effect |
|
||||||
|
|||||||
@@ -271,6 +271,18 @@ class SelvaLoraTrainer:
|
|||||||
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
|
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
|
||||||
}),
|
}),
|
||||||
"seed": ("INT", {"default": 42}),
|
"seed": ("INT", {"default": 42}),
|
||||||
|
"timestep_mode": (["logit_normal", "uniform"], {
|
||||||
|
"default": "logit_normal",
|
||||||
|
"tooltip": "How to sample training timesteps. "
|
||||||
|
"logit_normal concentrates steps near t=0.5 (recommended — reduces white noise artifacts). "
|
||||||
|
"uniform samples all timesteps equally (original behavior).",
|
||||||
|
}),
|
||||||
|
"logit_normal_sigma": ("FLOAT", {
|
||||||
|
"default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1,
|
||||||
|
"tooltip": "Spread of the logit-normal distribution. "
|
||||||
|
"1.0 = moderate peak at t=0.5. Higher approaches uniform. "
|
||||||
|
"Only used when timestep_mode=logit_normal.",
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,7 +304,8 @@ class SelvaLoraTrainer:
|
|||||||
|
|
||||||
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
||||||
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
||||||
grad_accum=1, save_every=500, resume_path="", seed=42):
|
grad_accum=1, save_every=500, resume_path="", seed=42,
|
||||||
|
timestep_mode="logit_normal", logit_normal_sigma=1.0):
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@@ -396,6 +409,7 @@ class SelvaLoraTrainer:
|
|||||||
data_dir, output_dir, steps, rank, lr,
|
data_dir, output_dir, steps, rank, lr,
|
||||||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||||
grad_accum, save_every, resume_path, seed,
|
grad_accum, save_every, resume_path, seed,
|
||||||
|
timestep_mode, logit_normal_sigma,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _train_inner(
|
def _train_inner(
|
||||||
@@ -404,6 +418,7 @@ class SelvaLoraTrainer:
|
|||||||
data_dir, output_dir, steps, rank, lr,
|
data_dir, output_dir, steps, rank, lr,
|
||||||
alpha_val, target_suffixes, batch_size, warmup_steps,
|
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||||
grad_accum, save_every, resume_path, seed,
|
grad_accum, save_every, resume_path, seed,
|
||||||
|
timestep_mode="logit_normal", logit_normal_sigma=1.0,
|
||||||
):
|
):
|
||||||
# --- Prepare generator copy with LoRA ---
|
# --- Prepare generator copy with LoRA ---
|
||||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||||
@@ -468,10 +483,13 @@ class SelvaLoraTrainer:
|
|||||||
"alpha": alpha_val,
|
"alpha": alpha_val,
|
||||||
"target": list(target_suffixes),
|
"target": list(target_suffixes),
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
|
"timestep_mode": timestep_mode,
|
||||||
|
"logit_normal_sigma": logit_normal_sigma,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
||||||
f"(step {start_step + 1} → {steps}, batch_size={batch_size})\n", flush=True)
|
f"(step {start_step + 1} → {steps}, batch_size={batch_size}, "
|
||||||
|
f"timestep_mode={timestep_mode})\n", flush=True)
|
||||||
|
|
||||||
for step in range(start_step + 1, steps + 1):
|
for step in range(start_step + 1, steps + 1):
|
||||||
batch = random.choices(dataset, k=batch_size)
|
batch = random.choices(dataset, k=batch_size)
|
||||||
@@ -484,6 +502,10 @@ class SelvaLoraTrainer:
|
|||||||
|
|
||||||
generator.normalize(x1)
|
generator.normalize(x1)
|
||||||
|
|
||||||
|
if timestep_mode == "logit_normal":
|
||||||
|
u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma
|
||||||
|
t = torch.sigmoid(u)
|
||||||
|
else:
|
||||||
t = torch.rand(batch_size, device=device, dtype=dtype)
|
t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||||
x0 = torch.randn_like(x1)
|
x0 = torch.randn_like(x1)
|
||||||
xt = fm.get_conditional_flow(x0, x1, t)
|
xt = fm.get_conditional_flow(x0, x1, t)
|
||||||
|
|||||||
@@ -167,6 +167,10 @@ def main():
|
|||||||
help="Path to a step checkpoint (.pt) to resume training from.")
|
help="Path to a step checkpoint (.pt) to resume training from.")
|
||||||
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
|
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||||
parser.add_argument("--seed", type=int, default=42)
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
parser.add_argument("--timestep_mode", default="logit_normal", choices=["logit_normal", "uniform"],
|
||||||
|
help="Timestep sampling distribution. logit_normal reduces white noise artifacts.")
|
||||||
|
parser.add_argument("--logit_normal_sigma", type=float, default=1.0,
|
||||||
|
help="Spread of logit-normal distribution (only used with --timestep_mode logit_normal).")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
@@ -342,6 +346,10 @@ def main():
|
|||||||
|
|
||||||
net_generator.normalize(x1)
|
net_generator.normalize(x1)
|
||||||
|
|
||||||
|
if args.timestep_mode == "logit_normal":
|
||||||
|
u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma
|
||||||
|
t = torch.sigmoid(u)
|
||||||
|
else:
|
||||||
t = torch.rand(args.batch_size, device=device, dtype=dtype)
|
t = torch.rand(args.batch_size, device=device, dtype=dtype)
|
||||||
x0 = torch.randn_like(x1)
|
x0 = torch.randn_like(x1)
|
||||||
xt = fm.get_conditional_flow(x0, x1, t)
|
xt = fm.get_conditional_flow(x0, x1, t)
|
||||||
@@ -377,6 +385,8 @@ def main():
|
|||||||
"alpha": args.alpha if args.alpha is not None else float(args.rank),
|
"alpha": args.alpha if args.alpha is not None else float(args.rank),
|
||||||
"target": args.target,
|
"target": args.target,
|
||||||
"steps": args.steps,
|
"steps": args.steps,
|
||||||
|
"timestep_mode": args.timestep_mode,
|
||||||
|
"logit_normal_sigma": args.logit_normal_sigma,
|
||||||
},
|
},
|
||||||
}, ckpt_path)
|
}, ckpt_path)
|
||||||
print(f"[LoRA] Saved {ckpt_path}")
|
print(f"[LoRA] Saved {ckpt_path}")
|
||||||
@@ -395,6 +405,8 @@ def main():
|
|||||||
"alpha": args.alpha if args.alpha is not None else float(args.rank),
|
"alpha": args.alpha if args.alpha is not None else float(args.rank),
|
||||||
"target": args.target,
|
"target": args.target,
|
||||||
"steps": args.steps,
|
"steps": args.steps,
|
||||||
|
"timestep_mode": args.timestep_mode,
|
||||||
|
"logit_normal_sigma": args.logit_normal_sigma,
|
||||||
}
|
}
|
||||||
torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final)
|
torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final)
|
||||||
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||||||
|
|||||||
Reference in New Issue
Block a user