feat: add dataset pipeline nodes + latent augmentation for LoRA trainer
New dataset pipeline nodes: - SelvaDatasetSpectralMatcher: batch spectral EQ toward VAE distribution - SelvaDatasetHfSmoother: batch HF attenuation for codec compatibility - SelvaDatasetAugmenter: gain/pitch/time-stretch variants with npz origin tracking Improvements: - Inspector: silence detection (max_silence_fraction param) - Saver: origin_name lookup for augmented clips' npz pairing - LoRA trainer: latent_mixup_alpha + latent_noise_sigma regularization - LoRA trainer: one-time SR mismatch warning in _load_audio Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -72,6 +72,10 @@ def _load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
|
||||
waveform = waveform.mean(0, keepdim=True)
|
||||
waveform = waveform.squeeze(0).float()
|
||||
if sr != target_sr:
|
||||
if not getattr(_load_audio, "_sr_warned", False):
|
||||
print(f"[LoRA Trainer] WARNING: audio sr={sr} ≠ target {target_sr}, resampling on-the-fly. "
|
||||
f"Pre-resample with SelVA Dataset Resampler for faster loading.", flush=True)
|
||||
_load_audio._sr_warned = True
|
||||
waveform = torchaudio.functional.resample(
|
||||
waveform.unsqueeze(0), sr, target_sr).squeeze(0)
|
||||
target_len = int(duration * target_sr)
|
||||
@@ -557,6 +561,17 @@ class SelvaLoraTrainer:
|
||||
"cosine: decay from lr to ~0 following a cosine curve — "
|
||||
"prevents oscillation when LR is slightly too high.",
|
||||
}),
|
||||
"latent_mixup_alpha": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.05,
|
||||
"tooltip": "Beta distribution alpha for latent mixup (MusicLDM, arXiv:2308.01546). "
|
||||
"0 = disabled. 0.4 recommended. Mixes pairs of training latents "
|
||||
"to reduce memorization on small datasets.",
|
||||
}),
|
||||
"latent_noise_sigma": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 0.1, "step": 0.005,
|
||||
"tooltip": "Additive Gaussian noise on training latents, scaled by x1.std(). "
|
||||
"0 = disabled. 0.01–0.03 adds mild regularization against overfitting.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -581,7 +596,8 @@ class SelvaLoraTrainer:
|
||||
grad_accum=1, save_every=500, resume_path="", seed=42,
|
||||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||
lora_dropout=0.0, lora_plus_ratio=1.0,
|
||||
init_mode="pissa", use_rslora=True, lr_schedule="constant"):
|
||||
init_mode="pissa", use_rslora=True, lr_schedule="constant",
|
||||
latent_mixup_alpha=0.0, latent_noise_sigma=0.0):
|
||||
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
@@ -633,6 +649,7 @@ class SelvaLoraTrainer:
|
||||
timestep_mode, logit_normal_sigma, curriculum_switch,
|
||||
lora_dropout, lora_plus_ratio, lr_schedule,
|
||||
init_mode, use_rslora,
|
||||
latent_mixup_alpha, latent_noise_sigma,
|
||||
)
|
||||
return (r["patched_model"], r["adapter_path"], r["loss_curve"])
|
||||
|
||||
@@ -645,6 +662,7 @@ class SelvaLoraTrainer:
|
||||
timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6,
|
||||
lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant",
|
||||
init_mode="pissa", use_rslora=True,
|
||||
latent_mixup_alpha=0.0, latent_noise_sigma=0.0,
|
||||
):
|
||||
# --- Prepare generator copy with LoRA ---
|
||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||
@@ -748,6 +766,8 @@ class SelvaLoraTrainer:
|
||||
"lr_schedule": lr_schedule,
|
||||
"init_mode": init_mode,
|
||||
"use_rslora": use_rslora,
|
||||
"latent_mixup_alpha": latent_mixup_alpha,
|
||||
"latent_noise_sigma": latent_noise_sigma,
|
||||
}
|
||||
|
||||
# For curriculum mode: compute the step at which we switch from logit_normal to uniform
|
||||
@@ -775,6 +795,18 @@ class SelvaLoraTrainer:
|
||||
|
||||
x1 = generator.normalize(x1)
|
||||
|
||||
# Latent mixup (MusicLDM, arXiv:2308.01546)
|
||||
if latent_mixup_alpha > 0 and x1.shape[0] > 1:
|
||||
lam = torch.distributions.Beta(
|
||||
latent_mixup_alpha, latent_mixup_alpha
|
||||
).sample().to(device)
|
||||
idx = torch.randperm(x1.shape[0], device=device)
|
||||
x1 = lam * x1 + (1 - lam) * x1[idx]
|
||||
|
||||
# Latent noise injection
|
||||
if latent_noise_sigma > 0:
|
||||
x1 = x1 + torch.randn_like(x1) * latent_noise_sigma * x1.std()
|
||||
|
||||
if timestep_mode == "logit_normal" or (
|
||||
timestep_mode == "curriculum" and step <= curriculum_switch_step
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user