chore: vendor selva_core from jnwnlee/selva@d7d40a9

Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes.
Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced
librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy
version incompatibility in some ComfyUI environments.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-04 15:18:09 +02:00
parent 762b19fd3a
commit 6bc3fd6443
106 changed files with 11323 additions and 0 deletions
View File
+190
View File
@@ -0,0 +1,190 @@
import logging
from dataclasses import dataclass
from fractions import Fraction
from pathlib import Path
from typing import Optional
import av
import numpy as np
import torch
from av import AudioFrame
log = logging.getLogger()
@dataclass
class VideoInfo:
duration_sec: float
fps: Fraction
clip_frames: torch.Tensor
sync_frames: torch.Tensor
all_frames: Optional[list[np.ndarray]]
@property
def height(self):
return self.all_frames[0].shape[0]
@property
def width(self):
return self.all_frames[0].shape[1]
@classmethod
def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
fps: Fraction) -> 'VideoInfo':
num_frames = int(duration_sec * fps)
all_frames = [image_info.original_frame] * num_frames
return cls(duration_sec=duration_sec,
fps=fps,
clip_frames=image_info.clip_frames,
sync_frames=image_info.sync_frames,
all_frames=all_frames)
@dataclass
class ImageInfo:
clip_frames: torch.Tensor
sync_frames: torch.Tensor
original_frame: Optional[np.ndarray]
@property
def height(self):
return self.original_frame.shape[0]
@property
def width(self):
return self.original_frame.shape[1]
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
output_frames = [[] for _ in list_of_fps]
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
all_frames = []
# container = av.open(video_path)
with av.open(video_path) as container:
stream = container.streams.video[0]
fps = stream.guessed_rate
stream.thread_type = 'AUTO'
for packet in container.demux(stream):
for frame in packet.decode():
frame_time = frame.time
if frame_time < start_sec:
continue
if frame_time > end_sec:
break
frame_np = None
if need_all_frames:
frame_np = frame.to_ndarray(format='rgb24')
all_frames.append(frame_np)
for i, _ in enumerate(list_of_fps):
this_time = frame_time
while this_time >= next_frame_time_for_each_fps[i]:
if frame_np is None:
frame_np = frame.to_ndarray(format='rgb24')
output_frames[i].append(frame_np)
next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
output_frames = [np.stack(frames) for frames in output_frames]
return output_frames, all_frames, fps
def normalize_video_chunk(video_chunk: torch.Tensor,
expected_length: int,
*,
n_tolerance_frame: int = 1,
desc: str = "") \
-> torch.Tensor:
# video_chunk: [T, H, W, C]
if video_chunk.shape[0] < expected_length:
if expected_length - video_chunk.shape[0] <= n_tolerance_frame:
# copy the last frame to make it the right length
log.warning(f'Video too short {desc}, padding {expected_length - video_chunk.shape[0]} frames with the last frame')
video_chunk = torch.cat([video_chunk, video_chunk[-1:].repeat(expected_length - video_chunk.shape[0], 1, 1, 1)])
assert video_chunk.shape[0] == expected_length
else:
raise RuntimeError(
f'Video too short {desc}, expected {expected_length}, got {video_chunk.shape[0]}'
)
video_chunk = video_chunk[:expected_length]
if video_chunk.shape[0] != expected_length:
raise RuntimeError(f'Video wrong length {desc}, '
f'expected {expected_length}, '
f'got {video_chunk.shape[0]}')
return video_chunk
def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
sampling_rate: int):
container = av.open(output_path, 'w')
output_video_stream = container.add_stream('h264', video_info.fps)
output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
output_video_stream.width = video_info.width
output_video_stream.height = video_info.height
output_video_stream.pix_fmt = 'yuv420p'
output_audio_stream = container.add_stream('aac', sampling_rate)
# encode video
for image in video_info.all_frames:
image = av.VideoFrame.from_ndarray(image)
packet = output_video_stream.encode(image)
container.mux(packet)
for packet in output_video_stream.encode():
container.mux(packet)
# convert float tensor audio to numpy array
audio_np = audio.numpy().astype(np.float32)
audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
audio_frame.sample_rate = sampling_rate
for packet in output_audio_stream.encode(audio_frame):
container.mux(packet)
for packet in output_audio_stream.encode():
container.mux(packet)
container.close()
def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
"""
NOTE: I don't think we can get the exact video duration right without re-encoding
so we are not using this but keeping it here for reference
"""
video = av.open(video_path)
output = av.open(output_path, 'w')
input_video_stream = video.streams.video[0]
output_video_stream = output.add_stream(template=input_video_stream)
output_audio_stream = output.add_stream('aac', sampling_rate)
duration_sec = audio.shape[-1] / sampling_rate
for packet in video.demux(input_video_stream):
# We need to skip the "flushing" packets that `demux` generates.
if packet.dts is None:
continue
# We need to assign the packet to the new stream.
packet.stream = output_video_stream
output.mux(packet)
# convert float tensor audio to numpy array
audio_np = audio.numpy().astype(np.float32)
audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
audio_frame.sample_rate = sampling_rate
for packet in output_audio_stream.encode(audio_frame):
output.mux(packet)
for packet in output_audio_stream.encode():
output.mux(packet)
video.close()
output.close()
output.close()
+227
View File
@@ -0,0 +1,227 @@
import logging
import random
from typing import Optional
import numpy as np
import torch
from omegaconf import DictConfig, open_dict
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from selva_core.data.vgg_sound import VGGSound
from selva_core.data.eval.eval_video_dataset import VGGSound as VGGSoundEval
from selva_core.data.eval.eval_video_dataset import InferenceVideoData, VGGMonoAudioBench
from selva_core.data.eval.audiocaps import AudioCapsData
from selva_core.data.mm_dataset import MultiModalDataset
from selva_core.data.mixup import DataMixupCollate
from selva_core.utils.dist_utils import local_rank
log = logging.getLogger()
# Re-seed randomness every time we start a worker
def worker_init_fn(worker_id: int):
worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
np.random.seed(worker_seed)
random.seed(worker_seed)
log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
def load_video_data(cfg: DictConfig, data_cfg: DictConfig, normalize_audio: bool = False,
) -> Dataset:
dataset = VGGSound(root=data_cfg.root,
tsv_path=data_cfg.subset_name,
sample_rate=16_000,
duration_sec=8.0,
normalize_audio=normalize_audio,
mmap_dir=data_cfg.memmap_dir,
tsv_tsynch_path=data_cfg.tsv_tsynch,
mmap_tsync_dir=data_cfg.memmap_dir_tsynch,
data_dim=cfg.data_dim
)
return dataset
def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
raise NotImplementedError('Audio data loading is not implemented yet')
def setup_training_datasets(cfg: DictConfig,
generator: torch.Generator,
) -> tuple[Dataset, DistributedSampler, DataLoader]:
if cfg.mini_train:
vgg = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=True)
dataset = MultiModalDataset([vgg], [])
if cfg.example_train:
video = load_video_data(cfg, cfg.data.Example_video, normalize_audio=True)
dataset = MultiModalDataset([video], [])
else:
vgg = load_video_data(cfg, cfg.data.VGGSound, normalize_audio=True)
# load the largest one first
# you can add more video/audio data upon demand, such as
# clotho = load_audio_data(cfg, cfg.data.Clotho)
dataset = MultiModalDataset([vgg], [])
batch_size = cfg.batch_size
num_workers = cfg.num_workers
pin_memory = cfg.pin_memory
if cfg.mixup.domain == 'data':
mixup_params = cfg.mixup.params
collate_fn = DataMixupCollate(generator=generator,
**mixup_params)
else:
collate_fn = None
sampler, loader = construct_loader(dataset,
batch_size,
num_workers,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=collate_fn)
return dataset, sampler, loader
def setup_test_datasets(cfg: DictConfig,
generator: torch.Generator,
) -> tuple[Dataset, DistributedSampler, DataLoader]:
if cfg.example_train:
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False, split='test')
elif cfg.dataset.startswith('vggsound'):
dataset = load_video_data(cfg, cfg.data.VGGSound_test, normalize_audio=False, split='test')
else:
raise NotImplementedError(f'Unknown dataset for test: {cfg.dataset}')
batch_size = cfg.batch_size
num_workers = cfg.get('num_workers_val', cfg.num_workers)
pin_memory = cfg.pin_memory
if cfg.mixup.domain == 'data':
mixup_config = cfg.mixup.params
collate_fn = DataMixupCollate(generator=generator,
**mixup_config)
else:
collate_fn = None
sampler, loader = construct_loader(dataset,
batch_size,
num_workers,
shuffle=False,
drop_last=False,
pin_memory=pin_memory,
collate_fn=collate_fn)
return dataset, sampler, loader
def setup_val_datasets(cfg: DictConfig,
generator: torch.Generator,
) -> tuple[Dataset, DataLoader, DataLoader]:
if cfg.example_train:
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False)
else:
dataset = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=False)
val_batch_size = cfg.batch_size
val_eval_batch_size = cfg.eval_batch_size
num_workers = cfg.get('num_workers_val', cfg.num_workers)
pin_memory = cfg.pin_memory
if cfg.mixup.domain == 'data':
mixup_config = cfg.mixup.params
collate_fn = DataMixupCollate(generator=generator,
**mixup_config)
else:
collate_fn = None
_, val_loader = construct_loader(dataset,
val_batch_size,
num_workers,
shuffle=False,
drop_last=False,
pin_memory=pin_memory,
collate_fn=collate_fn)
_, eval_loader = construct_loader(dataset,
val_eval_batch_size,
num_workers,
shuffle=False,
drop_last=False,
pin_memory=pin_memory,
collate_fn=collate_fn)
return dataset, val_loader, eval_loader
def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
if dataset_name.startswith('audiocaps_full'):
dataset = AudioCapsData(cfg.eval_data.audiocaps_full.audio_path,
cfg.eval_data.audiocaps_full.csv_path)
elif dataset_name.startswith('audiocaps'):
dataset = AudioCapsData(cfg.eval_data.audiocaps.audio_path,
cfg.eval_data.audiocaps.csv_path)
elif dataset_name.startswith('vggsound'):
dataset = VGGSound(cfg.eval_data.vggsound.video_path,
cfg.eval_data.vggsound.csv_path,
duration_sec=cfg.duration_s)
elif dataset_name.startswith('infer_video'):
dataset = InferenceVideoData(cfg.eval_data.infer_video.video_path,
cfg.eval_data.infer_video.jsonl_path,
duration_sec=cfg.duration_s)
cfg.batch_size = 1
elif dataset_name.startswith('example_video'):
dataset = VGGSoundEval(cfg.eval_data.Example_video.video_path,
cfg.eval_data.Example_video.csv_path,
duration_sec=cfg.duration_s)
elif dataset_name in ['vgg_monoaudio_intra', 'vgg_monoaudio_inter']:
dataset = VGGMonoAudioBench(cfg.eval_data[dataset_name].video_path,
cfg.eval_data[dataset_name].csv_path,
duration_sec=cfg.duration_s)
else:
raise ValueError(f'Invalid dataset name: {dataset_name}')
batch_size = cfg.batch_size
num_workers = cfg.num_workers
pin_memory = cfg.pin_memory
_, loader = construct_loader(dataset,
batch_size,
num_workers,
shuffle=False,
drop_last=False,
pin_memory=pin_memory,
error_avoidance=True)
return dataset, loader
def error_avoidance_collate(batch):
# Filter our None values
batch = [item for item in batch if item is not None]
if len(batch) == 0:
return None
return default_collate(batch)
def construct_loader(dataset: Dataset,
batch_size: int,
num_workers: int,
*,
shuffle: bool = True,
drop_last: bool = True,
pin_memory: bool = False,
error_avoidance: bool = False,
collate_fn = None) -> tuple[DistributedSampler, DataLoader]:
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
train_loader = DataLoader(dataset,
batch_size,
sampler=train_sampler,
num_workers=num_workers,
worker_init_fn=worker_init_fn,
drop_last=drop_last,
persistent_workers=num_workers > 0,
pin_memory=pin_memory,
collate_fn=error_avoidance_collate if error_avoidance else collate_fn)
return train_sampler, train_loader
View File
+39
View File
@@ -0,0 +1,39 @@
import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import Union
import pandas as pd
import torch
from torch.utils.data.dataset import Dataset
log = logging.getLogger()
class AudioCapsData(Dataset):
def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
df = pd.read_csv(csv_path).to_dict(orient='records')
audio_files = sorted(os.listdir(audio_path))
audio_files = set(
[Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
self.data = []
for row in df:
self.data.append({
'name': row['name'],
'caption': row['caption'],
})
self.audio_path = Path(audio_path)
self.csv_path = Path(csv_path)
log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
def __getitem__(self, idx: int) -> torch.Tensor:
return self.data[idx]
def __len__(self):
return len(self.data)
+237
View File
@@ -0,0 +1,237 @@
import json
import logging
import os
from pathlib import Path
from typing import Union
import pandas as pd
import torch
from torch.utils.data.dataset import Dataset
from torchvision.transforms import v2
from torio.io import StreamingMediaDecoder
from selva_core.data.av_utils import normalize_video_chunk
from selva_core.utils.dist_utils import local_rank
log = logging.getLogger()
_CLIP_SIZE = 384
_CLIP_FPS = 8.0
_SYNC_SIZE = 224
_SYNC_FPS = 25.0
class VideoDataset(Dataset):
def __init__(
self,
video_root: Union[str, Path],
*,
duration_sec: float = 8.0,
clip_video_required: bool = False,
):
self.video_root = Path(video_root)
self.duration_sec = duration_sec
self.clip_video_required = clip_video_required
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
self.sync_transform = v2.Compose([
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
# v2.CenterCrop(_SYNC_SIZE),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if self.clip_video_required:
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
self.clip_transform = v2.Compose([
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
])
# to be implemented by subclasses
self.captions = {}
self.negative_captions = {}
self.videos = sorted(list(self.captions.keys()))
def sample(self, idx: int) -> dict[str, torch.Tensor]:
video_id = self.videos[idx]
caption = self.captions[video_id]
negative_caption = self.negative_captions.get(video_id, None)
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
reader.add_basic_video_stream(
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
frame_rate=_SYNC_FPS,
format='rgb24',
)
if self.clip_video_required:
reader.add_basic_video_stream(
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
frame_rate=_CLIP_FPS,
format='rgb24',
)
reader.fill_buffer()
data_chunk = reader.pop_chunks()
sync_chunk = data_chunk[0]
if sync_chunk is None:
raise RuntimeError(f'Sync video returned None {video_id}')
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
n_tolerance_frame=3, desc=video_id)
sync_chunk = self.sync_transform(sync_chunk)
if self.clip_video_required:
clip_chunk = data_chunk[1]
if clip_chunk is None:
raise RuntimeError(f'CLIP video returned None {video_id}')
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
n_tolerance_frame=1, desc=video_id)
clip_chunk = self.clip_transform(clip_chunk)
data = {
'name': video_id,
'caption': caption,
'sync_video': sync_chunk,
}
if self.clip_video_required:
data['clip_video'] = clip_chunk
if negative_caption is not None:
data['negative_caption'] = negative_caption
return data
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
try:
return self.sample(idx)
except Exception as e:
log.error(f'Error loading video {self.videos[idx]}: {e}')
return None
def __len__(self):
return len(self.captions)
class VGGSound(VideoDataset):
def __init__(
self,
video_root: Union[str, Path],
csv_path: Union[str, Path],
*,
duration_sec: float = 8.0,
clip_video_required: bool = False,
):
super().__init__(video_root, duration_sec=duration_sec,
clip_video_required=clip_video_required)
self.video_root = Path(video_root)
self.csv_path = Path(csv_path)
videos = sorted(os.listdir(self.video_root))
if local_rank == 0:
log.info(f'{len(videos)} videos found in {video_root}')
self.captions = {}
df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
'split']).to_dict(orient='records')
videos_no_found = []
for row in df:
if row['split'] == 'test':
start_sec = int(row['sec'])
video_id = str(row['id'])
# this is how our videos are named
video_name = f'{video_id}_{start_sec:06d}'
if video_name + '.mp4' not in videos:
videos_no_found.append(video_name)
continue
self.captions[video_name] = row['caption']
if local_rank == 0:
log.info(f'{len(videos)} videos found in {video_root}')
log.info(f'{len(self.captions)} useable videos found')
if videos_no_found:
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
log.info(
'A small amount is expected, as not all videos are still available on YouTube')
self.videos = sorted(list(self.captions.keys()))
class InferenceVideoData(VideoDataset):
def __init__(
self,
video_root: Union[str, Path],
jsonl_root: Union[str, Path],
*,
duration_sec: float = 10.0,
clip_video_required: bool = False,
):
super().__init__(video_root, duration_sec=duration_sec,
clip_video_required=clip_video_required)
self.video_root = Path(video_root)
self.jsonl_root = Path(jsonl_root)
videos = sorted(os.listdir(self.video_root))
videos = [v[:-4] for v in videos] # remove extensions
self.captions = {}
for v in videos:
with open(self.jsonl_root / (v + '.jsonl')) as f:
data = json.load(f)
self.captions[v] = data['audio_prompt']
self.negative_captions[v] = data.get('negative_audio_prompt', None)
if local_rank == 0:
log.info(f'{len(videos)} videos found in {video_root}')
self.videos = videos
class VGGMonoAudioBench(VideoDataset):
def __init__(
self,
video_root: Union[str, Path],
csv_path: Union[str, Path],
*,
duration_sec: float = 8.0,
clip_video_required: bool = False,
):
super().__init__(video_root, duration_sec=duration_sec,
clip_video_required=clip_video_required)
self.video_root = Path(video_root)
self.csv_path = Path(csv_path)
videos = sorted(os.listdir(self.video_root))
if local_rank == 0:
log.info(f'{len(videos)} videos found in {video_root}')
self.captions = {}
self.negative_captions = {}
df = pd.read_csv(csv_path, header=0, usecols=['file_name', 'label', 'paired_label']
).to_dict(orient='records')
videos_no_found = []
for row in df:
video_name = str(Path(row['file_name']).stem)
if video_name + '.mp4' not in videos:
videos_no_found.append(video_name)
continue
self.captions[video_name] = row['label']
self.negative_captions[video_name] = row['paired_label']
if local_rank == 0:
log.info(f'{len(videos)} videos found in {video_root}')
log.info(f'{len(self.captions)} useable videos found')
if videos_no_found:
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}!')
self.videos = sorted(list(self.captions.keys()))
+194
View File
@@ -0,0 +1,194 @@
import logging
import os
from pathlib import Path
from typing import Optional, Union
import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset
from torchvision.transforms import v2
from torio.io import StreamingMediaDecoder
from selva_core.data.av_utils import normalize_video_chunk
from selva_core.utils.dist_utils import local_rank
log = logging.getLogger()
_CLIP_SIZE = 384
_CLIP_FPS = 8.0
_SYNC_SIZE = 224
_SYNC_FPS = 25.0
class VGGSound(Dataset):
def __init__(
self,
root: Union[str, Path],
*,
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
audio_required: bool = True,
sample_rate: int = 16_000,
duration_sec: float = 8.0,
audio_samples: Optional[int] = None,
normalize_audio: bool = False,
clip_video_required: bool = True,
):
self.root = Path(root)
self.audio_required = audio_required
if audio_required:
self.normalize_audio = normalize_audio
if audio_samples is None:
self.audio_samples = int(sample_rate * duration_sec)
else:
self.audio_samples = audio_samples
effective_duration = audio_samples / sample_rate
# make sure the duration is close enough, within 15ms
assert abs(effective_duration - duration_sec) < 0.015, \
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
self.clip_video_required = clip_video_required
videos = sorted(os.listdir(self.root))
videos = set([Path(v).stem for v in videos]) # remove extensions
self.labels = {}
self.videos = []
missing_videos = []
# read the tsv for subset information
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
for record in df_list:
id = record['id']
label = record['label']
if id in videos:
self.labels[id] = label
self.videos.append(id)
else:
missing_videos.append(id)
if local_rank == 0:
log.info(f'{len(videos)} videos found in {root}')
log.info(f'{len(self.videos)} videos found in {tsv_path}')
log.info(f'{len(missing_videos)} videos missing in {root}')
self.sample_rate = sample_rate
self.duration_sec = duration_sec
if audio_required:
self.expected_audio_length = self.audio_samples
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
if clip_video_required:
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
self.sync_transform = v2.Compose([
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
# v2.CenterCrop(_SYNC_SIZE),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if clip_video_required:
self.clip_transform = v2.Compose([
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
])
if audio_required:
self.resampler = {}
def sample(self, idx: int) -> dict[str, torch.Tensor]:
video_id = self.videos[idx]
label = self.labels[video_id]
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
reader.add_basic_video_stream(
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
frame_rate=_SYNC_FPS,
format='rgb24',
)
if self.audio_required:
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
if self.clip_video_required:
reader.add_basic_video_stream(
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
frame_rate=_CLIP_FPS,
format='rgb24',
)
reader.fill_buffer()
data_chunk = reader.pop_chunks()
sync_chunk = data_chunk[0]
if sync_chunk is None:
raise RuntimeError(f'Sync video returned None {video_id}')
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
n_tolerance_frame=3, desc=video_id)
sync_chunk = self.sync_transform(sync_chunk)
if self.audio_required:
audio_chunk = data_chunk[1]
if self.clip_video_required:
clip_chunk = data_chunk[2 if self.audio_required else 1]
if clip_chunk is None:
raise RuntimeError(f'CLIP video returned None {video_id}')
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
n_tolerance_frame=1, desc=video_id)
clip_chunk = self.clip_transform(clip_chunk)
# process audio
if self.audio_required:
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
audio_chunk = audio_chunk.transpose(0, 1)
audio_chunk = audio_chunk.mean(dim=0) # mono
if self.normalize_audio:
abs_max = audio_chunk.abs().max()
audio_chunk = audio_chunk * (0.95 / abs_max)
if abs_max <= 1e-6:
raise RuntimeError(f'Audio is silent {video_id}')
# resample
if sample_rate == self.sample_rate:
audio_chunk = audio_chunk
else:
if sample_rate not in self.resampler:
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
self.resampler[sample_rate] = torchaudio.transforms.Resample(
sample_rate,
self.sample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method='sinc_interp_kaiser',
beta=14.769656459379492,
)
audio_chunk = self.resampler[sample_rate](audio_chunk)
if audio_chunk.shape[0] < self.expected_audio_length:
raise RuntimeError(f'Audio too short {video_id}')
audio_chunk = audio_chunk[:self.expected_audio_length]
data = {
'id': video_id,
'caption': label,
'sync_video': sync_chunk,
}
if self.audio_required:
data['audio'] = audio_chunk
if self.clip_video_required:
data['clip_video'] = clip_chunk
return data
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
try:
return self.sample(idx)
except Exception as e:
log.error(f'Error loading video {self.videos[idx]}: {e}')
return None
def __len__(self):
return len(self.labels)
+129
View File
@@ -0,0 +1,129 @@
import logging
import os
from pathlib import Path
from typing import Union
import open_clip
import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset
log = logging.getLogger()
class WavTextClipsDataset(Dataset):
def __init__(
self,
root: Union[str, Path],
*,
captions_tsv: Union[str, Path],
clips_tsv: Union[str, Path],
sample_rate: int,
num_samples: int,
normalize_audio: bool = False,
reject_silent: bool = False,
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
):
self.root = Path(root)
self.sample_rate = sample_rate
self.num_samples = num_samples
self.normalize_audio = normalize_audio
self.reject_silent = reject_silent
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
audios = sorted(os.listdir(self.root))
audios = set([
Path(audio).stem for audio in audios
if audio.endswith('.wav') or audio.endswith('.flac')
])
self.captions = {}
# read the caption tsv
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
for record in df_list:
id = record['id']
caption = record['caption']
self.captions[id] = caption
# read the clip tsv
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
'id': str,
'name': str
}).to_dict('records')
self.clips = []
for record in df_list:
record['id'] = record['id']
record['name'] = record['name']
id = record['id']
name = record['name']
record['caption'] = self.captions[name]
self.clips.append(record)
log.info(f'Found {len(self.clips)} audio files in {self.root}')
self.resampler = {}
def __getitem__(self, idx: int) -> torch.Tensor:
try:
clip = self.clips[idx]
audio_name = clip['name']
audio_id = clip['id']
caption = clip['caption']
start_sample = clip['start_sample']
end_sample = clip['end_sample']
audio_path = self.root / f'{audio_name}.flac'
if not audio_path.exists():
audio_path = self.root / f'{audio_name}.wav'
assert audio_path.exists()
audio_chunk, sample_rate = torchaudio.load(audio_path)
audio_chunk = audio_chunk.mean(dim=0) # mono
abs_max = audio_chunk.abs().max()
if self.normalize_audio:
audio_chunk = audio_chunk / abs_max * 0.95
if self.reject_silent and abs_max < 1e-6:
log.warning(f'Rejecting silent audio')
return None
audio_chunk = audio_chunk[start_sample:end_sample]
# resample
if sample_rate == self.sample_rate:
audio_chunk = audio_chunk
else:
if sample_rate not in self.resampler:
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
self.resampler[sample_rate] = torchaudio.transforms.Resample(
sample_rate,
self.sample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method='sinc_interp_kaiser',
beta=14.769656459379492,
)
audio_chunk = self.resampler[sample_rate](audio_chunk)
if audio_chunk.shape[0] < self.num_samples:
raise ValueError('Audio is too short')
audio_chunk = audio_chunk[:self.num_samples]
tokens = self.tokenizer([caption])[0]
output = {
'waveform': audio_chunk,
'id': audio_id,
'caption': caption,
'tokens': tokens,
}
return output
except Exception as e:
log.error(f'Error reading {audio_path}: {e}')
return None
def __len__(self):
return len(self.clips)
+338
View File
@@ -0,0 +1,338 @@
""" Embedding Mixup
Reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/mixup.py
"""
from typing import Literal, Tuple, Union, List, Optional
from functools import partial
import gc
import numpy as np
import torch
from torch.utils.data.dataloader import default_collate
from torchvision.transforms import v2
from einops import rearrange
from omegaconf import DictConfig
from selva_core.data.vgg_sound import _SYNC_SIZE
class MixupBase:
""" Base class for mixup on either data or feature domain.
Applies different params to each element or whole batch.
Args:
generator (Optional[torch.Generator]): Random number generator for reproducibility
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
prob (float): Probability of applying mixup per batch or element
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
eps (float): Small epsilon value to avoid zero lambda
"""
def __init__(self, generator:torch.Generator,
*,
modality:Literal['video', 'audio', 'both'],
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
mode:Literal['elem','pair','batch', 'half']='batch',
eps:float=0.05
):
self.modality = modality
self.mixup_lambda:float = mixup_lambda
self.mixup_alpha:float = mixup_alpha
self.mix_prob:float = prob
self.mode:str = mode
self.eps:float = eps
self.mixup_enabled:bool = True # set to false to disable mixing (intended to be set by train loop)
if generator.device.type == 'cuda':
self.generator_cuda = generator
generator_seed = generator.initial_seed()
self.generator = torch.Generator(device='cpu')
self.generator.manual_seed(generator_seed)
else:
self.generator = generator
if not (self.mixup_lambda >= 0. and self.mixup_lambda <= 1.):
raise ValueError(f"mixup_lambda {self.mixup_lambda} should be in [0., 1.].")
if not self.mixup_alpha >= 0.:
raise ValueError(f"mixup_alpha {self.mixup_alpha} >= 0. should be true.")
if (self.mixup_alpha > 0. and self.mixup_lambda < 1.) or (self.mixup_alpha == 0. and self.mixup_lambda == 1.):
raise ValueError(f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true.")
def _params_per_elem(self, batch_size:int) -> np.ndarray:
lam:np.ndarray = np.ones(batch_size, dtype=np.float32)
if self.mixup_enabled:
if self.mixup_lambda < 1.: # constant lambda
lam_mix = np.full(batch_size, self.mixup_lambda, dtype=np.float32)
elif self.mixup_alpha > 0.: # sampled lambda
# Use torch's beta distribution with generator
lam_mix = torch.distributions.Beta(
torch.tensor([self.mixup_alpha]),
torch.tensor([self.mixup_alpha]),
).sample([batch_size]).numpy().astype(np.float32).reshape(-1)
else:
assert False, f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
lam_mix[lam_mix < self.eps] = self.eps
# Use torch's random with generator for the random comparison
rand_vals = torch.rand(batch_size, generator=self.generator).numpy()
lam = np.where(rand_vals < self.mix_prob, lam_mix, lam)
return lam
def _params_per_batch(self) -> float:
lam:float = 1.
if self.mixup_enabled:
if self.mixup_lambda < 1.: # constant lambda
lam = self.mixup_lambda
elif self.mixup_alpha > 0.: # sampled lambda
lam = torch.distributions.Beta(
torch.tensor([self.mixup_alpha]),
torch.tensor([self.mixup_alpha]),
).sample().item()
else:
assert False, f"mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
if lam < self.eps: lam = self.eps
lam = float(lam)
return lam
class DataMixupCollate(MixupBase):
""" Mixup video in data domain.
Applies different params to each element or whole batch.
Args:
generator (Optional[torch.Generator]): Random number generator for reproducibility
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
prob (float): Probability of applying mixup per batch or element
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
eps (float): Small epsilon value to avoid zero lambda
"""
def __init__(self, generator:torch.Generator,
*,
modality:Literal['video', 'audio', 'both']='video',
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
mode:Literal['elem','pair','batch', 'half']='batch',
eps:float=0.05
):
super().__init__(generator, modality=modality,
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
mode=mode, eps=eps)
self.source_video_key= 'sync_video'
self.source_audio_key = 'audio'
self.target_video_key = 'sync_video_mixed'
self.target_audio_key = 'audio_mixed'
if not mode == 'batch':
raise ValueError(f"Mode {mode} is not supported for data domain.")
self.sync_transform = v2.Compose([
v2.CenterCrop(_SYNC_SIZE),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def _concat_video_frames(self, batch:list, target_key:str='sync_video_mixed', source_key:str='sync_video') -> float:
# only batch mode supported
batch_size:int = len(batch)
lam:float = self._params_per_batch()
if lam == 1.:
# no mixup, just return
for i in range(batch_size):
batch[i][target_key] = batch[i][source_key]
return lam
# Randomly choose between horizontal and vertical resizing using
orig_size = int(lam * _SYNC_SIZE)
is_horizontal = True # torch.rand(1, generator=self.generator).item() < 0.5
if is_horizontal:
# Horizontal resize
resize_shape_orig = (_SYNC_SIZE, orig_size)
resize_shape_pair = (_SYNC_SIZE, _SYNC_SIZE-orig_size)
else:
# Vertical resize
resize_shape_orig = (orig_size, _SYNC_SIZE)
resize_shape_pair = (_SYNC_SIZE-orig_size, _SYNC_SIZE)
sync_resize_orig = v2.Compose([
v2.Resize(resize_shape_orig, interpolation=v2.InterpolationMode.BICUBIC),
])
sync_resize_pair = v2.Compose([
v2.Resize(resize_shape_pair, interpolation=v2.InterpolationMode.BICUBIC),
])
batch_videos_orig = torch.stack([batch[i][source_key] for i in range(batch_size)], dim=0)
batch_videos_pair = torch.stack([batch[batch_size - i - 1][source_key] for i in range(batch_size)], dim=0)
# (B, T, C, H, W)
# pass through resize, transform and concat
batch_videos_orig = sync_resize_orig(batch_videos_orig)
batch_videos_pair = sync_resize_pair(batch_videos_pair)
batch_videos_concat = torch.cat((batch_videos_orig, batch_videos_pair), dim=-1 if is_horizontal else -2)
batch_videos_concat = self.sync_transform(batch_videos_concat)
num_mixup = int(self.mix_prob * batch_size)
for i in range(num_mixup):
batch[i][target_key] = batch_videos_concat[i]
for i in range(num_mixup, batch_size):
batch[i][target_key] = batch[i][source_key] # no mixup
del batch_videos_orig, batch_videos_pair, sync_resize_orig, sync_resize_pair
gc.collect()
return lam
def _mix_audio_samples(self, batch:list, target_key:str='audio_mixed', source_key:str='audio',
normalize:bool = True) -> float:
# assume source_key audios are normalized
batch_size:int = len(batch)
lam:float = self._params_per_batch()
if lam == 1.:
# no mixup, just return
for i in range(batch_size):
batch[i][target_key] = batch[i][source_key]
return lam
num_mixup = int(self.mix_prob * batch_size)
for i in range(num_mixup):
batch[i][target_key] = batch[i][source_key] * lam + batch[batch_size - i - 1][source_key] * (1 - lam)
if normalize:
source_abs_max = batch[i][source_key].abs().max()
target_abs_max = batch[i][target_key].abs().max()
batch[i][target_key] = batch[i][target_key] * (source_abs_max / target_abs_max)
for i in range(num_mixup, batch_size):
batch[i][target_key] = batch[i][source_key] # no mixup
return lam
def __call__(self, batch:list, _=None) -> torch.tensor:
batch_size:int = len(batch)
assert batch_size % 2 == 0, f'Batch size {batch_size} should be even when using mixup'
half = 'half' in self.mode
if half:
batch_size //= 2
if self.modality == 'video' or self.modality == 'both':
lam = self._concat_video_frames(batch, target_key=self.target_video_key, source_key=self.source_video_key)
if self.modality == 'audio' or self.modality == 'both':
# raise NotImplementedError('Audio mixup is not implemented yet.')
lam = self._mix_audio_samples(batch, target_key=self.target_audio_key, source_key=self.source_audio_key)
return default_collate(batch)
class FeatureMixup(MixupBase):
""" Mixup video in feature domain.
Applies different params to each element or whole batch.
Args:
generator (Optional[torch.Generator]): Random number generator for reproducibility
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
prob (float): Probability of applying mixup per batch or element
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
eps (float): Small epsilon value to avoid zero lambda
"""
def __init__(self, generator:torch.Generator,
*,
modality:Literal['video', 'audio', 'both']='video',
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
mode:Literal['elem','pair','batch', 'half']='batch',
eps:float=0.05
):
super().__init__(generator, modality=modality,
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
mode=mode, eps=eps)
self.source_video_key= 'sync_f_vid_orig'
self.source_audio_key = 'sync_f_aud_orig'
self.target_video_key = 'sync_f_vid_mixed'
self.target_audio_key = 'sync_f_aud_mixed'
def _mix_elem_collate(self, batch:dict,
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig'],
half:bool=False) -> torch.tensor:
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
batch_size:int = len(batch['id'])
num_elem:int = batch_size // 2 if half else batch_size
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(num_elem))
indices = torch.arange(num_elem)
mix_indices = batch_size - indices - 1
mix_mask = lam_batch < 1
active_indices = indices[mix_mask]
active_mix_indices = mix_indices[mix_mask]
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
for target_key, source_key in zip(target_keys, source_keys):
batch[target_key][active_indices] = (
batch[source_key][active_indices] * active_lambdas +
batch[source_key][active_mix_indices] * (1 - active_lambdas)
)
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
if half:
lam_batch = torch.cat((lam_batch, torch.ones(num_elem, dtype=lam_batch.dtype)))
return lam_batch.unsqueeze(1)
def _mix_pair_collate(self, batch:dict,
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> torch.tensor:
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
batch_size:int = len(batch['id'])
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(batch_size // 2))
indices = torch.arange(batch_size // 2)
mix_indices = batch_size - indices - 1
mix_mask = lam_batch < 1
active_indices = indices[mix_mask]
active_mix_indices = mix_indices[mix_mask]
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
for target_key, source_key in zip(target_keys, source_keys):
batch[target_key][active_indices] = (
batch[source_key][active_indices] * active_lambdas +
batch[source_key][active_mix_indices] * (1 - active_lambdas)
)
batch[target_key][active_mix_indices] = (
batch[source_key][active_mix_indices] * active_lambdas +
batch[source_key][active_indices] * (1 - active_lambdas)
)
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
batch[target_key][~mix_indices[mix_mask]] = batch[source_key][~mix_indices[mix_mask]]
lam_batch = torch.cat((lam_batch, lam_batch.flip(0)))
return lam_batch.unsqueeze(1)
def _mix_batch_collate(self, batch:dict,
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> float:
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
lam:float = self._params_per_batch()
for target_key, source_key in zip(target_keys, source_keys):
num_mixup = int(self.mix_prob * batch[source_key].shape[0])
flipped_source = torch.flip(batch[source_key], dims=[0])
batch[target_key] = batch[source_key] * lam + flipped_source * (1 - lam)
batch[target_key][num_mixup:] = batch[source_key][num_mixup:] # no mixup
return lam
def __call__(self, batch:dict, _=None) -> None:
batch_size:int = len(batch['id'])
assert batch_size % 2 == 0, f'Batch size(={batch_size}) should be even when using this'
half = 'half' in self.mode
if half:
batch_size //= 2
# Mixup
if self.mode == 'elem' or self.mode == 'half':
collate_fn = partial(self._mix_elem_collate, half=half)
elif self.mode == 'pair':
collate_fn = self._mix_pair_collate
else:
collate_fn = self._mix_batch_collate
if self.modality == 'both':
target_keys, source_keys = [self.target_video_key, self.target_audio_key], [self.source_video_key, self.source_audio_key]
elif self.modality == 'video':
target_keys, source_keys = [self.target_video_key], [self.source_video_key]
elif self.modality == 'audio':
target_keys, source_keys = [self.target_audio_key], [self.source_audio_key]
lam = collate_fn(batch, target_keys=target_keys, source_keys=source_keys)
# return batch
+45
View File
@@ -0,0 +1,45 @@
import bisect
import torch
from torch.utils.data.dataset import Dataset
# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
class MultiModalDataset(Dataset):
datasets: list[Dataset]
cumulative_sizes: list[int]
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
super().__init__()
self.video_datasets = list(video_datasets)
self.audio_datasets = list(audio_datasets)
self.datasets = self.video_datasets + self.audio_datasets
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
return self.video_datasets[0].compute_latent_stats()
+148
View File
@@ -0,0 +1,148 @@
import logging
import os
import random
import tempfile
from pathlib import Path
from typing import Any, Optional, Union
import torch
import torch.distributed as dist
from tensordict import MemoryMappedTensor
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from tqdm import tqdm
from selva_core.utils.dist_utils import local_rank, world_size
scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
shm_path = Path('/dev/shm')
log = logging.getLogger()
def reseed(seed):
random.seed(seed)
torch.manual_seed(seed)
def local_scatter_torch(obj: Optional[Any]):
if world_size == 1:
# Just one worker. Do nothing.
return obj
array = [obj] * world_size
target_array = [None]
if local_rank == 0:
dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
else:
dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
return target_array[0]
class ShardDataset(Dataset):
def __init__(self, root):
self.root = root
self.shards = sorted(os.listdir(root))
def __len__(self):
return len(self.shards)
def __getitem__(self, idx):
return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
def get_tmp_dir(in_memory: bool) -> Path:
return shm_path if in_memory else scratch_path
def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
in_memory: bool) -> MemoryMappedTensor:
if local_rank == 0:
with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
log.info(f'Loading shards from {data_path} into {f.name}...')
data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
data = share_tensor_to_all(data)
torch.distributed.barrier()
f.close() # why does the context manager not close the file for me?
else:
log.info('Waiting for the data to be shared with me...')
data = share_tensor_to_all(None)
torch.distributed.barrier()
return data
def load_shards(
data_path: Union[str, Path],
ids: list[int],
*,
tmp_file_path: str,
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
id_set = set(ids)
shards = sorted(os.listdir(data_path))
log.info(f'Found {len(shards)} shards in {data_path}.')
first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
log.info(f'Rank {local_rank} created file {tmp_file_path}')
first_item = next(iter(first_shard.values()))
log.info(f'First item shape: {first_item.shape}')
mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
dtype=torch.float32,
filename=tmp_file_path,
existsok=True)
total_count = 0
used_index = set()
id_indexing = {i: idx for idx, i in enumerate(ids)}
# faster with no workers; otherwise we need to set_sharing_strategy('file_system')
loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
for data in tqdm(loader, desc='Loading shards'):
for i, v in data.items():
if i not in id_set:
continue
# tensor_index = ids.index(i)
tensor_index = id_indexing[i]
if tensor_index in used_index:
raise ValueError(f'Duplicate id {i} found in {data_path}.')
used_index.add(tensor_index)
mm_tensor[tensor_index] = v
total_count += 1
assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
log.info(f'Loaded {total_count} tensors from {data_path}.')
return mm_tensor
def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
"""
x: the tensor to be shared; None if local_rank != 0
return: the shared tensor
"""
# there is no need to share your stuff with anyone if you are alone; must be in memory
if world_size == 1:
return x
if local_rank == 0:
assert x is not None, 'x must not be None if local_rank == 0'
else:
assert x is None, 'x must be None if local_rank != 0'
if local_rank == 0:
filename = x.filename
meta_information = (filename, x.shape, x.dtype)
else:
meta_information = None
filename, data_shape, data_type = local_scatter_torch(meta_information)
if local_rank == 0:
data = x
else:
data = MemoryMappedTensor.from_filename(filename=filename,
dtype=data_type,
shape=data_shape)
return data
+299
View File
@@ -0,0 +1,299 @@
import logging
import os
from pathlib import Path
from typing import Optional, Union
import pandas as pd
import torch
import torchaudio
from torch.utils.data.dataset import Dataset
from torchvision.transforms import v2
from torio.io import StreamingMediaDecoder
from tensordict import TensorDict
from selva_core.data.av_utils import normalize_video_chunk
from selva_core.utils.dist_utils import local_rank
log = logging.getLogger()
_CLIP_SIZE = 384
_CLIP_FPS = 8.0
_SYNC_SIZE = 224
_SYNC_FPS = 25.0
class VGGSound(Dataset):
def __init__(
self,
root: Union[str, Path],
*,
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
for_generator: bool = True,
audio_required: bool = False,
sample_rate: int = 16_000,
duration_sec: float = 8.0,
audio_samples: Optional[int] = None,
normalize_audio: bool = False,
clip_video_required: bool = False,
mmap_dir: Union[str, Path] = None,
tsv_tsynch_path: Union[str, Path] = None,
mmap_tsync_dir: Union[str, Path] = None,
data_dim: dict[str, int] = None,
):
self.root = Path(root)
self.audio_required = audio_required
if audio_required:
self.normalize_audio = normalize_audio
if audio_samples is None:
self.audio_samples = int(sample_rate * duration_sec)
else:
self.audio_samples = audio_samples
effective_duration = audio_samples / sample_rate
# make sure the duration is close enough, within 15ms
assert abs(effective_duration - duration_sec) < 0.015, \
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
self.clip_video_required = clip_video_required
self.for_generator = for_generator
videos = sorted(os.listdir(self.root))
videos = set([Path(v).stem for v in videos]) # remove extensions
self.labels = {}
self.videos = []
missing_videos = []
# read the tsv for subset information
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
for record in df_list:
id = record['id']
label = record['label']
if id in videos:
self.labels[id] = label
self.videos.append(id)
else:
missing_videos.append(id)
if local_rank == 0:
log.info(f'{len(videos)} videos found in {root}')
log.info(f'{len(self.videos)} videos found in {tsv_path}')
log.info(f'{len(missing_videos)} videos missing in {root}')
self.sample_rate = sample_rate
self.duration_sec = duration_sec
if audio_required:
self.expected_audio_length = self.audio_samples
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
if clip_video_required:
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
self.sync_transform = v2.Compose([
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
# v2.CenterCrop(_SYNC_SIZE),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
if clip_video_required:
self.clip_transform = v2.Compose([
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
])
if audio_required:
self.resampler = {}
# mmap
log.info(f'Loading precomputed mmap from {mmap_dir}')
mmap_dir = Path(mmap_dir)
td = TensorDict.load_memmap(mmap_dir)
log.info(f'Loaded precomputed mmap from {mmap_dir}')
self.sync_features = td['sync_features']
if for_generator:
self.mean = td['mean']
self.std = td['std']
self.text_clip_features = td['text_features']
if clip_video_required:
self.clip_features = td['clip_features']
else:
self.clip_features = None
self.id2idx_mmap = {d['id']: i for i, d in enumerate(df_list)}
mmap_tsync_dir = Path(mmap_tsync_dir)
td_tsync = TensorDict.load_memmap(mmap_tsync_dir)
log.info(f'Loaded precomputed tsync mmap from {mmap_tsync_dir}')
self.text_features = td_tsync['text_features']
self.text_masks = td_tsync['text_masks']
df_list_tsync = pd.read_csv(tsv_tsynch_path, sep='\t').to_dict('records')
self.id2idx_mmap_tsync = {d['id']: i for i, d in enumerate(df_list_tsync)}
if local_rank == 0:
log.info(f'Loaded {len(self)} samples.')
log.info(f'Loaded sync_features: {self.sync_features.shape}.')
log.info(f'Loaded text_features: {self.text_features.shape}.')
log.info(f'Loaded text_masks: {self.text_masks.shape}.')
if for_generator:
log.info(f'Loaded mean: {self.mean.shape}.')
log.info(f'Loaded std: {self.std.shape}.')
log.info(f'Loaded text_clip_features: {self.text_clip_features.shape}.')
if clip_video_required:
log.info(f'Loaded clip_features: {self.clip_features.shape}.')
assert self.sync_features.shape[1] == data_dim['sync_seq_len'], \
f'{self.sync_features.shape[1]} != {data_dim["sync_seq_len"]}'
assert self.text_features.shape[1] <= data_dim['text_flant5_max_seq_len'], \
f'{self.text_features.shape[1]} > {data_dim["text_flant5_max_seq_len"]}'
assert self.text_masks.shape[1] <= data_dim['text_flant5_max_seq_len'], \
f'{self.text_masks.shape[1]} > {data_dim["text_flant5_max_seq_len"]}'
assert self.sync_features.shape[-1] == data_dim['sync_dim'], \
f'{self.sync_features.shape[-1]} != {data_dim["sync_dim"]}'
assert self.text_features.shape[-1] == data_dim['text_flant5_dim'], \
f'{self.text_features.shape[-1]} != {data_dim["text_flant5_dim"]}'
if for_generator:
assert self.mean.shape[1] == data_dim['latent_seq_len'], \
f'{self.mean.shape[1]} != {data_dim["latent_seq_len"]}'
assert self.std.shape[1] == data_dim['latent_seq_len'], \
f'{self.std.shape[1]} != {data_dim["latent_seq_len"]}'
assert self.text_clip_features.shape[1] == data_dim['text_clip_seq_len'], \
f'{self.text_clip_features.shape[1]} != {data_dim["text_clip_seq_len"]}'
assert self.text_clip_features.shape[-1] == data_dim['text_clip_dim'], \
f'{self.text_clip_features.shape[-1]} != {data_dim["text_clip_dim"]}'
if clip_video_required:
assert self.clip_features.shape[1] == data_dim['clip_seq_len'], \
f'{self.clip_features.shape[1]} != {data_dim["clip_seq_len"]}'
assert self.clip_features.shape[-1] == data_dim['clip_dim'], \
f'{self.clip_features.shape[-1]} != {data_dim["clip_dim"]}'
self.video_exist = torch.tensor(1, dtype=torch.bool)
self.text_exist = torch.tensor(1, dtype=torch.bool)
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: # mmap
latents = self.mean
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
def get_memory_mapped_tensor(self) -> TensorDict:
td = TensorDict({
'sync_features': self.sync_features,
'text_features': self.text_features,
'text_masks': self.text_masks,
})
if self.for_generator:
td['mean'] = self.mean
td['std'] = self.std
td['text_clip_features'] = self.text_clip_features
if self.clip_video_required:
td['clip_features'] = self.clip_features
return td
def sample(self, idx: int) -> dict[str, torch.Tensor]:
video_id = self.videos[idx]
if video_id in self.captions and torch.rand(1).item() < self.autoacd_sample_prob:
label = self.captions[video_id]
else:
label = self.labels[video_id]
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
reader.add_basic_video_stream(
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
frame_rate=_SYNC_FPS,
format='rgb24',
)
if self.audio_required:
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
if self.clip_video_required:
reader.add_basic_video_stream(
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
frame_rate=_CLIP_FPS,
format='rgb24',
)
reader.fill_buffer()
data_chunk = reader.pop_chunks()
sync_chunk = data_chunk[0]
if sync_chunk is None:
raise RuntimeError(f'Sync video returned None {video_id}')
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
n_tolerance_frame=3, desc=video_id)
sync_chunk = self.sync_transform(sync_chunk)
if self.audio_required:
audio_chunk = data_chunk[1]
if self.clip_video_required:
clip_chunk = data_chunk[2 if self.audio_required else 1]
if clip_chunk is None:
raise RuntimeError(f'CLIP video returned None {video_id}')
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
n_tolerance_frame=1, desc=video_id)
clip_chunk = self.clip_transform(clip_chunk)
# process audio
if self.audio_required:
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
audio_chunk = audio_chunk.transpose(0, 1)
audio_chunk = audio_chunk.mean(dim=0) # mono
if self.normalize_audio:
abs_max = audio_chunk.abs().max()
audio_chunk = audio_chunk * (0.95 / abs_max)
if abs_max <= 1e-6:
raise RuntimeError(f'Audio is silent {video_id}')
# resample
if sample_rate == self.sample_rate:
audio_chunk = audio_chunk
else:
if sample_rate not in self.resampler:
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
self.resampler[sample_rate] = torchaudio.transforms.Resample(
sample_rate,
self.sample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method='sinc_interp_kaiser',
beta=14.769656459379492,
)
audio_chunk = self.resampler[sample_rate](audio_chunk)
if audio_chunk.shape[0] < self.expected_audio_length:
raise RuntimeError(f'Audio too short {video_id}')
audio_chunk = audio_chunk[:self.expected_audio_length]
data = {
'id': video_id,
'caption': label,
'sync_video': sync_chunk,
'sync_f_vid_orig': self.sync_features[self.id2idx_mmap[video_id]],
'text_features': self.text_features[self.id2idx_mmap_tsync[video_id]],
'text_masks': self.text_masks[self.id2idx_mmap_tsync[video_id]],
'video_exist': self.video_exist,
'text_exist': self.text_exist,
}
if self.for_generator:
data['a_mean'] = self.mean[self.id2idx_mmap[video_id]]
data['a_std'] = self.std[self.id2idx_mmap[video_id]]
data['text_clip_features'] = self.text_clip_features[self.id2idx_mmap[video_id]]
if self.audio_required:
data['audio'] = audio_chunk
if self.clip_video_required:
data['clip_video'] = clip_chunk
data['clip_features'] = self.clip_features[self.id2idx_mmap[video_id]],
return data
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
try:
return self.sample(idx)
except Exception as e:
log.error(f'Error loading video {self.videos[idx]}: {e}')
return None
def __len__(self):
return len(self.labels)