diff --git a/nodes/selva_dataset_pipeline.py b/nodes/selva_dataset_pipeline.py index d1d78dd..2359e89 100644 --- a/nodes/selva_dataset_pipeline.py +++ b/nodes/selva_dataset_pipeline.py @@ -66,3 +66,47 @@ class SelvaDatasetLoader: print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True) return (dataset,) + + +class SelvaDatasetResampler: + """Resample all clips in a dataset to a target sample rate using soxr VHQ.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "dataset": (AUDIO_DATASET,), + "target_sr": ("INT", { + "default": 44100, "min": 8000, "max": 192000, + "tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.", + }), + } + } + + RETURN_TYPES = (AUDIO_DATASET,) + RETURN_NAMES = ("dataset",) + FUNCTION = "resample" + CATEGORY = SELVA_CATEGORY + DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate." + + def resample(self, dataset, target_sr: int): + import soxr + + out = [] + changed = 0 + for item in dataset: + sr = item["sample_rate"] + if sr == target_sr: + out.append(item) + continue + + wav = item["waveform"][0] # [C, L] + # soxr expects [L, C] (time-first), float64 + wav_np = wav.permute(1, 0).double().numpy() # [L, C] + wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ") + wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L] + out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]}) + changed += 1 + + print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True) + return (out,)