chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,190 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from av import AudioFrame
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
@dataclass
|
||||
class VideoInfo:
|
||||
duration_sec: float
|
||||
fps: Fraction
|
||||
clip_frames: torch.Tensor
|
||||
sync_frames: torch.Tensor
|
||||
all_frames: Optional[list[np.ndarray]]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.all_frames[0].shape[0]
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.all_frames[0].shape[1]
|
||||
|
||||
@classmethod
|
||||
def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
|
||||
fps: Fraction) -> 'VideoInfo':
|
||||
num_frames = int(duration_sec * fps)
|
||||
all_frames = [image_info.original_frame] * num_frames
|
||||
return cls(duration_sec=duration_sec,
|
||||
fps=fps,
|
||||
clip_frames=image_info.clip_frames,
|
||||
sync_frames=image_info.sync_frames,
|
||||
all_frames=all_frames)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageInfo:
|
||||
clip_frames: torch.Tensor
|
||||
sync_frames: torch.Tensor
|
||||
original_frame: Optional[np.ndarray]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.original_frame.shape[0]
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.original_frame.shape[1]
|
||||
|
||||
|
||||
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
|
||||
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
|
||||
output_frames = [[] for _ in list_of_fps]
|
||||
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
|
||||
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
|
||||
all_frames = []
|
||||
|
||||
# container = av.open(video_path)
|
||||
with av.open(video_path) as container:
|
||||
stream = container.streams.video[0]
|
||||
fps = stream.guessed_rate
|
||||
stream.thread_type = 'AUTO'
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
frame_time = frame.time
|
||||
if frame_time < start_sec:
|
||||
continue
|
||||
if frame_time > end_sec:
|
||||
break
|
||||
|
||||
frame_np = None
|
||||
if need_all_frames:
|
||||
frame_np = frame.to_ndarray(format='rgb24')
|
||||
all_frames.append(frame_np)
|
||||
|
||||
for i, _ in enumerate(list_of_fps):
|
||||
this_time = frame_time
|
||||
while this_time >= next_frame_time_for_each_fps[i]:
|
||||
if frame_np is None:
|
||||
frame_np = frame.to_ndarray(format='rgb24')
|
||||
|
||||
output_frames[i].append(frame_np)
|
||||
next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
|
||||
|
||||
output_frames = [np.stack(frames) for frames in output_frames]
|
||||
return output_frames, all_frames, fps
|
||||
|
||||
|
||||
def normalize_video_chunk(video_chunk: torch.Tensor,
|
||||
expected_length: int,
|
||||
*,
|
||||
n_tolerance_frame: int = 1,
|
||||
desc: str = "") \
|
||||
-> torch.Tensor:
|
||||
# video_chunk: [T, H, W, C]
|
||||
if video_chunk.shape[0] < expected_length:
|
||||
if expected_length - video_chunk.shape[0] <= n_tolerance_frame:
|
||||
# copy the last frame to make it the right length
|
||||
log.warning(f'Video too short {desc}, padding {expected_length - video_chunk.shape[0]} frames with the last frame')
|
||||
video_chunk = torch.cat([video_chunk, video_chunk[-1:].repeat(expected_length - video_chunk.shape[0], 1, 1, 1)])
|
||||
assert video_chunk.shape[0] == expected_length
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Video too short {desc}, expected {expected_length}, got {video_chunk.shape[0]}'
|
||||
)
|
||||
video_chunk = video_chunk[:expected_length]
|
||||
if video_chunk.shape[0] != expected_length:
|
||||
raise RuntimeError(f'Video wrong length {desc}, '
|
||||
f'expected {expected_length}, '
|
||||
f'got {video_chunk.shape[0]}')
|
||||
|
||||
return video_chunk
|
||||
|
||||
|
||||
def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
|
||||
sampling_rate: int):
|
||||
container = av.open(output_path, 'w')
|
||||
output_video_stream = container.add_stream('h264', video_info.fps)
|
||||
output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
|
||||
output_video_stream.width = video_info.width
|
||||
output_video_stream.height = video_info.height
|
||||
output_video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
output_audio_stream = container.add_stream('aac', sampling_rate)
|
||||
|
||||
# encode video
|
||||
for image in video_info.all_frames:
|
||||
image = av.VideoFrame.from_ndarray(image)
|
||||
packet = output_video_stream.encode(image)
|
||||
container.mux(packet)
|
||||
|
||||
for packet in output_video_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
# convert float tensor audio to numpy array
|
||||
audio_np = audio.numpy().astype(np.float32)
|
||||
audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
||||
audio_frame.sample_rate = sampling_rate
|
||||
|
||||
for packet in output_audio_stream.encode(audio_frame):
|
||||
container.mux(packet)
|
||||
|
||||
for packet in output_audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
container.close()
|
||||
|
||||
|
||||
def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
|
||||
"""
|
||||
NOTE: I don't think we can get the exact video duration right without re-encoding
|
||||
so we are not using this but keeping it here for reference
|
||||
"""
|
||||
video = av.open(video_path)
|
||||
output = av.open(output_path, 'w')
|
||||
input_video_stream = video.streams.video[0]
|
||||
output_video_stream = output.add_stream(template=input_video_stream)
|
||||
output_audio_stream = output.add_stream('aac', sampling_rate)
|
||||
|
||||
duration_sec = audio.shape[-1] / sampling_rate
|
||||
|
||||
for packet in video.demux(input_video_stream):
|
||||
# We need to skip the "flushing" packets that `demux` generates.
|
||||
if packet.dts is None:
|
||||
continue
|
||||
# We need to assign the packet to the new stream.
|
||||
packet.stream = output_video_stream
|
||||
output.mux(packet)
|
||||
|
||||
# convert float tensor audio to numpy array
|
||||
audio_np = audio.numpy().astype(np.float32)
|
||||
audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
||||
audio_frame.sample_rate = sampling_rate
|
||||
|
||||
for packet in output_audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
for packet in output_audio_stream.encode():
|
||||
output.mux(packet)
|
||||
|
||||
video.close()
|
||||
output.close()
|
||||
|
||||
output.close()
|
||||
@@ -0,0 +1,227 @@
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from selva_core.data.vgg_sound import VGGSound
|
||||
from selva_core.data.eval.eval_video_dataset import VGGSound as VGGSoundEval
|
||||
from selva_core.data.eval.eval_video_dataset import InferenceVideoData, VGGMonoAudioBench
|
||||
from selva_core.data.eval.audiocaps import AudioCapsData
|
||||
from selva_core.data.mm_dataset import MultiModalDataset
|
||||
from selva_core.data.mixup import DataMixupCollate
|
||||
from selva_core.utils.dist_utils import local_rank
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
# Re-seed randomness every time we start a worker
|
||||
def worker_init_fn(worker_id: int):
|
||||
worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
|
||||
|
||||
|
||||
def load_video_data(cfg: DictConfig, data_cfg: DictConfig, normalize_audio: bool = False,
|
||||
) -> Dataset:
|
||||
dataset = VGGSound(root=data_cfg.root,
|
||||
tsv_path=data_cfg.subset_name,
|
||||
sample_rate=16_000,
|
||||
duration_sec=8.0,
|
||||
normalize_audio=normalize_audio,
|
||||
mmap_dir=data_cfg.memmap_dir,
|
||||
tsv_tsynch_path=data_cfg.tsv_tsynch,
|
||||
mmap_tsync_dir=data_cfg.memmap_dir_tsynch,
|
||||
data_dim=cfg.data_dim
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
||||
raise NotImplementedError('Audio data loading is not implemented yet')
|
||||
|
||||
|
||||
def setup_training_datasets(cfg: DictConfig,
|
||||
generator: torch.Generator,
|
||||
) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
||||
if cfg.mini_train:
|
||||
vgg = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=True)
|
||||
dataset = MultiModalDataset([vgg], [])
|
||||
if cfg.example_train:
|
||||
video = load_video_data(cfg, cfg.data.Example_video, normalize_audio=True)
|
||||
dataset = MultiModalDataset([video], [])
|
||||
else:
|
||||
vgg = load_video_data(cfg, cfg.data.VGGSound, normalize_audio=True)
|
||||
# load the largest one first
|
||||
# you can add more video/audio data upon demand, such as
|
||||
# clotho = load_audio_data(cfg, cfg.data.Clotho)
|
||||
dataset = MultiModalDataset([vgg], [])
|
||||
|
||||
batch_size = cfg.batch_size
|
||||
num_workers = cfg.num_workers
|
||||
pin_memory = cfg.pin_memory
|
||||
|
||||
if cfg.mixup.domain == 'data':
|
||||
mixup_params = cfg.mixup.params
|
||||
collate_fn = DataMixupCollate(generator=generator,
|
||||
**mixup_params)
|
||||
else:
|
||||
collate_fn = None
|
||||
|
||||
sampler, loader = construct_loader(dataset,
|
||||
batch_size,
|
||||
num_workers,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
return dataset, sampler, loader
|
||||
|
||||
|
||||
def setup_test_datasets(cfg: DictConfig,
|
||||
generator: torch.Generator,
|
||||
) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
||||
if cfg.example_train:
|
||||
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False, split='test')
|
||||
elif cfg.dataset.startswith('vggsound'):
|
||||
dataset = load_video_data(cfg, cfg.data.VGGSound_test, normalize_audio=False, split='test')
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown dataset for test: {cfg.dataset}')
|
||||
|
||||
batch_size = cfg.batch_size
|
||||
num_workers = cfg.get('num_workers_val', cfg.num_workers)
|
||||
pin_memory = cfg.pin_memory
|
||||
|
||||
if cfg.mixup.domain == 'data':
|
||||
mixup_config = cfg.mixup.params
|
||||
collate_fn = DataMixupCollate(generator=generator,
|
||||
**mixup_config)
|
||||
else:
|
||||
collate_fn = None
|
||||
|
||||
sampler, loader = construct_loader(dataset,
|
||||
batch_size,
|
||||
num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
return dataset, sampler, loader
|
||||
|
||||
|
||||
def setup_val_datasets(cfg: DictConfig,
|
||||
generator: torch.Generator,
|
||||
) -> tuple[Dataset, DataLoader, DataLoader]:
|
||||
if cfg.example_train:
|
||||
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False)
|
||||
else:
|
||||
dataset = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=False)
|
||||
|
||||
val_batch_size = cfg.batch_size
|
||||
val_eval_batch_size = cfg.eval_batch_size
|
||||
num_workers = cfg.get('num_workers_val', cfg.num_workers)
|
||||
pin_memory = cfg.pin_memory
|
||||
|
||||
if cfg.mixup.domain == 'data':
|
||||
mixup_config = cfg.mixup.params
|
||||
collate_fn = DataMixupCollate(generator=generator,
|
||||
**mixup_config)
|
||||
else:
|
||||
collate_fn = None
|
||||
|
||||
_, val_loader = construct_loader(dataset,
|
||||
val_batch_size,
|
||||
num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=collate_fn)
|
||||
_, eval_loader = construct_loader(dataset,
|
||||
val_eval_batch_size,
|
||||
num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
return dataset, val_loader, eval_loader
|
||||
|
||||
|
||||
def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
|
||||
if dataset_name.startswith('audiocaps_full'):
|
||||
dataset = AudioCapsData(cfg.eval_data.audiocaps_full.audio_path,
|
||||
cfg.eval_data.audiocaps_full.csv_path)
|
||||
elif dataset_name.startswith('audiocaps'):
|
||||
dataset = AudioCapsData(cfg.eval_data.audiocaps.audio_path,
|
||||
cfg.eval_data.audiocaps.csv_path)
|
||||
elif dataset_name.startswith('vggsound'):
|
||||
dataset = VGGSound(cfg.eval_data.vggsound.video_path,
|
||||
cfg.eval_data.vggsound.csv_path,
|
||||
duration_sec=cfg.duration_s)
|
||||
elif dataset_name.startswith('infer_video'):
|
||||
dataset = InferenceVideoData(cfg.eval_data.infer_video.video_path,
|
||||
cfg.eval_data.infer_video.jsonl_path,
|
||||
duration_sec=cfg.duration_s)
|
||||
cfg.batch_size = 1
|
||||
elif dataset_name.startswith('example_video'):
|
||||
dataset = VGGSoundEval(cfg.eval_data.Example_video.video_path,
|
||||
cfg.eval_data.Example_video.csv_path,
|
||||
duration_sec=cfg.duration_s)
|
||||
elif dataset_name in ['vgg_monoaudio_intra', 'vgg_monoaudio_inter']:
|
||||
dataset = VGGMonoAudioBench(cfg.eval_data[dataset_name].video_path,
|
||||
cfg.eval_data[dataset_name].csv_path,
|
||||
duration_sec=cfg.duration_s)
|
||||
|
||||
else:
|
||||
raise ValueError(f'Invalid dataset name: {dataset_name}')
|
||||
|
||||
batch_size = cfg.batch_size
|
||||
num_workers = cfg.num_workers
|
||||
pin_memory = cfg.pin_memory
|
||||
_, loader = construct_loader(dataset,
|
||||
batch_size,
|
||||
num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=pin_memory,
|
||||
error_avoidance=True)
|
||||
return dataset, loader
|
||||
|
||||
|
||||
def error_avoidance_collate(batch):
|
||||
# Filter our None values
|
||||
batch = [item for item in batch if item is not None]
|
||||
if len(batch) == 0:
|
||||
return None
|
||||
return default_collate(batch)
|
||||
|
||||
|
||||
def construct_loader(dataset: Dataset,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
drop_last: bool = True,
|
||||
pin_memory: bool = False,
|
||||
error_avoidance: bool = False,
|
||||
collate_fn = None) -> tuple[DistributedSampler, DataLoader]:
|
||||
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
|
||||
train_loader = DataLoader(dataset,
|
||||
batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=num_workers,
|
||||
worker_init_fn=worker_init_fn,
|
||||
drop_last=drop_last,
|
||||
persistent_workers=num_workers > 0,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=error_avoidance_collate if error_avoidance else collate_fn)
|
||||
return train_sampler, train_loader
|
||||
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
class AudioCapsData(Dataset):
|
||||
|
||||
def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
|
||||
df = pd.read_csv(csv_path).to_dict(orient='records')
|
||||
|
||||
audio_files = sorted(os.listdir(audio_path))
|
||||
audio_files = set(
|
||||
[Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
|
||||
|
||||
self.data = []
|
||||
for row in df:
|
||||
self.data.append({
|
||||
'name': row['name'],
|
||||
'caption': row['caption'],
|
||||
})
|
||||
|
||||
self.audio_path = Path(audio_path)
|
||||
self.csv_path = Path(csv_path)
|
||||
|
||||
log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
|
||||
|
||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||
return self.data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@@ -0,0 +1,237 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torchvision.transforms import v2
|
||||
from torio.io import StreamingMediaDecoder
|
||||
|
||||
from selva_core.data.av_utils import normalize_video_chunk
|
||||
from selva_core.utils.dist_utils import local_rank
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
_CLIP_SIZE = 384
|
||||
_CLIP_FPS = 8.0
|
||||
|
||||
_SYNC_SIZE = 224
|
||||
_SYNC_FPS = 25.0
|
||||
|
||||
|
||||
class VideoDataset(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_root: Union[str, Path],
|
||||
*,
|
||||
duration_sec: float = 8.0,
|
||||
clip_video_required: bool = False,
|
||||
):
|
||||
self.video_root = Path(video_root)
|
||||
self.duration_sec = duration_sec
|
||||
self.clip_video_required = clip_video_required
|
||||
|
||||
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
||||
self.sync_transform = v2.Compose([
|
||||
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
# v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
if self.clip_video_required:
|
||||
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
||||
self.clip_transform = v2.Compose([
|
||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
|
||||
# to be implemented by subclasses
|
||||
self.captions = {}
|
||||
self.negative_captions = {}
|
||||
self.videos = sorted(list(self.captions.keys()))
|
||||
|
||||
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
video_id = self.videos[idx]
|
||||
caption = self.captions[video_id]
|
||||
negative_caption = self.negative_captions.get(video_id, None)
|
||||
|
||||
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
||||
frame_rate=_SYNC_FPS,
|
||||
format='rgb24',
|
||||
)
|
||||
if self.clip_video_required:
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
||||
frame_rate=_CLIP_FPS,
|
||||
format='rgb24',
|
||||
)
|
||||
|
||||
reader.fill_buffer()
|
||||
data_chunk = reader.pop_chunks()
|
||||
|
||||
sync_chunk = data_chunk[0]
|
||||
if sync_chunk is None:
|
||||
raise RuntimeError(f'Sync video returned None {video_id}')
|
||||
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
|
||||
n_tolerance_frame=3, desc=video_id)
|
||||
sync_chunk = self.sync_transform(sync_chunk)
|
||||
|
||||
if self.clip_video_required:
|
||||
clip_chunk = data_chunk[1]
|
||||
if clip_chunk is None:
|
||||
raise RuntimeError(f'CLIP video returned None {video_id}')
|
||||
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
|
||||
n_tolerance_frame=1, desc=video_id)
|
||||
clip_chunk = self.clip_transform(clip_chunk)
|
||||
|
||||
data = {
|
||||
'name': video_id,
|
||||
'caption': caption,
|
||||
'sync_video': sync_chunk,
|
||||
}
|
||||
if self.clip_video_required:
|
||||
data['clip_video'] = clip_chunk
|
||||
if negative_caption is not None:
|
||||
data['negative_caption'] = negative_caption
|
||||
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
try:
|
||||
return self.sample(idx)
|
||||
except Exception as e:
|
||||
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.captions)
|
||||
|
||||
|
||||
class VGGSound(VideoDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_root: Union[str, Path],
|
||||
csv_path: Union[str, Path],
|
||||
*,
|
||||
duration_sec: float = 8.0,
|
||||
clip_video_required: bool = False,
|
||||
):
|
||||
super().__init__(video_root, duration_sec=duration_sec,
|
||||
clip_video_required=clip_video_required)
|
||||
self.video_root = Path(video_root)
|
||||
self.csv_path = Path(csv_path)
|
||||
|
||||
videos = sorted(os.listdir(self.video_root))
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {video_root}')
|
||||
self.captions = {}
|
||||
|
||||
df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
|
||||
'split']).to_dict(orient='records')
|
||||
|
||||
videos_no_found = []
|
||||
for row in df:
|
||||
if row['split'] == 'test':
|
||||
start_sec = int(row['sec'])
|
||||
video_id = str(row['id'])
|
||||
# this is how our videos are named
|
||||
video_name = f'{video_id}_{start_sec:06d}'
|
||||
if video_name + '.mp4' not in videos:
|
||||
videos_no_found.append(video_name)
|
||||
continue
|
||||
|
||||
self.captions[video_name] = row['caption']
|
||||
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {video_root}')
|
||||
log.info(f'{len(self.captions)} useable videos found')
|
||||
if videos_no_found:
|
||||
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
|
||||
log.info(
|
||||
'A small amount is expected, as not all videos are still available on YouTube')
|
||||
|
||||
self.videos = sorted(list(self.captions.keys()))
|
||||
|
||||
|
||||
class InferenceVideoData(VideoDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_root: Union[str, Path],
|
||||
jsonl_root: Union[str, Path],
|
||||
*,
|
||||
duration_sec: float = 10.0,
|
||||
clip_video_required: bool = False,
|
||||
):
|
||||
super().__init__(video_root, duration_sec=duration_sec,
|
||||
clip_video_required=clip_video_required)
|
||||
self.video_root = Path(video_root)
|
||||
self.jsonl_root = Path(jsonl_root)
|
||||
|
||||
videos = sorted(os.listdir(self.video_root))
|
||||
videos = [v[:-4] for v in videos] # remove extensions
|
||||
self.captions = {}
|
||||
|
||||
for v in videos:
|
||||
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
||||
data = json.load(f)
|
||||
self.captions[v] = data['audio_prompt']
|
||||
self.negative_captions[v] = data.get('negative_audio_prompt', None)
|
||||
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {video_root}')
|
||||
|
||||
self.videos = videos
|
||||
|
||||
|
||||
class VGGMonoAudioBench(VideoDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_root: Union[str, Path],
|
||||
csv_path: Union[str, Path],
|
||||
*,
|
||||
duration_sec: float = 8.0,
|
||||
clip_video_required: bool = False,
|
||||
):
|
||||
super().__init__(video_root, duration_sec=duration_sec,
|
||||
clip_video_required=clip_video_required)
|
||||
self.video_root = Path(video_root)
|
||||
self.csv_path = Path(csv_path)
|
||||
|
||||
videos = sorted(os.listdir(self.video_root))
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {video_root}')
|
||||
self.captions = {}
|
||||
self.negative_captions = {}
|
||||
|
||||
df = pd.read_csv(csv_path, header=0, usecols=['file_name', 'label', 'paired_label']
|
||||
).to_dict(orient='records')
|
||||
|
||||
videos_no_found = []
|
||||
for row in df:
|
||||
video_name = str(Path(row['file_name']).stem)
|
||||
if video_name + '.mp4' not in videos:
|
||||
videos_no_found.append(video_name)
|
||||
continue
|
||||
|
||||
self.captions[video_name] = row['label']
|
||||
self.negative_captions[video_name] = row['paired_label']
|
||||
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {video_root}')
|
||||
log.info(f'{len(self.captions)} useable videos found')
|
||||
if videos_no_found:
|
||||
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}!')
|
||||
|
||||
self.videos = sorted(list(self.captions.keys()))
|
||||
@@ -0,0 +1,194 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torchvision.transforms import v2
|
||||
from torio.io import StreamingMediaDecoder
|
||||
|
||||
from selva_core.data.av_utils import normalize_video_chunk
|
||||
from selva_core.utils.dist_utils import local_rank
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
_CLIP_SIZE = 384
|
||||
_CLIP_FPS = 8.0
|
||||
|
||||
_SYNC_SIZE = 224
|
||||
_SYNC_FPS = 25.0
|
||||
|
||||
|
||||
class VGGSound(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Union[str, Path],
|
||||
*,
|
||||
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
|
||||
audio_required: bool = True,
|
||||
sample_rate: int = 16_000,
|
||||
duration_sec: float = 8.0,
|
||||
audio_samples: Optional[int] = None,
|
||||
normalize_audio: bool = False,
|
||||
clip_video_required: bool = True,
|
||||
):
|
||||
self.root = Path(root)
|
||||
self.audio_required = audio_required
|
||||
if audio_required:
|
||||
self.normalize_audio = normalize_audio
|
||||
if audio_samples is None:
|
||||
self.audio_samples = int(sample_rate * duration_sec)
|
||||
else:
|
||||
self.audio_samples = audio_samples
|
||||
effective_duration = audio_samples / sample_rate
|
||||
# make sure the duration is close enough, within 15ms
|
||||
assert abs(effective_duration - duration_sec) < 0.015, \
|
||||
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
||||
self.clip_video_required = clip_video_required
|
||||
|
||||
videos = sorted(os.listdir(self.root))
|
||||
videos = set([Path(v).stem for v in videos]) # remove extensions
|
||||
self.labels = {}
|
||||
self.videos = []
|
||||
missing_videos = []
|
||||
|
||||
# read the tsv for subset information
|
||||
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
|
||||
for record in df_list:
|
||||
id = record['id']
|
||||
label = record['label']
|
||||
if id in videos:
|
||||
self.labels[id] = label
|
||||
self.videos.append(id)
|
||||
else:
|
||||
missing_videos.append(id)
|
||||
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {root}')
|
||||
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
||||
log.info(f'{len(missing_videos)} videos missing in {root}')
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.duration_sec = duration_sec
|
||||
|
||||
if audio_required:
|
||||
self.expected_audio_length = self.audio_samples
|
||||
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
||||
if clip_video_required:
|
||||
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
||||
|
||||
self.sync_transform = v2.Compose([
|
||||
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
# v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
if clip_video_required:
|
||||
self.clip_transform = v2.Compose([
|
||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
if audio_required:
|
||||
self.resampler = {}
|
||||
|
||||
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
video_id = self.videos[idx]
|
||||
|
||||
label = self.labels[video_id]
|
||||
|
||||
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
||||
frame_rate=_SYNC_FPS,
|
||||
format='rgb24',
|
||||
)
|
||||
if self.audio_required:
|
||||
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
|
||||
if self.clip_video_required:
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
||||
frame_rate=_CLIP_FPS,
|
||||
format='rgb24',
|
||||
)
|
||||
|
||||
reader.fill_buffer()
|
||||
data_chunk = reader.pop_chunks()
|
||||
|
||||
sync_chunk = data_chunk[0]
|
||||
if sync_chunk is None:
|
||||
raise RuntimeError(f'Sync video returned None {video_id}')
|
||||
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
|
||||
n_tolerance_frame=3, desc=video_id)
|
||||
sync_chunk = self.sync_transform(sync_chunk)
|
||||
|
||||
if self.audio_required:
|
||||
audio_chunk = data_chunk[1]
|
||||
|
||||
if self.clip_video_required:
|
||||
clip_chunk = data_chunk[2 if self.audio_required else 1]
|
||||
if clip_chunk is None:
|
||||
raise RuntimeError(f'CLIP video returned None {video_id}')
|
||||
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
|
||||
n_tolerance_frame=1, desc=video_id)
|
||||
clip_chunk = self.clip_transform(clip_chunk)
|
||||
|
||||
# process audio
|
||||
if self.audio_required:
|
||||
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
|
||||
audio_chunk = audio_chunk.transpose(0, 1)
|
||||
audio_chunk = audio_chunk.mean(dim=0) # mono
|
||||
if self.normalize_audio:
|
||||
abs_max = audio_chunk.abs().max()
|
||||
audio_chunk = audio_chunk * (0.95 / abs_max)
|
||||
if abs_max <= 1e-6:
|
||||
raise RuntimeError(f'Audio is silent {video_id}')
|
||||
|
||||
# resample
|
||||
if sample_rate == self.sample_rate:
|
||||
audio_chunk = audio_chunk
|
||||
else:
|
||||
if sample_rate not in self.resampler:
|
||||
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
||||
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
||||
sample_rate,
|
||||
self.sample_rate,
|
||||
lowpass_filter_width=64,
|
||||
rolloff=0.9475937167399596,
|
||||
resampling_method='sinc_interp_kaiser',
|
||||
beta=14.769656459379492,
|
||||
)
|
||||
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
||||
|
||||
if audio_chunk.shape[0] < self.expected_audio_length:
|
||||
raise RuntimeError(f'Audio too short {video_id}')
|
||||
audio_chunk = audio_chunk[:self.expected_audio_length]
|
||||
|
||||
data = {
|
||||
'id': video_id,
|
||||
'caption': label,
|
||||
'sync_video': sync_chunk,
|
||||
}
|
||||
|
||||
if self.audio_required:
|
||||
data['audio'] = audio_chunk
|
||||
if self.clip_video_required:
|
||||
data['clip_video'] = clip_chunk
|
||||
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
try:
|
||||
return self.sample(idx)
|
||||
except Exception as e:
|
||||
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
@@ -0,0 +1,129 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import open_clip
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
class WavTextClipsDataset(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Union[str, Path],
|
||||
*,
|
||||
captions_tsv: Union[str, Path],
|
||||
clips_tsv: Union[str, Path],
|
||||
sample_rate: int,
|
||||
num_samples: int,
|
||||
normalize_audio: bool = False,
|
||||
reject_silent: bool = False,
|
||||
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
|
||||
):
|
||||
self.root = Path(root)
|
||||
self.sample_rate = sample_rate
|
||||
self.num_samples = num_samples
|
||||
self.normalize_audio = normalize_audio
|
||||
self.reject_silent = reject_silent
|
||||
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
|
||||
|
||||
audios = sorted(os.listdir(self.root))
|
||||
audios = set([
|
||||
Path(audio).stem for audio in audios
|
||||
if audio.endswith('.wav') or audio.endswith('.flac')
|
||||
])
|
||||
self.captions = {}
|
||||
|
||||
# read the caption tsv
|
||||
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
|
||||
for record in df_list:
|
||||
id = record['id']
|
||||
caption = record['caption']
|
||||
self.captions[id] = caption
|
||||
|
||||
# read the clip tsv
|
||||
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
|
||||
'id': str,
|
||||
'name': str
|
||||
}).to_dict('records')
|
||||
self.clips = []
|
||||
for record in df_list:
|
||||
record['id'] = record['id']
|
||||
record['name'] = record['name']
|
||||
id = record['id']
|
||||
name = record['name']
|
||||
record['caption'] = self.captions[name]
|
||||
self.clips.append(record)
|
||||
|
||||
log.info(f'Found {len(self.clips)} audio files in {self.root}')
|
||||
|
||||
self.resampler = {}
|
||||
|
||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||
try:
|
||||
clip = self.clips[idx]
|
||||
audio_name = clip['name']
|
||||
audio_id = clip['id']
|
||||
caption = clip['caption']
|
||||
start_sample = clip['start_sample']
|
||||
end_sample = clip['end_sample']
|
||||
|
||||
audio_path = self.root / f'{audio_name}.flac'
|
||||
if not audio_path.exists():
|
||||
audio_path = self.root / f'{audio_name}.wav'
|
||||
assert audio_path.exists()
|
||||
|
||||
audio_chunk, sample_rate = torchaudio.load(audio_path)
|
||||
audio_chunk = audio_chunk.mean(dim=0) # mono
|
||||
abs_max = audio_chunk.abs().max()
|
||||
if self.normalize_audio:
|
||||
audio_chunk = audio_chunk / abs_max * 0.95
|
||||
|
||||
if self.reject_silent and abs_max < 1e-6:
|
||||
log.warning(f'Rejecting silent audio')
|
||||
return None
|
||||
|
||||
audio_chunk = audio_chunk[start_sample:end_sample]
|
||||
|
||||
# resample
|
||||
if sample_rate == self.sample_rate:
|
||||
audio_chunk = audio_chunk
|
||||
else:
|
||||
if sample_rate not in self.resampler:
|
||||
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
||||
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
||||
sample_rate,
|
||||
self.sample_rate,
|
||||
lowpass_filter_width=64,
|
||||
rolloff=0.9475937167399596,
|
||||
resampling_method='sinc_interp_kaiser',
|
||||
beta=14.769656459379492,
|
||||
)
|
||||
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
||||
|
||||
if audio_chunk.shape[0] < self.num_samples:
|
||||
raise ValueError('Audio is too short')
|
||||
audio_chunk = audio_chunk[:self.num_samples]
|
||||
|
||||
tokens = self.tokenizer([caption])[0]
|
||||
|
||||
output = {
|
||||
'waveform': audio_chunk,
|
||||
'id': audio_id,
|
||||
'caption': caption,
|
||||
'tokens': tokens,
|
||||
}
|
||||
|
||||
return output
|
||||
except Exception as e:
|
||||
log.error(f'Error reading {audio_path}: {e}')
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.clips)
|
||||
@@ -0,0 +1,338 @@
|
||||
""" Embedding Mixup
|
||||
Reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/mixup.py
|
||||
"""
|
||||
from typing import Literal, Tuple, Union, List, Optional
|
||||
from functools import partial
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from torchvision.transforms import v2
|
||||
from einops import rearrange
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from selva_core.data.vgg_sound import _SYNC_SIZE
|
||||
|
||||
|
||||
class MixupBase:
|
||||
""" Base class for mixup on either data or feature domain.
|
||||
Applies different params to each element or whole batch.
|
||||
|
||||
Args:
|
||||
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
||||
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
||||
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
||||
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
||||
prob (float): Probability of applying mixup per batch or element
|
||||
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
||||
eps (float): Small epsilon value to avoid zero lambda
|
||||
"""
|
||||
def __init__(self, generator:torch.Generator,
|
||||
*,
|
||||
modality:Literal['video', 'audio', 'both'],
|
||||
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
||||
mode:Literal['elem','pair','batch', 'half']='batch',
|
||||
eps:float=0.05
|
||||
):
|
||||
self.modality = modality
|
||||
self.mixup_lambda:float = mixup_lambda
|
||||
self.mixup_alpha:float = mixup_alpha
|
||||
self.mix_prob:float = prob
|
||||
self.mode:str = mode
|
||||
self.eps:float = eps
|
||||
self.mixup_enabled:bool = True # set to false to disable mixing (intended to be set by train loop)
|
||||
if generator.device.type == 'cuda':
|
||||
self.generator_cuda = generator
|
||||
generator_seed = generator.initial_seed()
|
||||
self.generator = torch.Generator(device='cpu')
|
||||
self.generator.manual_seed(generator_seed)
|
||||
else:
|
||||
self.generator = generator
|
||||
|
||||
if not (self.mixup_lambda >= 0. and self.mixup_lambda <= 1.):
|
||||
raise ValueError(f"mixup_lambda {self.mixup_lambda} should be in [0., 1.].")
|
||||
if not self.mixup_alpha >= 0.:
|
||||
raise ValueError(f"mixup_alpha {self.mixup_alpha} >= 0. should be true.")
|
||||
if (self.mixup_alpha > 0. and self.mixup_lambda < 1.) or (self.mixup_alpha == 0. and self.mixup_lambda == 1.):
|
||||
raise ValueError(f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true.")
|
||||
|
||||
def _params_per_elem(self, batch_size:int) -> np.ndarray:
|
||||
lam:np.ndarray = np.ones(batch_size, dtype=np.float32)
|
||||
if self.mixup_enabled:
|
||||
if self.mixup_lambda < 1.: # constant lambda
|
||||
lam_mix = np.full(batch_size, self.mixup_lambda, dtype=np.float32)
|
||||
elif self.mixup_alpha > 0.: # sampled lambda
|
||||
# Use torch's beta distribution with generator
|
||||
lam_mix = torch.distributions.Beta(
|
||||
torch.tensor([self.mixup_alpha]),
|
||||
torch.tensor([self.mixup_alpha]),
|
||||
).sample([batch_size]).numpy().astype(np.float32).reshape(-1)
|
||||
else:
|
||||
assert False, f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
|
||||
lam_mix[lam_mix < self.eps] = self.eps
|
||||
|
||||
# Use torch's random with generator for the random comparison
|
||||
rand_vals = torch.rand(batch_size, generator=self.generator).numpy()
|
||||
lam = np.where(rand_vals < self.mix_prob, lam_mix, lam)
|
||||
return lam
|
||||
|
||||
def _params_per_batch(self) -> float:
|
||||
lam:float = 1.
|
||||
if self.mixup_enabled:
|
||||
if self.mixup_lambda < 1.: # constant lambda
|
||||
lam = self.mixup_lambda
|
||||
elif self.mixup_alpha > 0.: # sampled lambda
|
||||
lam = torch.distributions.Beta(
|
||||
torch.tensor([self.mixup_alpha]),
|
||||
torch.tensor([self.mixup_alpha]),
|
||||
).sample().item()
|
||||
else:
|
||||
assert False, f"mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
|
||||
if lam < self.eps: lam = self.eps
|
||||
lam = float(lam)
|
||||
return lam
|
||||
|
||||
|
||||
class DataMixupCollate(MixupBase):
|
||||
""" Mixup video in data domain.
|
||||
Applies different params to each element or whole batch.
|
||||
|
||||
Args:
|
||||
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
||||
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
||||
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
||||
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
||||
prob (float): Probability of applying mixup per batch or element
|
||||
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
||||
eps (float): Small epsilon value to avoid zero lambda
|
||||
"""
|
||||
def __init__(self, generator:torch.Generator,
|
||||
*,
|
||||
modality:Literal['video', 'audio', 'both']='video',
|
||||
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
||||
mode:Literal['elem','pair','batch', 'half']='batch',
|
||||
eps:float=0.05
|
||||
):
|
||||
super().__init__(generator, modality=modality,
|
||||
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
|
||||
mode=mode, eps=eps)
|
||||
|
||||
self.source_video_key= 'sync_video'
|
||||
self.source_audio_key = 'audio'
|
||||
self.target_video_key = 'sync_video_mixed'
|
||||
self.target_audio_key = 'audio_mixed'
|
||||
|
||||
if not mode == 'batch':
|
||||
raise ValueError(f"Mode {mode} is not supported for data domain.")
|
||||
self.sync_transform = v2.Compose([
|
||||
v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
def _concat_video_frames(self, batch:list, target_key:str='sync_video_mixed', source_key:str='sync_video') -> float:
|
||||
# only batch mode supported
|
||||
batch_size:int = len(batch)
|
||||
lam:float = self._params_per_batch()
|
||||
|
||||
if lam == 1.:
|
||||
# no mixup, just return
|
||||
for i in range(batch_size):
|
||||
batch[i][target_key] = batch[i][source_key]
|
||||
return lam
|
||||
|
||||
# Randomly choose between horizontal and vertical resizing using
|
||||
orig_size = int(lam * _SYNC_SIZE)
|
||||
is_horizontal = True # torch.rand(1, generator=self.generator).item() < 0.5
|
||||
if is_horizontal:
|
||||
# Horizontal resize
|
||||
resize_shape_orig = (_SYNC_SIZE, orig_size)
|
||||
resize_shape_pair = (_SYNC_SIZE, _SYNC_SIZE-orig_size)
|
||||
else:
|
||||
# Vertical resize
|
||||
resize_shape_orig = (orig_size, _SYNC_SIZE)
|
||||
resize_shape_pair = (_SYNC_SIZE-orig_size, _SYNC_SIZE)
|
||||
sync_resize_orig = v2.Compose([
|
||||
v2.Resize(resize_shape_orig, interpolation=v2.InterpolationMode.BICUBIC),
|
||||
])
|
||||
sync_resize_pair = v2.Compose([
|
||||
v2.Resize(resize_shape_pair, interpolation=v2.InterpolationMode.BICUBIC),
|
||||
])
|
||||
|
||||
batch_videos_orig = torch.stack([batch[i][source_key] for i in range(batch_size)], dim=0)
|
||||
batch_videos_pair = torch.stack([batch[batch_size - i - 1][source_key] for i in range(batch_size)], dim=0)
|
||||
# (B, T, C, H, W)
|
||||
# pass through resize, transform and concat
|
||||
batch_videos_orig = sync_resize_orig(batch_videos_orig)
|
||||
batch_videos_pair = sync_resize_pair(batch_videos_pair)
|
||||
batch_videos_concat = torch.cat((batch_videos_orig, batch_videos_pair), dim=-1 if is_horizontal else -2)
|
||||
batch_videos_concat = self.sync_transform(batch_videos_concat)
|
||||
|
||||
num_mixup = int(self.mix_prob * batch_size)
|
||||
for i in range(num_mixup):
|
||||
batch[i][target_key] = batch_videos_concat[i]
|
||||
for i in range(num_mixup, batch_size):
|
||||
batch[i][target_key] = batch[i][source_key] # no mixup
|
||||
|
||||
del batch_videos_orig, batch_videos_pair, sync_resize_orig, sync_resize_pair
|
||||
gc.collect()
|
||||
|
||||
return lam
|
||||
|
||||
def _mix_audio_samples(self, batch:list, target_key:str='audio_mixed', source_key:str='audio',
|
||||
normalize:bool = True) -> float:
|
||||
# assume source_key audios are normalized
|
||||
batch_size:int = len(batch)
|
||||
lam:float = self._params_per_batch()
|
||||
|
||||
if lam == 1.:
|
||||
# no mixup, just return
|
||||
for i in range(batch_size):
|
||||
batch[i][target_key] = batch[i][source_key]
|
||||
return lam
|
||||
|
||||
num_mixup = int(self.mix_prob * batch_size)
|
||||
for i in range(num_mixup):
|
||||
batch[i][target_key] = batch[i][source_key] * lam + batch[batch_size - i - 1][source_key] * (1 - lam)
|
||||
if normalize:
|
||||
source_abs_max = batch[i][source_key].abs().max()
|
||||
target_abs_max = batch[i][target_key].abs().max()
|
||||
batch[i][target_key] = batch[i][target_key] * (source_abs_max / target_abs_max)
|
||||
for i in range(num_mixup, batch_size):
|
||||
batch[i][target_key] = batch[i][source_key] # no mixup
|
||||
|
||||
return lam
|
||||
|
||||
def __call__(self, batch:list, _=None) -> torch.tensor:
|
||||
batch_size:int = len(batch)
|
||||
assert batch_size % 2 == 0, f'Batch size {batch_size} should be even when using mixup'
|
||||
half = 'half' in self.mode
|
||||
if half:
|
||||
batch_size //= 2
|
||||
|
||||
if self.modality == 'video' or self.modality == 'both':
|
||||
lam = self._concat_video_frames(batch, target_key=self.target_video_key, source_key=self.source_video_key)
|
||||
if self.modality == 'audio' or self.modality == 'both':
|
||||
# raise NotImplementedError('Audio mixup is not implemented yet.')
|
||||
lam = self._mix_audio_samples(batch, target_key=self.target_audio_key, source_key=self.source_audio_key)
|
||||
|
||||
return default_collate(batch)
|
||||
|
||||
|
||||
class FeatureMixup(MixupBase):
|
||||
""" Mixup video in feature domain.
|
||||
Applies different params to each element or whole batch.
|
||||
|
||||
Args:
|
||||
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
||||
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
||||
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
||||
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
||||
prob (float): Probability of applying mixup per batch or element
|
||||
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
||||
eps (float): Small epsilon value to avoid zero lambda
|
||||
"""
|
||||
def __init__(self, generator:torch.Generator,
|
||||
*,
|
||||
modality:Literal['video', 'audio', 'both']='video',
|
||||
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
||||
mode:Literal['elem','pair','batch', 'half']='batch',
|
||||
eps:float=0.05
|
||||
):
|
||||
super().__init__(generator, modality=modality,
|
||||
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
|
||||
mode=mode, eps=eps)
|
||||
self.source_video_key= 'sync_f_vid_orig'
|
||||
self.source_audio_key = 'sync_f_aud_orig'
|
||||
self.target_video_key = 'sync_f_vid_mixed'
|
||||
self.target_audio_key = 'sync_f_aud_mixed'
|
||||
|
||||
def _mix_elem_collate(self, batch:dict,
|
||||
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig'],
|
||||
half:bool=False) -> torch.tensor:
|
||||
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
||||
batch_size:int = len(batch['id'])
|
||||
num_elem:int = batch_size // 2 if half else batch_size
|
||||
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(num_elem))
|
||||
|
||||
indices = torch.arange(num_elem)
|
||||
mix_indices = batch_size - indices - 1
|
||||
mix_mask = lam_batch < 1
|
||||
active_indices = indices[mix_mask]
|
||||
active_mix_indices = mix_indices[mix_mask]
|
||||
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
|
||||
for target_key, source_key in zip(target_keys, source_keys):
|
||||
batch[target_key][active_indices] = (
|
||||
batch[source_key][active_indices] * active_lambdas +
|
||||
batch[source_key][active_mix_indices] * (1 - active_lambdas)
|
||||
)
|
||||
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
|
||||
if half:
|
||||
lam_batch = torch.cat((lam_batch, torch.ones(num_elem, dtype=lam_batch.dtype)))
|
||||
return lam_batch.unsqueeze(1)
|
||||
|
||||
def _mix_pair_collate(self, batch:dict,
|
||||
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> torch.tensor:
|
||||
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
||||
batch_size:int = len(batch['id'])
|
||||
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(batch_size // 2))
|
||||
|
||||
indices = torch.arange(batch_size // 2)
|
||||
mix_indices = batch_size - indices - 1
|
||||
mix_mask = lam_batch < 1
|
||||
active_indices = indices[mix_mask]
|
||||
active_mix_indices = mix_indices[mix_mask]
|
||||
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
|
||||
for target_key, source_key in zip(target_keys, source_keys):
|
||||
batch[target_key][active_indices] = (
|
||||
batch[source_key][active_indices] * active_lambdas +
|
||||
batch[source_key][active_mix_indices] * (1 - active_lambdas)
|
||||
)
|
||||
batch[target_key][active_mix_indices] = (
|
||||
batch[source_key][active_mix_indices] * active_lambdas +
|
||||
batch[source_key][active_indices] * (1 - active_lambdas)
|
||||
)
|
||||
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
|
||||
batch[target_key][~mix_indices[mix_mask]] = batch[source_key][~mix_indices[mix_mask]]
|
||||
lam_batch = torch.cat((lam_batch, lam_batch.flip(0)))
|
||||
return lam_batch.unsqueeze(1)
|
||||
|
||||
def _mix_batch_collate(self, batch:dict,
|
||||
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> float:
|
||||
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
||||
lam:float = self._params_per_batch()
|
||||
|
||||
for target_key, source_key in zip(target_keys, source_keys):
|
||||
num_mixup = int(self.mix_prob * batch[source_key].shape[0])
|
||||
flipped_source = torch.flip(batch[source_key], dims=[0])
|
||||
batch[target_key] = batch[source_key] * lam + flipped_source * (1 - lam)
|
||||
batch[target_key][num_mixup:] = batch[source_key][num_mixup:] # no mixup
|
||||
return lam
|
||||
|
||||
def __call__(self, batch:dict, _=None) -> None:
|
||||
batch_size:int = len(batch['id'])
|
||||
assert batch_size % 2 == 0, f'Batch size(={batch_size}) should be even when using this'
|
||||
half = 'half' in self.mode
|
||||
if half:
|
||||
batch_size //= 2
|
||||
|
||||
# Mixup
|
||||
if self.mode == 'elem' or self.mode == 'half':
|
||||
collate_fn = partial(self._mix_elem_collate, half=half)
|
||||
elif self.mode == 'pair':
|
||||
collate_fn = self._mix_pair_collate
|
||||
else:
|
||||
collate_fn = self._mix_batch_collate
|
||||
|
||||
if self.modality == 'both':
|
||||
target_keys, source_keys = [self.target_video_key, self.target_audio_key], [self.source_video_key, self.source_audio_key]
|
||||
elif self.modality == 'video':
|
||||
target_keys, source_keys = [self.target_video_key], [self.source_video_key]
|
||||
elif self.modality == 'audio':
|
||||
target_keys, source_keys = [self.target_audio_key], [self.source_audio_key]
|
||||
lam = collate_fn(batch, target_keys=target_keys, source_keys=source_keys)
|
||||
|
||||
# return batch
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import bisect
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
|
||||
# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
|
||||
class MultiModalDataset(Dataset):
|
||||
datasets: list[Dataset]
|
||||
cumulative_sizes: list[int]
|
||||
|
||||
@staticmethod
|
||||
def cumsum(sequence):
|
||||
r, s = [], 0
|
||||
for e in sequence:
|
||||
l = len(e)
|
||||
r.append(l + s)
|
||||
s += l
|
||||
return r
|
||||
|
||||
def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
|
||||
super().__init__()
|
||||
self.video_datasets = list(video_datasets)
|
||||
self.audio_datasets = list(audio_datasets)
|
||||
self.datasets = self.video_datasets + self.audio_datasets
|
||||
|
||||
self.cumulative_sizes = self.cumsum(self.datasets)
|
||||
|
||||
def __len__(self):
|
||||
return self.cumulative_sizes[-1]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < 0:
|
||||
if -idx > len(self):
|
||||
raise ValueError("absolute value of index should not exceed dataset length")
|
||||
idx = len(self) + idx
|
||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
||||
if dataset_idx == 0:
|
||||
sample_idx = idx
|
||||
else:
|
||||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
||||
return self.datasets[dataset_idx][sample_idx]
|
||||
|
||||
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.video_datasets[0].compute_latent_stats()
|
||||
@@ -0,0 +1,148 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tensordict import MemoryMappedTensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from selva_core.utils.dist_utils import local_rank, world_size
|
||||
|
||||
scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
|
||||
shm_path = Path('/dev/shm')
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
def reseed(seed):
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def local_scatter_torch(obj: Optional[Any]):
|
||||
if world_size == 1:
|
||||
# Just one worker. Do nothing.
|
||||
return obj
|
||||
|
||||
array = [obj] * world_size
|
||||
target_array = [None]
|
||||
if local_rank == 0:
|
||||
dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
|
||||
else:
|
||||
dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
|
||||
return target_array[0]
|
||||
|
||||
|
||||
class ShardDataset(Dataset):
|
||||
|
||||
def __init__(self, root):
|
||||
self.root = root
|
||||
self.shards = sorted(os.listdir(root))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.shards)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
|
||||
|
||||
|
||||
def get_tmp_dir(in_memory: bool) -> Path:
|
||||
return shm_path if in_memory else scratch_path
|
||||
|
||||
|
||||
def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
|
||||
in_memory: bool) -> MemoryMappedTensor:
|
||||
if local_rank == 0:
|
||||
with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
|
||||
log.info(f'Loading shards from {data_path} into {f.name}...')
|
||||
data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
|
||||
data = share_tensor_to_all(data)
|
||||
torch.distributed.barrier()
|
||||
f.close() # why does the context manager not close the file for me?
|
||||
else:
|
||||
log.info('Waiting for the data to be shared with me...')
|
||||
data = share_tensor_to_all(None)
|
||||
torch.distributed.barrier()
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def load_shards(
|
||||
data_path: Union[str, Path],
|
||||
ids: list[int],
|
||||
*,
|
||||
tmp_file_path: str,
|
||||
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
|
||||
id_set = set(ids)
|
||||
shards = sorted(os.listdir(data_path))
|
||||
log.info(f'Found {len(shards)} shards in {data_path}.')
|
||||
first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
|
||||
|
||||
log.info(f'Rank {local_rank} created file {tmp_file_path}')
|
||||
first_item = next(iter(first_shard.values()))
|
||||
log.info(f'First item shape: {first_item.shape}')
|
||||
mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
|
||||
dtype=torch.float32,
|
||||
filename=tmp_file_path,
|
||||
existsok=True)
|
||||
total_count = 0
|
||||
used_index = set()
|
||||
id_indexing = {i: idx for idx, i in enumerate(ids)}
|
||||
# faster with no workers; otherwise we need to set_sharing_strategy('file_system')
|
||||
loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
|
||||
for data in tqdm(loader, desc='Loading shards'):
|
||||
for i, v in data.items():
|
||||
if i not in id_set:
|
||||
continue
|
||||
|
||||
# tensor_index = ids.index(i)
|
||||
tensor_index = id_indexing[i]
|
||||
if tensor_index in used_index:
|
||||
raise ValueError(f'Duplicate id {i} found in {data_path}.')
|
||||
used_index.add(tensor_index)
|
||||
mm_tensor[tensor_index] = v
|
||||
total_count += 1
|
||||
|
||||
assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
|
||||
log.info(f'Loaded {total_count} tensors from {data_path}.')
|
||||
|
||||
return mm_tensor
|
||||
|
||||
|
||||
def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
|
||||
"""
|
||||
x: the tensor to be shared; None if local_rank != 0
|
||||
return: the shared tensor
|
||||
"""
|
||||
|
||||
# there is no need to share your stuff with anyone if you are alone; must be in memory
|
||||
if world_size == 1:
|
||||
return x
|
||||
|
||||
if local_rank == 0:
|
||||
assert x is not None, 'x must not be None if local_rank == 0'
|
||||
else:
|
||||
assert x is None, 'x must be None if local_rank != 0'
|
||||
|
||||
if local_rank == 0:
|
||||
filename = x.filename
|
||||
meta_information = (filename, x.shape, x.dtype)
|
||||
else:
|
||||
meta_information = None
|
||||
|
||||
filename, data_shape, data_type = local_scatter_torch(meta_information)
|
||||
if local_rank == 0:
|
||||
data = x
|
||||
else:
|
||||
data = MemoryMappedTensor.from_filename(filename=filename,
|
||||
dtype=data_type,
|
||||
shape=data_shape)
|
||||
|
||||
return data
|
||||
@@ -0,0 +1,299 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torchvision.transforms import v2
|
||||
from torio.io import StreamingMediaDecoder
|
||||
from tensordict import TensorDict
|
||||
|
||||
from selva_core.data.av_utils import normalize_video_chunk
|
||||
from selva_core.utils.dist_utils import local_rank
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
_CLIP_SIZE = 384
|
||||
_CLIP_FPS = 8.0
|
||||
|
||||
_SYNC_SIZE = 224
|
||||
_SYNC_FPS = 25.0
|
||||
|
||||
|
||||
class VGGSound(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Union[str, Path],
|
||||
*,
|
||||
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
|
||||
for_generator: bool = True,
|
||||
audio_required: bool = False,
|
||||
sample_rate: int = 16_000,
|
||||
duration_sec: float = 8.0,
|
||||
audio_samples: Optional[int] = None,
|
||||
normalize_audio: bool = False,
|
||||
clip_video_required: bool = False,
|
||||
mmap_dir: Union[str, Path] = None,
|
||||
tsv_tsynch_path: Union[str, Path] = None,
|
||||
mmap_tsync_dir: Union[str, Path] = None,
|
||||
data_dim: dict[str, int] = None,
|
||||
):
|
||||
self.root = Path(root)
|
||||
self.audio_required = audio_required
|
||||
if audio_required:
|
||||
self.normalize_audio = normalize_audio
|
||||
if audio_samples is None:
|
||||
self.audio_samples = int(sample_rate * duration_sec)
|
||||
else:
|
||||
self.audio_samples = audio_samples
|
||||
effective_duration = audio_samples / sample_rate
|
||||
# make sure the duration is close enough, within 15ms
|
||||
assert abs(effective_duration - duration_sec) < 0.015, \
|
||||
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
||||
self.clip_video_required = clip_video_required
|
||||
self.for_generator = for_generator
|
||||
|
||||
videos = sorted(os.listdir(self.root))
|
||||
videos = set([Path(v).stem for v in videos]) # remove extensions
|
||||
self.labels = {}
|
||||
self.videos = []
|
||||
missing_videos = []
|
||||
|
||||
# read the tsv for subset information
|
||||
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
|
||||
for record in df_list:
|
||||
id = record['id']
|
||||
label = record['label']
|
||||
if id in videos:
|
||||
self.labels[id] = label
|
||||
self.videos.append(id)
|
||||
else:
|
||||
missing_videos.append(id)
|
||||
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {root}')
|
||||
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
||||
log.info(f'{len(missing_videos)} videos missing in {root}')
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.duration_sec = duration_sec
|
||||
|
||||
if audio_required:
|
||||
self.expected_audio_length = self.audio_samples
|
||||
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
||||
if clip_video_required:
|
||||
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
||||
|
||||
self.sync_transform = v2.Compose([
|
||||
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
# v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
if clip_video_required:
|
||||
self.clip_transform = v2.Compose([
|
||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
if audio_required:
|
||||
self.resampler = {}
|
||||
|
||||
# mmap
|
||||
log.info(f'Loading precomputed mmap from {mmap_dir}')
|
||||
mmap_dir = Path(mmap_dir)
|
||||
td = TensorDict.load_memmap(mmap_dir)
|
||||
log.info(f'Loaded precomputed mmap from {mmap_dir}')
|
||||
self.sync_features = td['sync_features']
|
||||
if for_generator:
|
||||
self.mean = td['mean']
|
||||
self.std = td['std']
|
||||
self.text_clip_features = td['text_features']
|
||||
if clip_video_required:
|
||||
self.clip_features = td['clip_features']
|
||||
else:
|
||||
self.clip_features = None
|
||||
self.id2idx_mmap = {d['id']: i for i, d in enumerate(df_list)}
|
||||
|
||||
mmap_tsync_dir = Path(mmap_tsync_dir)
|
||||
td_tsync = TensorDict.load_memmap(mmap_tsync_dir)
|
||||
log.info(f'Loaded precomputed tsync mmap from {mmap_tsync_dir}')
|
||||
self.text_features = td_tsync['text_features']
|
||||
self.text_masks = td_tsync['text_masks']
|
||||
df_list_tsync = pd.read_csv(tsv_tsynch_path, sep='\t').to_dict('records')
|
||||
self.id2idx_mmap_tsync = {d['id']: i for i, d in enumerate(df_list_tsync)}
|
||||
|
||||
if local_rank == 0:
|
||||
log.info(f'Loaded {len(self)} samples.')
|
||||
log.info(f'Loaded sync_features: {self.sync_features.shape}.')
|
||||
log.info(f'Loaded text_features: {self.text_features.shape}.')
|
||||
log.info(f'Loaded text_masks: {self.text_masks.shape}.')
|
||||
if for_generator:
|
||||
log.info(f'Loaded mean: {self.mean.shape}.')
|
||||
log.info(f'Loaded std: {self.std.shape}.')
|
||||
log.info(f'Loaded text_clip_features: {self.text_clip_features.shape}.')
|
||||
if clip_video_required:
|
||||
log.info(f'Loaded clip_features: {self.clip_features.shape}.')
|
||||
|
||||
assert self.sync_features.shape[1] == data_dim['sync_seq_len'], \
|
||||
f'{self.sync_features.shape[1]} != {data_dim["sync_seq_len"]}'
|
||||
assert self.text_features.shape[1] <= data_dim['text_flant5_max_seq_len'], \
|
||||
f'{self.text_features.shape[1]} > {data_dim["text_flant5_max_seq_len"]}'
|
||||
assert self.text_masks.shape[1] <= data_dim['text_flant5_max_seq_len'], \
|
||||
f'{self.text_masks.shape[1]} > {data_dim["text_flant5_max_seq_len"]}'
|
||||
assert self.sync_features.shape[-1] == data_dim['sync_dim'], \
|
||||
f'{self.sync_features.shape[-1]} != {data_dim["sync_dim"]}'
|
||||
assert self.text_features.shape[-1] == data_dim['text_flant5_dim'], \
|
||||
f'{self.text_features.shape[-1]} != {data_dim["text_flant5_dim"]}'
|
||||
if for_generator:
|
||||
assert self.mean.shape[1] == data_dim['latent_seq_len'], \
|
||||
f'{self.mean.shape[1]} != {data_dim["latent_seq_len"]}'
|
||||
assert self.std.shape[1] == data_dim['latent_seq_len'], \
|
||||
f'{self.std.shape[1]} != {data_dim["latent_seq_len"]}'
|
||||
assert self.text_clip_features.shape[1] == data_dim['text_clip_seq_len'], \
|
||||
f'{self.text_clip_features.shape[1]} != {data_dim["text_clip_seq_len"]}'
|
||||
assert self.text_clip_features.shape[-1] == data_dim['text_clip_dim'], \
|
||||
f'{self.text_clip_features.shape[-1]} != {data_dim["text_clip_dim"]}'
|
||||
if clip_video_required:
|
||||
assert self.clip_features.shape[1] == data_dim['clip_seq_len'], \
|
||||
f'{self.clip_features.shape[1]} != {data_dim["clip_seq_len"]}'
|
||||
assert self.clip_features.shape[-1] == data_dim['clip_dim'], \
|
||||
f'{self.clip_features.shape[-1]} != {data_dim["clip_dim"]}'
|
||||
|
||||
self.video_exist = torch.tensor(1, dtype=torch.bool)
|
||||
self.text_exist = torch.tensor(1, dtype=torch.bool)
|
||||
|
||||
|
||||
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: # mmap
|
||||
latents = self.mean
|
||||
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
|
||||
|
||||
def get_memory_mapped_tensor(self) -> TensorDict:
|
||||
td = TensorDict({
|
||||
'sync_features': self.sync_features,
|
||||
'text_features': self.text_features,
|
||||
'text_masks': self.text_masks,
|
||||
})
|
||||
if self.for_generator:
|
||||
td['mean'] = self.mean
|
||||
td['std'] = self.std
|
||||
td['text_clip_features'] = self.text_clip_features
|
||||
if self.clip_video_required:
|
||||
td['clip_features'] = self.clip_features
|
||||
return td
|
||||
|
||||
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
video_id = self.videos[idx]
|
||||
|
||||
if video_id in self.captions and torch.rand(1).item() < self.autoacd_sample_prob:
|
||||
label = self.captions[video_id]
|
||||
else:
|
||||
label = self.labels[video_id]
|
||||
|
||||
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
||||
frame_rate=_SYNC_FPS,
|
||||
format='rgb24',
|
||||
)
|
||||
if self.audio_required:
|
||||
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
|
||||
if self.clip_video_required:
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
||||
frame_rate=_CLIP_FPS,
|
||||
format='rgb24',
|
||||
)
|
||||
|
||||
reader.fill_buffer()
|
||||
data_chunk = reader.pop_chunks()
|
||||
|
||||
sync_chunk = data_chunk[0]
|
||||
if sync_chunk is None:
|
||||
raise RuntimeError(f'Sync video returned None {video_id}')
|
||||
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
|
||||
n_tolerance_frame=3, desc=video_id)
|
||||
sync_chunk = self.sync_transform(sync_chunk)
|
||||
|
||||
if self.audio_required:
|
||||
audio_chunk = data_chunk[1]
|
||||
|
||||
if self.clip_video_required:
|
||||
clip_chunk = data_chunk[2 if self.audio_required else 1]
|
||||
if clip_chunk is None:
|
||||
raise RuntimeError(f'CLIP video returned None {video_id}')
|
||||
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
|
||||
n_tolerance_frame=1, desc=video_id)
|
||||
clip_chunk = self.clip_transform(clip_chunk)
|
||||
|
||||
# process audio
|
||||
if self.audio_required:
|
||||
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
|
||||
audio_chunk = audio_chunk.transpose(0, 1)
|
||||
audio_chunk = audio_chunk.mean(dim=0) # mono
|
||||
if self.normalize_audio:
|
||||
abs_max = audio_chunk.abs().max()
|
||||
audio_chunk = audio_chunk * (0.95 / abs_max)
|
||||
if abs_max <= 1e-6:
|
||||
raise RuntimeError(f'Audio is silent {video_id}')
|
||||
|
||||
# resample
|
||||
if sample_rate == self.sample_rate:
|
||||
audio_chunk = audio_chunk
|
||||
else:
|
||||
if sample_rate not in self.resampler:
|
||||
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
||||
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
||||
sample_rate,
|
||||
self.sample_rate,
|
||||
lowpass_filter_width=64,
|
||||
rolloff=0.9475937167399596,
|
||||
resampling_method='sinc_interp_kaiser',
|
||||
beta=14.769656459379492,
|
||||
)
|
||||
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
||||
|
||||
if audio_chunk.shape[0] < self.expected_audio_length:
|
||||
raise RuntimeError(f'Audio too short {video_id}')
|
||||
audio_chunk = audio_chunk[:self.expected_audio_length]
|
||||
|
||||
data = {
|
||||
'id': video_id,
|
||||
'caption': label,
|
||||
'sync_video': sync_chunk,
|
||||
'sync_f_vid_orig': self.sync_features[self.id2idx_mmap[video_id]],
|
||||
'text_features': self.text_features[self.id2idx_mmap_tsync[video_id]],
|
||||
'text_masks': self.text_masks[self.id2idx_mmap_tsync[video_id]],
|
||||
'video_exist': self.video_exist,
|
||||
'text_exist': self.text_exist,
|
||||
}
|
||||
|
||||
if self.for_generator:
|
||||
data['a_mean'] = self.mean[self.id2idx_mmap[video_id]]
|
||||
data['a_std'] = self.std[self.id2idx_mmap[video_id]]
|
||||
data['text_clip_features'] = self.text_clip_features[self.id2idx_mmap[video_id]]
|
||||
|
||||
if self.audio_required:
|
||||
data['audio'] = audio_chunk
|
||||
|
||||
if self.clip_video_required:
|
||||
data['clip_video'] = clip_chunk
|
||||
data['clip_features'] = self.clip_features[self.id2idx_mmap[video_id]],
|
||||
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
try:
|
||||
return self.sample(idx)
|
||||
except Exception as e:
|
||||
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
Reference in New Issue
Block a user