Files
Ethanfel 6bc3fd6443 chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes.
Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced
librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy
version incompatibility in some ComfyUI environments.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 15:18:09 +02:00

195 lines
7.0 KiB
Python

import logging
import os
from pathlib import Path
from typing import Optional, Union
import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset
from torchvision.transforms import v2
from torio.io import StreamingMediaDecoder
from selva_core.data.av_utils import normalize_video_chunk
from selva_core.utils.dist_utils import local_rank
log = logging.getLogger()
_CLIP_SIZE = 384
_CLIP_FPS = 8.0
_SYNC_SIZE = 224
_SYNC_FPS = 25.0
class VGGSound(Dataset):
def __init__(
self,
root: Union[str, Path],
*,
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
audio_required: bool = True,
sample_rate: int = 16_000,
duration_sec: float = 8.0,
audio_samples: Optional[int] = None,
normalize_audio: bool = False,
clip_video_required: bool = True,
):
self.root = Path(root)
self.audio_required = audio_required
if audio_required:
self.normalize_audio = normalize_audio
if audio_samples is None:
self.audio_samples = int(sample_rate * duration_sec)
else:
self.audio_samples = audio_samples
effective_duration = audio_samples / sample_rate
# make sure the duration is close enough, within 15ms
assert abs(effective_duration - duration_sec) < 0.015, \
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
self.clip_video_required = clip_video_required
videos = sorted(os.listdir(self.root))
videos = set([Path(v).stem for v in videos]) # remove extensions
self.labels = {}
self.videos = []
missing_videos = []
# read the tsv for subset information
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
for record in df_list:
id = record['id']
label = record['label']
if id in videos:
self.labels[id] = label
self.videos.append(id)
else:
missing_videos.append(id)
if local_rank == 0:
log.info(f'{len(videos)} videos found in {root}')
log.info(f'{len(self.videos)} videos found in {tsv_path}')
log.info(f'{len(missing_videos)} videos missing in {root}')
self.sample_rate = sample_rate
self.duration_sec = duration_sec
if audio_required:
self.expected_audio_length = self.audio_samples
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
if clip_video_required:
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
self.sync_transform = v2.Compose([
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
# v2.CenterCrop(_SYNC_SIZE),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if clip_video_required:
self.clip_transform = v2.Compose([
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
])
if audio_required:
self.resampler = {}
def sample(self, idx: int) -> dict[str, torch.Tensor]:
video_id = self.videos[idx]
label = self.labels[video_id]
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
reader.add_basic_video_stream(
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
frame_rate=_SYNC_FPS,
format='rgb24',
)
if self.audio_required:
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
if self.clip_video_required:
reader.add_basic_video_stream(
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
frame_rate=_CLIP_FPS,
format='rgb24',
)
reader.fill_buffer()
data_chunk = reader.pop_chunks()
sync_chunk = data_chunk[0]
if sync_chunk is None:
raise RuntimeError(f'Sync video returned None {video_id}')
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
n_tolerance_frame=3, desc=video_id)
sync_chunk = self.sync_transform(sync_chunk)
if self.audio_required:
audio_chunk = data_chunk[1]
if self.clip_video_required:
clip_chunk = data_chunk[2 if self.audio_required else 1]
if clip_chunk is None:
raise RuntimeError(f'CLIP video returned None {video_id}')
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
n_tolerance_frame=1, desc=video_id)
clip_chunk = self.clip_transform(clip_chunk)
# process audio
if self.audio_required:
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
audio_chunk = audio_chunk.transpose(0, 1)
audio_chunk = audio_chunk.mean(dim=0) # mono
if self.normalize_audio:
abs_max = audio_chunk.abs().max()
audio_chunk = audio_chunk * (0.95 / abs_max)
if abs_max <= 1e-6:
raise RuntimeError(f'Audio is silent {video_id}')
# resample
if sample_rate == self.sample_rate:
audio_chunk = audio_chunk
else:
if sample_rate not in self.resampler:
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
self.resampler[sample_rate] = torchaudio.transforms.Resample(
sample_rate,
self.sample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method='sinc_interp_kaiser',
beta=14.769656459379492,
)
audio_chunk = self.resampler[sample_rate](audio_chunk)
if audio_chunk.shape[0] < self.expected_audio_length:
raise RuntimeError(f'Audio too short {video_id}')
audio_chunk = audio_chunk[:self.expected_audio_length]
data = {
'id': video_id,
'caption': label,
'sync_video': sync_chunk,
}
if self.audio_required:
data['audio'] = audio_chunk
if self.clip_video_required:
data['clip_video'] = clip_chunk
return data
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
try:
return self.sample(idx)
except Exception as e:
log.error(f'Error loading video {self.videos[idx]}: {e}')
return None
def __len__(self):
return len(self.labels)