chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,129 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import open_clip
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
class WavTextClipsDataset(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Union[str, Path],
|
||||
*,
|
||||
captions_tsv: Union[str, Path],
|
||||
clips_tsv: Union[str, Path],
|
||||
sample_rate: int,
|
||||
num_samples: int,
|
||||
normalize_audio: bool = False,
|
||||
reject_silent: bool = False,
|
||||
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
|
||||
):
|
||||
self.root = Path(root)
|
||||
self.sample_rate = sample_rate
|
||||
self.num_samples = num_samples
|
||||
self.normalize_audio = normalize_audio
|
||||
self.reject_silent = reject_silent
|
||||
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
|
||||
|
||||
audios = sorted(os.listdir(self.root))
|
||||
audios = set([
|
||||
Path(audio).stem for audio in audios
|
||||
if audio.endswith('.wav') or audio.endswith('.flac')
|
||||
])
|
||||
self.captions = {}
|
||||
|
||||
# read the caption tsv
|
||||
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
|
||||
for record in df_list:
|
||||
id = record['id']
|
||||
caption = record['caption']
|
||||
self.captions[id] = caption
|
||||
|
||||
# read the clip tsv
|
||||
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
|
||||
'id': str,
|
||||
'name': str
|
||||
}).to_dict('records')
|
||||
self.clips = []
|
||||
for record in df_list:
|
||||
record['id'] = record['id']
|
||||
record['name'] = record['name']
|
||||
id = record['id']
|
||||
name = record['name']
|
||||
record['caption'] = self.captions[name]
|
||||
self.clips.append(record)
|
||||
|
||||
log.info(f'Found {len(self.clips)} audio files in {self.root}')
|
||||
|
||||
self.resampler = {}
|
||||
|
||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||
try:
|
||||
clip = self.clips[idx]
|
||||
audio_name = clip['name']
|
||||
audio_id = clip['id']
|
||||
caption = clip['caption']
|
||||
start_sample = clip['start_sample']
|
||||
end_sample = clip['end_sample']
|
||||
|
||||
audio_path = self.root / f'{audio_name}.flac'
|
||||
if not audio_path.exists():
|
||||
audio_path = self.root / f'{audio_name}.wav'
|
||||
assert audio_path.exists()
|
||||
|
||||
audio_chunk, sample_rate = torchaudio.load(audio_path)
|
||||
audio_chunk = audio_chunk.mean(dim=0) # mono
|
||||
abs_max = audio_chunk.abs().max()
|
||||
if self.normalize_audio:
|
||||
audio_chunk = audio_chunk / abs_max * 0.95
|
||||
|
||||
if self.reject_silent and abs_max < 1e-6:
|
||||
log.warning(f'Rejecting silent audio')
|
||||
return None
|
||||
|
||||
audio_chunk = audio_chunk[start_sample:end_sample]
|
||||
|
||||
# resample
|
||||
if sample_rate == self.sample_rate:
|
||||
audio_chunk = audio_chunk
|
||||
else:
|
||||
if sample_rate not in self.resampler:
|
||||
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
||||
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
||||
sample_rate,
|
||||
self.sample_rate,
|
||||
lowpass_filter_width=64,
|
||||
rolloff=0.9475937167399596,
|
||||
resampling_method='sinc_interp_kaiser',
|
||||
beta=14.769656459379492,
|
||||
)
|
||||
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
||||
|
||||
if audio_chunk.shape[0] < self.num_samples:
|
||||
raise ValueError('Audio is too short')
|
||||
audio_chunk = audio_chunk[:self.num_samples]
|
||||
|
||||
tokens = self.tokenizer([caption])[0]
|
||||
|
||||
output = {
|
||||
'waveform': audio_chunk,
|
||||
'id': audio_id,
|
||||
'caption': caption,
|
||||
'tokens': tokens,
|
||||
}
|
||||
|
||||
return output
|
||||
except Exception as e:
|
||||
log.error(f'Error reading {audio_path}: {e}')
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.clips)
|
||||
Reference in New Issue
Block a user