chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,227 @@
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from selva_core.data.vgg_sound import VGGSound
|
||||
from selva_core.data.eval.eval_video_dataset import VGGSound as VGGSoundEval
|
||||
from selva_core.data.eval.eval_video_dataset import InferenceVideoData, VGGMonoAudioBench
|
||||
from selva_core.data.eval.audiocaps import AudioCapsData
|
||||
from selva_core.data.mm_dataset import MultiModalDataset
|
||||
from selva_core.data.mixup import DataMixupCollate
|
||||
from selva_core.utils.dist_utils import local_rank
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
# Re-seed randomness every time we start a worker
|
||||
def worker_init_fn(worker_id: int):
|
||||
worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
|
||||
|
||||
|
||||
def load_video_data(cfg: DictConfig, data_cfg: DictConfig, normalize_audio: bool = False,
|
||||
) -> Dataset:
|
||||
dataset = VGGSound(root=data_cfg.root,
|
||||
tsv_path=data_cfg.subset_name,
|
||||
sample_rate=16_000,
|
||||
duration_sec=8.0,
|
||||
normalize_audio=normalize_audio,
|
||||
mmap_dir=data_cfg.memmap_dir,
|
||||
tsv_tsynch_path=data_cfg.tsv_tsynch,
|
||||
mmap_tsync_dir=data_cfg.memmap_dir_tsynch,
|
||||
data_dim=cfg.data_dim
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
||||
raise NotImplementedError('Audio data loading is not implemented yet')
|
||||
|
||||
|
||||
def setup_training_datasets(cfg: DictConfig,
|
||||
generator: torch.Generator,
|
||||
) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
||||
if cfg.mini_train:
|
||||
vgg = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=True)
|
||||
dataset = MultiModalDataset([vgg], [])
|
||||
if cfg.example_train:
|
||||
video = load_video_data(cfg, cfg.data.Example_video, normalize_audio=True)
|
||||
dataset = MultiModalDataset([video], [])
|
||||
else:
|
||||
vgg = load_video_data(cfg, cfg.data.VGGSound, normalize_audio=True)
|
||||
# load the largest one first
|
||||
# you can add more video/audio data upon demand, such as
|
||||
# clotho = load_audio_data(cfg, cfg.data.Clotho)
|
||||
dataset = MultiModalDataset([vgg], [])
|
||||
|
||||
batch_size = cfg.batch_size
|
||||
num_workers = cfg.num_workers
|
||||
pin_memory = cfg.pin_memory
|
||||
|
||||
if cfg.mixup.domain == 'data':
|
||||
mixup_params = cfg.mixup.params
|
||||
collate_fn = DataMixupCollate(generator=generator,
|
||||
**mixup_params)
|
||||
else:
|
||||
collate_fn = None
|
||||
|
||||
sampler, loader = construct_loader(dataset,
|
||||
batch_size,
|
||||
num_workers,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
return dataset, sampler, loader
|
||||
|
||||
|
||||
def setup_test_datasets(cfg: DictConfig,
|
||||
generator: torch.Generator,
|
||||
) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
||||
if cfg.example_train:
|
||||
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False, split='test')
|
||||
elif cfg.dataset.startswith('vggsound'):
|
||||
dataset = load_video_data(cfg, cfg.data.VGGSound_test, normalize_audio=False, split='test')
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown dataset for test: {cfg.dataset}')
|
||||
|
||||
batch_size = cfg.batch_size
|
||||
num_workers = cfg.get('num_workers_val', cfg.num_workers)
|
||||
pin_memory = cfg.pin_memory
|
||||
|
||||
if cfg.mixup.domain == 'data':
|
||||
mixup_config = cfg.mixup.params
|
||||
collate_fn = DataMixupCollate(generator=generator,
|
||||
**mixup_config)
|
||||
else:
|
||||
collate_fn = None
|
||||
|
||||
sampler, loader = construct_loader(dataset,
|
||||
batch_size,
|
||||
num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
return dataset, sampler, loader
|
||||
|
||||
|
||||
def setup_val_datasets(cfg: DictConfig,
|
||||
generator: torch.Generator,
|
||||
) -> tuple[Dataset, DataLoader, DataLoader]:
|
||||
if cfg.example_train:
|
||||
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False)
|
||||
else:
|
||||
dataset = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=False)
|
||||
|
||||
val_batch_size = cfg.batch_size
|
||||
val_eval_batch_size = cfg.eval_batch_size
|
||||
num_workers = cfg.get('num_workers_val', cfg.num_workers)
|
||||
pin_memory = cfg.pin_memory
|
||||
|
||||
if cfg.mixup.domain == 'data':
|
||||
mixup_config = cfg.mixup.params
|
||||
collate_fn = DataMixupCollate(generator=generator,
|
||||
**mixup_config)
|
||||
else:
|
||||
collate_fn = None
|
||||
|
||||
_, val_loader = construct_loader(dataset,
|
||||
val_batch_size,
|
||||
num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=collate_fn)
|
||||
_, eval_loader = construct_loader(dataset,
|
||||
val_eval_batch_size,
|
||||
num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
return dataset, val_loader, eval_loader
|
||||
|
||||
|
||||
def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
|
||||
if dataset_name.startswith('audiocaps_full'):
|
||||
dataset = AudioCapsData(cfg.eval_data.audiocaps_full.audio_path,
|
||||
cfg.eval_data.audiocaps_full.csv_path)
|
||||
elif dataset_name.startswith('audiocaps'):
|
||||
dataset = AudioCapsData(cfg.eval_data.audiocaps.audio_path,
|
||||
cfg.eval_data.audiocaps.csv_path)
|
||||
elif dataset_name.startswith('vggsound'):
|
||||
dataset = VGGSound(cfg.eval_data.vggsound.video_path,
|
||||
cfg.eval_data.vggsound.csv_path,
|
||||
duration_sec=cfg.duration_s)
|
||||
elif dataset_name.startswith('infer_video'):
|
||||
dataset = InferenceVideoData(cfg.eval_data.infer_video.video_path,
|
||||
cfg.eval_data.infer_video.jsonl_path,
|
||||
duration_sec=cfg.duration_s)
|
||||
cfg.batch_size = 1
|
||||
elif dataset_name.startswith('example_video'):
|
||||
dataset = VGGSoundEval(cfg.eval_data.Example_video.video_path,
|
||||
cfg.eval_data.Example_video.csv_path,
|
||||
duration_sec=cfg.duration_s)
|
||||
elif dataset_name in ['vgg_monoaudio_intra', 'vgg_monoaudio_inter']:
|
||||
dataset = VGGMonoAudioBench(cfg.eval_data[dataset_name].video_path,
|
||||
cfg.eval_data[dataset_name].csv_path,
|
||||
duration_sec=cfg.duration_s)
|
||||
|
||||
else:
|
||||
raise ValueError(f'Invalid dataset name: {dataset_name}')
|
||||
|
||||
batch_size = cfg.batch_size
|
||||
num_workers = cfg.num_workers
|
||||
pin_memory = cfg.pin_memory
|
||||
_, loader = construct_loader(dataset,
|
||||
batch_size,
|
||||
num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=pin_memory,
|
||||
error_avoidance=True)
|
||||
return dataset, loader
|
||||
|
||||
|
||||
def error_avoidance_collate(batch):
|
||||
# Filter our None values
|
||||
batch = [item for item in batch if item is not None]
|
||||
if len(batch) == 0:
|
||||
return None
|
||||
return default_collate(batch)
|
||||
|
||||
|
||||
def construct_loader(dataset: Dataset,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
drop_last: bool = True,
|
||||
pin_memory: bool = False,
|
||||
error_avoidance: bool = False,
|
||||
collate_fn = None) -> tuple[DistributedSampler, DataLoader]:
|
||||
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
|
||||
train_loader = DataLoader(dataset,
|
||||
batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=num_workers,
|
||||
worker_init_fn=worker_init_fn,
|
||||
drop_last=drop_last,
|
||||
persistent_workers=num_workers > 0,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=error_avoidance_collate if error_avoidance else collate_fn)
|
||||
return train_sampler, train_loader
|
||||
Reference in New Issue
Block a user