7eb9bd5745
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
190 lines
6.3 KiB
Python
190 lines
6.3 KiB
Python
"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes.
|
|
|
|
Typical chain:
|
|
SelvaDatasetLoader
|
|
↓ AUDIO_DATASET
|
|
SelvaDatasetResampler (optional)
|
|
↓ AUDIO_DATASET
|
|
SelvaDatasetLUFSNormalizer (optional)
|
|
↓ AUDIO_DATASET
|
|
SelvaDatasetInspector (optional)
|
|
↓ AUDIO_DATASET + STRING report
|
|
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torchaudio
|
|
|
|
from .utils import SELVA_CATEGORY
|
|
|
|
# ComfyUI custom type name — passed between all dataset pipeline nodes
|
|
AUDIO_DATASET = "AUDIO_DATASET"
|
|
|
|
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"}
|
|
|
|
|
|
class SelvaDatasetLoader:
|
|
"""Load all audio files in a folder into an in-memory AUDIO_DATASET."""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"folder": ("STRING", {
|
|
"default": "",
|
|
"tooltip": "Absolute path to folder containing audio files. Searched recursively.",
|
|
}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = (AUDIO_DATASET,)
|
|
RETURN_NAMES = ("dataset",)
|
|
FUNCTION = "load"
|
|
CATEGORY = SELVA_CATEGORY
|
|
DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET."
|
|
|
|
def load(self, folder: str):
|
|
folder = Path(folder.strip())
|
|
if not folder.exists():
|
|
raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}")
|
|
|
|
files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS]
|
|
if not files:
|
|
raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}")
|
|
|
|
dataset = []
|
|
for f in sorted(files):
|
|
try:
|
|
wav, sr = torchaudio.load(str(f)) # [C, L]
|
|
wav = wav.unsqueeze(0).float() # [1, C, L]
|
|
dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem})
|
|
except Exception as e:
|
|
print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True)
|
|
|
|
print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True)
|
|
return (dataset,)
|
|
|
|
|
|
class SelvaDatasetResampler:
|
|
"""Resample all clips in a dataset to a target sample rate using soxr VHQ."""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"dataset": (AUDIO_DATASET,),
|
|
"target_sr": ("INT", {
|
|
"default": 44100, "min": 8000, "max": 192000,
|
|
"tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.",
|
|
}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = (AUDIO_DATASET,)
|
|
RETURN_NAMES = ("dataset",)
|
|
FUNCTION = "resample"
|
|
CATEGORY = SELVA_CATEGORY
|
|
DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate."
|
|
|
|
def resample(self, dataset, target_sr: int):
|
|
import soxr
|
|
|
|
out = []
|
|
changed = 0
|
|
for item in dataset:
|
|
sr = item["sample_rate"]
|
|
if sr == target_sr:
|
|
out.append(item)
|
|
continue
|
|
|
|
wav = item["waveform"][0] # [C, L]
|
|
# soxr expects [L, C] (time-first), float64
|
|
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
|
|
wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ")
|
|
wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L]
|
|
out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]})
|
|
changed += 1
|
|
|
|
print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True)
|
|
return (out,)
|
|
|
|
|
|
class SelvaDatasetLUFSNormalizer:
|
|
"""Normalize each clip to a target integrated LUFS level + true peak limit."""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"dataset": (AUDIO_DATASET,),
|
|
"target_lufs": ("FLOAT", {
|
|
"default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5,
|
|
"tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.",
|
|
}),
|
|
"true_peak_dbtp": ("FLOAT", {
|
|
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
|
|
"tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.",
|
|
}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = (AUDIO_DATASET,)
|
|
RETURN_NAMES = ("dataset",)
|
|
FUNCTION = "normalize"
|
|
CATEGORY = SELVA_CATEGORY
|
|
DESCRIPTION = (
|
|
"Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. "
|
|
"Skips clips that are too short for LUFS measurement (< 0.4 s)."
|
|
)
|
|
|
|
def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float):
|
|
import pyloudnorm as pyln
|
|
|
|
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
|
|
out = []
|
|
skipped = 0
|
|
|
|
for item in dataset:
|
|
wav = item["waveform"][0] # [C, L]
|
|
sr = item["sample_rate"]
|
|
|
|
# pyloudnorm wants [L] mono or [L, C] multichannel, float64
|
|
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
|
|
if wav_np.shape[1] == 1:
|
|
wav_np = wav_np[:, 0] # [L] mono
|
|
|
|
meter = pyln.Meter(sr)
|
|
try:
|
|
loudness = meter.integrated_loudness(wav_np)
|
|
except Exception:
|
|
skipped += 1
|
|
out.append(item)
|
|
continue
|
|
|
|
if not np.isfinite(loudness):
|
|
skipped += 1
|
|
out.append(item)
|
|
continue
|
|
|
|
gain_db = target_lufs - loudness
|
|
gain_linear = 10.0 ** (gain_db / 20.0)
|
|
|
|
wav_norm = wav * gain_linear
|
|
|
|
# True peak limit
|
|
peak = wav_norm.abs().max().item()
|
|
if peak > tp_linear:
|
|
wav_norm = wav_norm * (tp_linear / peak)
|
|
|
|
out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]})
|
|
|
|
print(
|
|
f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized "
|
|
f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}",
|
|
flush=True,
|
|
)
|
|
return (out,)
|