From af6c225f53a7b1df7428992ca0f971a81577c025 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 10 Apr 2026 11:32:00 +0200 Subject: [PATCH] feat: add dataset pipeline nodes + latent augmentation for LoRA trainer New dataset pipeline nodes: - SelvaDatasetSpectralMatcher: batch spectral EQ toward VAE distribution - SelvaDatasetHfSmoother: batch HF attenuation for codec compatibility - SelvaDatasetAugmenter: gain/pitch/time-stretch variants with npz origin tracking Improvements: - Inspector: silence detection (max_silence_fraction param) - Saver: origin_name lookup for augmented clips' npz pairing - LoRA trainer: latent_mixup_alpha + latent_noise_sigma regularization - LoRA trainer: one-time SR mismatch warning in _load_audio Co-Authored-By: Claude Opus 4.6 --- nodes/__init__.py | 3 + nodes/selva_dataset_pipeline.py | 265 +++++++++++++++++++++++++++++++- nodes/selva_lora_trainer.py | 34 +++- 3 files changed, 299 insertions(+), 3 deletions(-) diff --git a/nodes/__init__.py b/nodes/__init__.py index ef5a391..68a292d 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -32,6 +32,9 @@ _NODES = { "SelvaDatasetSaver": (".selva_dataset_pipeline", "SelvaDatasetSaver", "SelVA Dataset Saver"), "SelvaHarmonicExciter": (".selva_audio_postprocess", "SelvaHarmonicExciter", "SelVA Harmonic Exciter"), "SelvaOutputNormalizer": (".selva_audio_postprocess", "SelvaOutputNormalizer", "SelVA Output Normalizer"), + "SelvaDatasetSpectralMatcher": (".selva_dataset_pipeline", "SelvaDatasetSpectralMatcher", "SelVA Dataset Spectral Matcher"), + "SelvaDatasetHfSmoother": (".selva_dataset_pipeline", "SelvaDatasetHfSmoother", "SelVA Dataset HF Smoother"), + "SelvaDatasetAugmenter": (".selva_dataset_pipeline", "SelvaDatasetAugmenter", "SelVA Dataset Augmenter"), } for key, (module_path, class_name, display_name) in _NODES.items(): diff --git a/nodes/selva_dataset_pipeline.py b/nodes/selva_dataset_pipeline.py index 9d41391..85586ff 100644 --- a/nodes/selva_dataset_pipeline.py +++ b/nodes/selva_dataset_pipeline.py @@ -9,6 +9,12 @@ Typical chain: ↓ AUDIO_DATASET SelvaDatasetCompressor (optional) ↓ AUDIO_DATASET + SelvaDatasetSpectralMatcher (optional — batch spectral EQ) + ↓ AUDIO_DATASET + SelvaDatasetHfSmoother (optional — batch HF attenuation) + ↓ AUDIO_DATASET + SelvaDatasetAugmenter (optional — gain/pitch/stretch variants) + ↓ AUDIO_DATASET SelvaDatasetInspector (optional) ↓ AUDIO_DATASET + STRING report SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes) @@ -342,6 +348,11 @@ class SelvaDatasetInspector: "default": True, "tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).", }), + "max_silence_fraction": ("FLOAT", { + "default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05, + "tooltip": "Flag clips where more than this fraction of frames are near-silent " + "(< -60 dBFS). Set to 0 to disable silence detection.", + }), } } @@ -355,7 +366,8 @@ class SelvaDatasetInspector: "Connect report to a ShowText node to preview in the UI." ) - def inspect(self, dataset, skip_rejected: bool, min_snr_db: float, check_codec_artifacts: bool): + def inspect(self, dataset, skip_rejected: bool, min_snr_db: float, + check_codec_artifacts: bool, max_silence_fraction: float = 0.5): clean = [] flagged = [] lines = ["SelVA Dataset Inspector Report", "=" * 40] @@ -380,6 +392,16 @@ class SelvaDatasetInspector: if check_codec_artifacts and _check_hf_shelf(wav, sr): issues.append("codec artifact (HF shelf > 15 kHz)") + # Silence detection + if max_silence_fraction > 0: + mono = wav[0].mean(0) + if mono.shape[0] >= 2048: + frames = mono.unfold(0, 2048, 512) + rms = frames.pow(2).mean(-1).sqrt() + silent_frac = (rms < 1e-3).float().mean().item() + if silent_frac > max_silence_fraction: + issues.append(f"mostly silent ({silent_frac:.0%} < -60 dBFS)") + if issues: flagged.append(name) lines.append(f" FLAGGED {name}: {', '.join(issues)}") @@ -507,7 +529,9 @@ class SelvaDatasetSaver: saved += 1 if npz_src is not None: - npz_path = npz_src / f"{name}.npz" + # Augmented clips store their origin name — use it to find the .npz + lookup = item.get("origin_name", name) + npz_path = npz_src / f"{lookup}.npz" if npz_path.exists(): shutil.copy2(str(npz_path), str(out / f"{name}.npz")) npz_copied += 1 @@ -525,3 +549,240 @@ class SelvaDatasetSaver: report = "\n".join(lines) print(report, flush=True) return (report,) + + +# ── Batch wrappers for audio preprocessors ─────────────────────────────────── + +class SelvaDatasetSpectralMatcher: + """Apply SelVA Spectral Matcher to every clip in an AUDIO_DATASET. + + Wraps SelvaSpectralMatcher so it works on batch datasets instead of + individual AUDIO items. Same parameters — see that node for details. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "dataset": (AUDIO_DATASET,), + "mode": (["44k", "16k"], { + "tooltip": "Must match the SelVA model you are training. " + "44k = large model, 16k = small model.", + }), + "strength": ("FLOAT", { + "default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05, + "tooltip": "0 = no correction, 1 = full match to VAE distribution.", + }), + "max_gain_db": ("FLOAT", { + "default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0, + "tooltip": "Clamps per-band gain to ±dB.", + }), + } + } + + RETURN_TYPES = (AUDIO_DATASET,) + RETURN_NAMES = ("dataset",) + FUNCTION = "process" + CATEGORY = SELVA_CATEGORY + DESCRIPTION = ( + "Apply adaptive spectral matching to every clip in a dataset. " + "Batch version of SelVA Spectral Matcher — same per-band EQ toward the " + "VAE's expected distribution." + ) + + def process(self, dataset, mode: str, strength: float, max_gain_db: float): + from .selva_audio_preprocessors import SelvaSpectralMatcher + matcher = SelvaSpectralMatcher() + out = [] + for item in dataset: + audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]} + (result,) = matcher.process(audio, mode, strength, max_gain_db) + new_item = dict(item) # preserve origin_name and any extra keys + new_item["waveform"] = result["waveform"] + new_item["sample_rate"] = result["sample_rate"] + out.append(new_item) + print(f"[DatasetSpectralMatcher] {len(out)} clips processed " + f"mode={mode} strength={strength}", flush=True) + return (out,) + + +class SelvaDatasetHfSmoother: + """Apply SelVA HF Smoother to every clip in an AUDIO_DATASET. + + Wraps SelvaHfSmoother so it works on batch datasets instead of + individual AUDIO items. Same parameters — see that node for details. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "dataset": (AUDIO_DATASET,), + "cutoff_hz": ("FLOAT", { + "default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0, + "tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.", + }), + "blend": ("FLOAT", { + "default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05, + "tooltip": "0 = original, 1 = fully filtered.", + }), + } + } + + RETURN_TYPES = (AUDIO_DATASET,) + RETURN_NAMES = ("dataset",) + FUNCTION = "process" + CATEGORY = SELVA_CATEGORY + DESCRIPTION = ( + "Apply soft HF attenuation to every clip in a dataset. " + "Batch version of SelVA HF Smoother — blends a low-pass filtered copy " + "with the original to tame extreme HF content." + ) + + def process(self, dataset, cutoff_hz: float, blend: float): + from .selva_audio_preprocessors import SelvaHfSmoother + smoother = SelvaHfSmoother() + out = [] + for item in dataset: + audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]} + (result,) = smoother.process(audio, cutoff_hz, blend) + new_item = dict(item) # preserve origin_name and any extra keys + new_item["waveform"] = result["waveform"] + new_item["sample_rate"] = result["sample_rate"] + out.append(new_item) + print(f"[DatasetHfSmoother] {len(out)} clips processed " + f"cutoff={cutoff_hz:.0f}Hz blend={blend:.2f}", flush=True) + return (out,) + + +# ── Dataset augmenter ──────────────────────────────────────────────────────── + +class SelvaDatasetAugmenter: + """Create augmented variants of each clip to expand a small dataset. + + Supports gain variation (always available) and optionally pitch shift + and time stretch via audiomentations. Install audiomentations for the + full feature set: ``pip install audiomentations`` + + Each original clip produces ``variants_per_clip`` augmented copies. + Originals are kept by default (toggle ``keep_originals``). + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "dataset": (AUDIO_DATASET,), + "variants_per_clip": ("INT", { + "default": 2, "min": 1, "max": 20, + "tooltip": "Number of augmented copies per original clip.", + }), + "gain_range_db": ("FLOAT", { + "default": 3.0, "min": 0.0, "max": 12.0, "step": 0.5, + "tooltip": "Random gain ±dB applied to each variant. 3 dB is subtle.", + }), + "seed": ("INT", {"default": 42}), + }, + "optional": { + "pitch_range_semitones": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 4.0, "step": 0.25, + "tooltip": "Random pitch shift ±semitones. Requires audiomentations. 0 = disabled.", + }), + "time_stretch_range": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 0.3, "step": 0.05, + "tooltip": "Random time stretch ±fraction (0.1 = 90%–110% speed). " + "Requires audiomentations. 0 = disabled.", + }), + "keep_originals": ("BOOLEAN", { + "default": True, + "tooltip": "Include the original unaugmented clips in the output.", + }), + }, + } + + RETURN_TYPES = (AUDIO_DATASET,) + RETURN_NAMES = ("dataset",) + FUNCTION = "augment" + CATEGORY = SELVA_CATEGORY + DESCRIPTION = ( + "Create augmented variants of each clip (gain, pitch, time stretch) " + "to expand small training datasets. Gain is always available; pitch and " + "time stretch require audiomentations (pip install audiomentations)." + ) + + def augment(self, dataset, variants_per_clip: int, gain_range_db: float, + seed: int, pitch_range_semitones: float = 0.0, + time_stretch_range: float = 0.0, keep_originals: bool = True): + rng = np.random.RandomState(seed) + + # Try audiomentations for pitch/stretch + use_am = False + am_compose = None + needs_am = pitch_range_semitones > 0 or time_stretch_range > 0 + if needs_am: + try: + import audiomentations as am + transforms = [] + if pitch_range_semitones > 0: + transforms.append(am.PitchShift( + min_semitones=-pitch_range_semitones, + max_semitones=pitch_range_semitones, + p=0.5, + )) + if time_stretch_range > 0: + transforms.append(am.TimeStretch( + min_rate=1.0 - time_stretch_range, + max_rate=1.0 + time_stretch_range, + leave_length_unchanged=True, + p=0.5, + )) + am_compose = am.Compose(transforms) + use_am = True + except ImportError: + print("[DatasetAugmenter] audiomentations not installed — " + "pitch_shift and time_stretch disabled. " + "Install: pip install audiomentations", flush=True) + + out = [] + if keep_originals: + out.extend(dataset) + + for item in dataset: + wav = item["waveform"] # [1, C, L] + sr = item["sample_rate"] + name = item["name"] + + for v in range(variants_per_clip): + # Gain variation (always applied) + gain_db = rng.uniform(-gain_range_db, gain_range_db) if gain_range_db > 0 else 0.0 + gain_lin = 10.0 ** (gain_db / 20.0) + wav_aug = wav * gain_lin + + # Pitch/stretch via audiomentations + if use_am and am_compose is not None: + wav_np = wav_aug[0].numpy() # [C, L] float32 + if wav_np.shape[0] == 1: + wav_np = wav_np[0] # [L] mono for audiomentations + wav_np = am_compose(samples=wav_np, sample_rate=sr) + if wav_np.ndim == 1: + wav_np = wav_np[np.newaxis, :] # back to [1, L] + wav_aug = torch.from_numpy(wav_np).unsqueeze(0) # [1, C, L] + + # Prevent clipping + peak = wav_aug.abs().max() + if peak > 1.0: + wav_aug = wav_aug / peak + + out.append({ + "waveform": wav_aug, + "sample_rate": sr, + "name": f"{name}_aug{v:02d}", + "origin_name": name, + }) + + print(f"[DatasetAugmenter] {len(dataset)} originals → {len(out)} total clips " + f"gain=±{gain_range_db:.1f}dB" + + (f" pitch=±{pitch_range_semitones:.1f}st" if pitch_range_semitones > 0 else "") + + (f" stretch=±{time_stretch_range:.0%}" if time_stretch_range > 0 else ""), + flush=True) + return (out,) diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 413e8c6..e60a157 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -72,6 +72,10 @@ def _load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor: waveform = waveform.mean(0, keepdim=True) waveform = waveform.squeeze(0).float() if sr != target_sr: + if not getattr(_load_audio, "_sr_warned", False): + print(f"[LoRA Trainer] WARNING: audio sr={sr} ≠ target {target_sr}, resampling on-the-fly. " + f"Pre-resample with SelVA Dataset Resampler for faster loading.", flush=True) + _load_audio._sr_warned = True waveform = torchaudio.functional.resample( waveform.unsqueeze(0), sr, target_sr).squeeze(0) target_len = int(duration * target_sr) @@ -557,6 +561,17 @@ class SelvaLoraTrainer: "cosine: decay from lr to ~0 following a cosine curve — " "prevents oscillation when LR is slightly too high.", }), + "latent_mixup_alpha": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 1.0, "step": 0.05, + "tooltip": "Beta distribution alpha for latent mixup (MusicLDM, arXiv:2308.01546). " + "0 = disabled. 0.4 recommended. Mixes pairs of training latents " + "to reduce memorization on small datasets.", + }), + "latent_noise_sigma": ("FLOAT", { + "default": 0.0, "min": 0.0, "max": 0.1, "step": 0.005, + "tooltip": "Additive Gaussian noise on training latents, scaled by x1.std(). " + "0 = disabled. 0.01–0.03 adds mild regularization against overfitting.", + }), }, } @@ -581,7 +596,8 @@ class SelvaLoraTrainer: grad_accum=1, save_every=500, resume_path="", seed=42, timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6, lora_dropout=0.0, lora_plus_ratio=1.0, - init_mode="pissa", use_rslora=True, lr_schedule="constant"): + init_mode="pissa", use_rslora=True, lr_schedule="constant", + latent_mixup_alpha=0.0, latent_noise_sigma=0.0): torch.manual_seed(seed) random.seed(seed) @@ -633,6 +649,7 @@ class SelvaLoraTrainer: timestep_mode, logit_normal_sigma, curriculum_switch, lora_dropout, lora_plus_ratio, lr_schedule, init_mode, use_rslora, + latent_mixup_alpha, latent_noise_sigma, ) return (r["patched_model"], r["adapter_path"], r["loss_curve"]) @@ -645,6 +662,7 @@ class SelvaLoraTrainer: timestep_mode="uniform", logit_normal_sigma=1.0, curriculum_switch=0.6, lora_dropout=0.0, lora_plus_ratio=1.0, lr_schedule="constant", init_mode="pissa", use_rslora=True, + latent_mixup_alpha=0.0, latent_noise_sigma=0.0, ): # --- Prepare generator copy with LoRA --- generator = copy.deepcopy(model["generator"]).to(device, dtype) @@ -748,6 +766,8 @@ class SelvaLoraTrainer: "lr_schedule": lr_schedule, "init_mode": init_mode, "use_rslora": use_rslora, + "latent_mixup_alpha": latent_mixup_alpha, + "latent_noise_sigma": latent_noise_sigma, } # For curriculum mode: compute the step at which we switch from logit_normal to uniform @@ -775,6 +795,18 @@ class SelvaLoraTrainer: x1 = generator.normalize(x1) + # Latent mixup (MusicLDM, arXiv:2308.01546) + if latent_mixup_alpha > 0 and x1.shape[0] > 1: + lam = torch.distributions.Beta( + latent_mixup_alpha, latent_mixup_alpha + ).sample().to(device) + idx = torch.randperm(x1.shape[0], device=device) + x1 = lam * x1 + (1 - lam) * x1[idx] + + # Latent noise injection + if latent_noise_sigma > 0: + x1 = x1 + torch.randn_like(x1) * latent_noise_sigma * x1.std() + if timestep_mode == "logit_normal" or ( timestep_mode == "curriculum" and step <= curriculum_switch_step ):