feat: add logit-normal timestep sampling to reduce white noise artifacts
Uniform timestep sampling undertrained t>0.8 (the final denoising steps), leaving residual noise that CFG amplifies at inference. Logit-normal sampling concentrates training near t=0.5 while still covering the full range, improving high-t coverage and reducing noise floor in generated audio. Default changed from uniform to logit_normal (sigma=1.0). Previous behavior available with timestep_mode=uniform. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+25
-13
@@ -165,8 +165,12 @@ def main():
|
||||
parser.add_argument("--save_every", type=int, default=500)
|
||||
parser.add_argument("--resume", default=None,
|
||||
help="Path to a step checkpoint (.pt) to resume training from.")
|
||||
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--timestep_mode", default="logit_normal", choices=["logit_normal", "uniform"],
|
||||
help="Timestep sampling distribution. logit_normal reduces white noise artifacts.")
|
||||
parser.add_argument("--logit_normal_sigma", type=float, default=1.0,
|
||||
help="Spread of logit-normal distribution (only used with --timestep_mode logit_normal).")
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
@@ -342,7 +346,11 @@ def main():
|
||||
|
||||
net_generator.normalize(x1)
|
||||
|
||||
t = torch.rand(args.batch_size, device=device, dtype=dtype)
|
||||
if args.timestep_mode == "logit_normal":
|
||||
u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma
|
||||
t = torch.sigmoid(u)
|
||||
else:
|
||||
t = torch.rand(args.batch_size, device=device, dtype=dtype)
|
||||
x0 = torch.randn_like(x1)
|
||||
xt = fm.get_conditional_flow(x0, x1, t)
|
||||
|
||||
@@ -372,11 +380,13 @@ def main():
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"step": step,
|
||||
"meta": {
|
||||
"variant": args.variant,
|
||||
"rank": args.rank,
|
||||
"alpha": args.alpha if args.alpha is not None else float(args.rank),
|
||||
"target": args.target,
|
||||
"steps": args.steps,
|
||||
"variant": args.variant,
|
||||
"rank": args.rank,
|
||||
"alpha": args.alpha if args.alpha is not None else float(args.rank),
|
||||
"target": args.target,
|
||||
"steps": args.steps,
|
||||
"timestep_mode": args.timestep_mode,
|
||||
"logit_normal_sigma": args.logit_normal_sigma,
|
||||
},
|
||||
}, ckpt_path)
|
||||
print(f"[LoRA] Saved {ckpt_path}")
|
||||
@@ -390,11 +400,13 @@ def main():
|
||||
i += 1
|
||||
final = output_dir / f"adapter_final_{i:03d}.pt"
|
||||
meta = {
|
||||
"variant": args.variant,
|
||||
"rank": args.rank,
|
||||
"alpha": args.alpha if args.alpha is not None else float(args.rank),
|
||||
"target": args.target,
|
||||
"steps": args.steps,
|
||||
"variant": args.variant,
|
||||
"rank": args.rank,
|
||||
"alpha": args.alpha if args.alpha is not None else float(args.rank),
|
||||
"target": args.target,
|
||||
"steps": args.steps,
|
||||
"timestep_mode": args.timestep_mode,
|
||||
"logit_normal_sigma": args.logit_normal_sigma,
|
||||
}
|
||||
torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final)
|
||||
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||||
|
||||
Reference in New Issue
Block a user