feat: add dataset pipeline nodes + latent augmentation for LoRA trainer

New dataset pipeline nodes:
- SelvaDatasetSpectralMatcher: batch spectral EQ toward VAE distribution
- SelvaDatasetHfSmoother: batch HF attenuation for codec compatibility
- SelvaDatasetAugmenter: gain/pitch/time-stretch variants with npz origin tracking

Improvements:
- Inspector: silence detection (max_silence_fraction param)
- Saver: origin_name lookup for augmented clips' npz pairing
- LoRA trainer: latent_mixup_alpha + latent_noise_sigma regularization
- LoRA trainer: one-time SR mismatch warning in _load_audio

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 11:32:00 +02:00
parent 30127c13ca
commit af6c225f53
3 changed files with 299 additions and 3 deletions
+263 -2
View File
@@ -9,6 +9,12 @@ Typical chain:
↓ AUDIO_DATASET
SelvaDatasetCompressor (optional)
↓ AUDIO_DATASET
SelvaDatasetSpectralMatcher (optional — batch spectral EQ)
↓ AUDIO_DATASET
SelvaDatasetHfSmoother (optional — batch HF attenuation)
↓ AUDIO_DATASET
SelvaDatasetAugmenter (optional — gain/pitch/stretch variants)
↓ AUDIO_DATASET
SelvaDatasetInspector (optional)
↓ AUDIO_DATASET + STRING report
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
@@ -342,6 +348,11 @@ class SelvaDatasetInspector:
"default": True,
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
}),
"max_silence_fraction": ("FLOAT", {
"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Flag clips where more than this fraction of frames are near-silent "
"(< -60 dBFS). Set to 0 to disable silence detection.",
}),
}
}
@@ -355,7 +366,8 @@ class SelvaDatasetInspector:
"Connect report to a ShowText node to preview in the UI."
)
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float, check_codec_artifacts: bool):
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float,
check_codec_artifacts: bool, max_silence_fraction: float = 0.5):
clean = []
flagged = []
lines = ["SelVA Dataset Inspector Report", "=" * 40]
@@ -380,6 +392,16 @@ class SelvaDatasetInspector:
if check_codec_artifacts and _check_hf_shelf(wav, sr):
issues.append("codec artifact (HF shelf > 15 kHz)")
# Silence detection
if max_silence_fraction > 0:
mono = wav[0].mean(0)
if mono.shape[0] >= 2048:
frames = mono.unfold(0, 2048, 512)
rms = frames.pow(2).mean(-1).sqrt()
silent_frac = (rms < 1e-3).float().mean().item()
if silent_frac > max_silence_fraction:
issues.append(f"mostly silent ({silent_frac:.0%} < -60 dBFS)")
if issues:
flagged.append(name)
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
@@ -507,7 +529,9 @@ class SelvaDatasetSaver:
saved += 1
if npz_src is not None:
npz_path = npz_src / f"{name}.npz"
# Augmented clips store their origin name — use it to find the .npz
lookup = item.get("origin_name", name)
npz_path = npz_src / f"{lookup}.npz"
if npz_path.exists():
shutil.copy2(str(npz_path), str(out / f"{name}.npz"))
npz_copied += 1
@@ -525,3 +549,240 @@ class SelvaDatasetSaver:
report = "\n".join(lines)
print(report, flush=True)
return (report,)
# ── Batch wrappers for audio preprocessors ───────────────────────────────────
class SelvaDatasetSpectralMatcher:
"""Apply SelVA Spectral Matcher to every clip in an AUDIO_DATASET.
Wraps SelvaSpectralMatcher so it works on batch datasets instead of
individual AUDIO items. Same parameters — see that node for details.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"mode": (["44k", "16k"], {
"tooltip": "Must match the SelVA model you are training. "
"44k = large model, 16k = small model.",
}),
"strength": ("FLOAT", {
"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = no correction, 1 = full match to VAE distribution.",
}),
"max_gain_db": ("FLOAT", {
"default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0,
"tooltip": "Clamps per-band gain to ±dB.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Apply adaptive spectral matching to every clip in a dataset. "
"Batch version of SelVA Spectral Matcher — same per-band EQ toward the "
"VAE's expected distribution."
)
def process(self, dataset, mode: str, strength: float, max_gain_db: float):
from .selva_audio_preprocessors import SelvaSpectralMatcher
matcher = SelvaSpectralMatcher()
out = []
for item in dataset:
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
(result,) = matcher.process(audio, mode, strength, max_gain_db)
new_item = dict(item) # preserve origin_name and any extra keys
new_item["waveform"] = result["waveform"]
new_item["sample_rate"] = result["sample_rate"]
out.append(new_item)
print(f"[DatasetSpectralMatcher] {len(out)} clips processed "
f"mode={mode} strength={strength}", flush=True)
return (out,)
class SelvaDatasetHfSmoother:
"""Apply SelVA HF Smoother to every clip in an AUDIO_DATASET.
Wraps SelvaHfSmoother so it works on batch datasets instead of
individual AUDIO items. Same parameters — see that node for details.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"cutoff_hz": ("FLOAT", {
"default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0,
"tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.",
}),
"blend": ("FLOAT", {
"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = original, 1 = fully filtered.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Apply soft HF attenuation to every clip in a dataset. "
"Batch version of SelVA HF Smoother — blends a low-pass filtered copy "
"with the original to tame extreme HF content."
)
def process(self, dataset, cutoff_hz: float, blend: float):
from .selva_audio_preprocessors import SelvaHfSmoother
smoother = SelvaHfSmoother()
out = []
for item in dataset:
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
(result,) = smoother.process(audio, cutoff_hz, blend)
new_item = dict(item) # preserve origin_name and any extra keys
new_item["waveform"] = result["waveform"]
new_item["sample_rate"] = result["sample_rate"]
out.append(new_item)
print(f"[DatasetHfSmoother] {len(out)} clips processed "
f"cutoff={cutoff_hz:.0f}Hz blend={blend:.2f}", flush=True)
return (out,)
# ── Dataset augmenter ────────────────────────────────────────────────────────
class SelvaDatasetAugmenter:
"""Create augmented variants of each clip to expand a small dataset.
Supports gain variation (always available) and optionally pitch shift
and time stretch via audiomentations. Install audiomentations for the
full feature set: ``pip install audiomentations``
Each original clip produces ``variants_per_clip`` augmented copies.
Originals are kept by default (toggle ``keep_originals``).
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"variants_per_clip": ("INT", {
"default": 2, "min": 1, "max": 20,
"tooltip": "Number of augmented copies per original clip.",
}),
"gain_range_db": ("FLOAT", {
"default": 3.0, "min": 0.0, "max": 12.0, "step": 0.5,
"tooltip": "Random gain ±dB applied to each variant. 3 dB is subtle.",
}),
"seed": ("INT", {"default": 42}),
},
"optional": {
"pitch_range_semitones": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 4.0, "step": 0.25,
"tooltip": "Random pitch shift ±semitones. Requires audiomentations. 0 = disabled.",
}),
"time_stretch_range": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.05,
"tooltip": "Random time stretch ±fraction (0.1 = 90%110% speed). "
"Requires audiomentations. 0 = disabled.",
}),
"keep_originals": ("BOOLEAN", {
"default": True,
"tooltip": "Include the original unaugmented clips in the output.",
}),
},
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "augment"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Create augmented variants of each clip (gain, pitch, time stretch) "
"to expand small training datasets. Gain is always available; pitch and "
"time stretch require audiomentations (pip install audiomentations)."
)
def augment(self, dataset, variants_per_clip: int, gain_range_db: float,
seed: int, pitch_range_semitones: float = 0.0,
time_stretch_range: float = 0.0, keep_originals: bool = True):
rng = np.random.RandomState(seed)
# Try audiomentations for pitch/stretch
use_am = False
am_compose = None
needs_am = pitch_range_semitones > 0 or time_stretch_range > 0
if needs_am:
try:
import audiomentations as am
transforms = []
if pitch_range_semitones > 0:
transforms.append(am.PitchShift(
min_semitones=-pitch_range_semitones,
max_semitones=pitch_range_semitones,
p=0.5,
))
if time_stretch_range > 0:
transforms.append(am.TimeStretch(
min_rate=1.0 - time_stretch_range,
max_rate=1.0 + time_stretch_range,
leave_length_unchanged=True,
p=0.5,
))
am_compose = am.Compose(transforms)
use_am = True
except ImportError:
print("[DatasetAugmenter] audiomentations not installed — "
"pitch_shift and time_stretch disabled. "
"Install: pip install audiomentations", flush=True)
out = []
if keep_originals:
out.extend(dataset)
for item in dataset:
wav = item["waveform"] # [1, C, L]
sr = item["sample_rate"]
name = item["name"]
for v in range(variants_per_clip):
# Gain variation (always applied)
gain_db = rng.uniform(-gain_range_db, gain_range_db) if gain_range_db > 0 else 0.0
gain_lin = 10.0 ** (gain_db / 20.0)
wav_aug = wav * gain_lin
# Pitch/stretch via audiomentations
if use_am and am_compose is not None:
wav_np = wav_aug[0].numpy() # [C, L] float32
if wav_np.shape[0] == 1:
wav_np = wav_np[0] # [L] mono for audiomentations
wav_np = am_compose(samples=wav_np, sample_rate=sr)
if wav_np.ndim == 1:
wav_np = wav_np[np.newaxis, :] # back to [1, L]
wav_aug = torch.from_numpy(wav_np).unsqueeze(0) # [1, C, L]
# Prevent clipping
peak = wav_aug.abs().max()
if peak > 1.0:
wav_aug = wav_aug / peak
out.append({
"waveform": wav_aug,
"sample_rate": sr,
"name": f"{name}_aug{v:02d}",
"origin_name": name,
})
print(f"[DatasetAugmenter] {len(dataset)} originals → {len(out)} total clips "
f"gain=±{gain_range_db:.1f}dB"
+ (f" pitch=±{pitch_range_semitones:.1f}st" if pitch_range_semitones > 0 else "")
+ (f" stretch=±{time_stretch_range:.0%}" if time_stretch_range > 0 else ""),
flush=True)
return (out,)