feat: add logit-normal timestep sampling to reduce white noise artifacts

Uniform timestep sampling undertrained t>0.8 (the final denoising steps),
leaving residual noise that CFG amplifies at inference. Logit-normal sampling
concentrates training near t=0.5 while still covering the full range, improving
high-t coverage and reducing noise floor in generated audio.

Default changed from uniform to logit_normal (sigma=1.0). Previous behavior
available with timestep_mode=uniform.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-06 00:35:42 +02:00
parent 8ae0ba3c7d
commit a5014e49eb
3 changed files with 73 additions and 21 deletions
+30 -8
View File
@@ -271,6 +271,18 @@ class SelvaLoraTrainer:
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
}),
"seed": ("INT", {"default": 42}),
"timestep_mode": (["logit_normal", "uniform"], {
"default": "logit_normal",
"tooltip": "How to sample training timesteps. "
"logit_normal concentrates steps near t=0.5 (recommended — reduces white noise artifacts). "
"uniform samples all timesteps equally (original behavior).",
}),
"logit_normal_sigma": ("FLOAT", {
"default": 1.0, "min": 0.1, "max": 3.0, "step": 0.1,
"tooltip": "Spread of the logit-normal distribution. "
"1.0 = moderate peak at t=0.5. Higher approaches uniform. "
"Only used when timestep_mode=logit_normal.",
}),
},
}
@@ -292,7 +304,8 @@ class SelvaLoraTrainer:
def train(self, model, data_dir, output_dir, steps, rank, lr,
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
grad_accum=1, save_every=500, resume_path="", seed=42):
grad_accum=1, save_every=500, resume_path="", seed=42,
timestep_mode="logit_normal", logit_normal_sigma=1.0):
torch.manual_seed(seed)
random.seed(seed)
@@ -396,6 +409,7 @@ class SelvaLoraTrainer:
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup_steps,
grad_accum, save_every, resume_path, seed,
timestep_mode, logit_normal_sigma,
)
def _train_inner(
@@ -404,6 +418,7 @@ class SelvaLoraTrainer:
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup_steps,
grad_accum, save_every, resume_path, seed,
timestep_mode="logit_normal", logit_normal_sigma=1.0,
):
# --- Prepare generator copy with LoRA ---
generator = copy.deepcopy(model["generator"]).to(device, dtype)
@@ -463,15 +478,18 @@ class SelvaLoraTrainer:
running_loss = 0.0
meta = {
"variant": variant,
"rank": rank,
"alpha": alpha_val,
"target": list(target_suffixes),
"steps": steps,
"variant": variant,
"rank": rank,
"alpha": alpha_val,
"target": list(target_suffixes),
"steps": steps,
"timestep_mode": timestep_mode,
"logit_normal_sigma": logit_normal_sigma,
}
print(f"\n[LoRA Trainer] Training {remaining} steps "
f"(step {start_step + 1}{steps}, batch_size={batch_size})\n", flush=True)
f"(step {start_step + 1}{steps}, batch_size={batch_size}, "
f"timestep_mode={timestep_mode})\n", flush=True)
for step in range(start_step + 1, steps + 1):
batch = random.choices(dataset, k=batch_size)
@@ -484,7 +502,11 @@ class SelvaLoraTrainer:
generator.normalize(x1)
t = torch.rand(batch_size, device=device, dtype=dtype)
if timestep_mode == "logit_normal":
u = torch.randn(batch_size, device=device, dtype=dtype) * logit_normal_sigma
t = torch.sigmoid(u)
else:
t = torch.rand(batch_size, device=device, dtype=dtype)
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)