chore: vendor selva_core from jnwnlee/selva@d7d40a9

Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes.
Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced
librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy
version incompatibility in some ComfyUI environments.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-04 15:18:09 +02:00
parent 762b19fd3a
commit 6bc3fd6443
106 changed files with 11323 additions and 0 deletions
+338
View File
@@ -0,0 +1,338 @@
""" Embedding Mixup
Reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/mixup.py
"""
from typing import Literal, Tuple, Union, List, Optional
from functools import partial
import gc
import numpy as np
import torch
from torch.utils.data.dataloader import default_collate
from torchvision.transforms import v2
from einops import rearrange
from omegaconf import DictConfig
from selva_core.data.vgg_sound import _SYNC_SIZE
class MixupBase:
""" Base class for mixup on either data or feature domain.
Applies different params to each element or whole batch.
Args:
generator (Optional[torch.Generator]): Random number generator for reproducibility
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
prob (float): Probability of applying mixup per batch or element
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
eps (float): Small epsilon value to avoid zero lambda
"""
def __init__(self, generator:torch.Generator,
*,
modality:Literal['video', 'audio', 'both'],
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
mode:Literal['elem','pair','batch', 'half']='batch',
eps:float=0.05
):
self.modality = modality
self.mixup_lambda:float = mixup_lambda
self.mixup_alpha:float = mixup_alpha
self.mix_prob:float = prob
self.mode:str = mode
self.eps:float = eps
self.mixup_enabled:bool = True # set to false to disable mixing (intended to be set by train loop)
if generator.device.type == 'cuda':
self.generator_cuda = generator
generator_seed = generator.initial_seed()
self.generator = torch.Generator(device='cpu')
self.generator.manual_seed(generator_seed)
else:
self.generator = generator
if not (self.mixup_lambda >= 0. and self.mixup_lambda <= 1.):
raise ValueError(f"mixup_lambda {self.mixup_lambda} should be in [0., 1.].")
if not self.mixup_alpha >= 0.:
raise ValueError(f"mixup_alpha {self.mixup_alpha} >= 0. should be true.")
if (self.mixup_alpha > 0. and self.mixup_lambda < 1.) or (self.mixup_alpha == 0. and self.mixup_lambda == 1.):
raise ValueError(f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true.")
def _params_per_elem(self, batch_size:int) -> np.ndarray:
lam:np.ndarray = np.ones(batch_size, dtype=np.float32)
if self.mixup_enabled:
if self.mixup_lambda < 1.: # constant lambda
lam_mix = np.full(batch_size, self.mixup_lambda, dtype=np.float32)
elif self.mixup_alpha > 0.: # sampled lambda
# Use torch's beta distribution with generator
lam_mix = torch.distributions.Beta(
torch.tensor([self.mixup_alpha]),
torch.tensor([self.mixup_alpha]),
).sample([batch_size]).numpy().astype(np.float32).reshape(-1)
else:
assert False, f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
lam_mix[lam_mix < self.eps] = self.eps
# Use torch's random with generator for the random comparison
rand_vals = torch.rand(batch_size, generator=self.generator).numpy()
lam = np.where(rand_vals < self.mix_prob, lam_mix, lam)
return lam
def _params_per_batch(self) -> float:
lam:float = 1.
if self.mixup_enabled:
if self.mixup_lambda < 1.: # constant lambda
lam = self.mixup_lambda
elif self.mixup_alpha > 0.: # sampled lambda
lam = torch.distributions.Beta(
torch.tensor([self.mixup_alpha]),
torch.tensor([self.mixup_alpha]),
).sample().item()
else:
assert False, f"mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
if lam < self.eps: lam = self.eps
lam = float(lam)
return lam
class DataMixupCollate(MixupBase):
""" Mixup video in data domain.
Applies different params to each element or whole batch.
Args:
generator (Optional[torch.Generator]): Random number generator for reproducibility
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
prob (float): Probability of applying mixup per batch or element
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
eps (float): Small epsilon value to avoid zero lambda
"""
def __init__(self, generator:torch.Generator,
*,
modality:Literal['video', 'audio', 'both']='video',
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
mode:Literal['elem','pair','batch', 'half']='batch',
eps:float=0.05
):
super().__init__(generator, modality=modality,
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
mode=mode, eps=eps)
self.source_video_key= 'sync_video'
self.source_audio_key = 'audio'
self.target_video_key = 'sync_video_mixed'
self.target_audio_key = 'audio_mixed'
if not mode == 'batch':
raise ValueError(f"Mode {mode} is not supported for data domain.")
self.sync_transform = v2.Compose([
v2.CenterCrop(_SYNC_SIZE),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def _concat_video_frames(self, batch:list, target_key:str='sync_video_mixed', source_key:str='sync_video') -> float:
# only batch mode supported
batch_size:int = len(batch)
lam:float = self._params_per_batch()
if lam == 1.:
# no mixup, just return
for i in range(batch_size):
batch[i][target_key] = batch[i][source_key]
return lam
# Randomly choose between horizontal and vertical resizing using
orig_size = int(lam * _SYNC_SIZE)
is_horizontal = True # torch.rand(1, generator=self.generator).item() < 0.5
if is_horizontal:
# Horizontal resize
resize_shape_orig = (_SYNC_SIZE, orig_size)
resize_shape_pair = (_SYNC_SIZE, _SYNC_SIZE-orig_size)
else:
# Vertical resize
resize_shape_orig = (orig_size, _SYNC_SIZE)
resize_shape_pair = (_SYNC_SIZE-orig_size, _SYNC_SIZE)
sync_resize_orig = v2.Compose([
v2.Resize(resize_shape_orig, interpolation=v2.InterpolationMode.BICUBIC),
])
sync_resize_pair = v2.Compose([
v2.Resize(resize_shape_pair, interpolation=v2.InterpolationMode.BICUBIC),
])
batch_videos_orig = torch.stack([batch[i][source_key] for i in range(batch_size)], dim=0)
batch_videos_pair = torch.stack([batch[batch_size - i - 1][source_key] for i in range(batch_size)], dim=0)
# (B, T, C, H, W)
# pass through resize, transform and concat
batch_videos_orig = sync_resize_orig(batch_videos_orig)
batch_videos_pair = sync_resize_pair(batch_videos_pair)
batch_videos_concat = torch.cat((batch_videos_orig, batch_videos_pair), dim=-1 if is_horizontal else -2)
batch_videos_concat = self.sync_transform(batch_videos_concat)
num_mixup = int(self.mix_prob * batch_size)
for i in range(num_mixup):
batch[i][target_key] = batch_videos_concat[i]
for i in range(num_mixup, batch_size):
batch[i][target_key] = batch[i][source_key] # no mixup
del batch_videos_orig, batch_videos_pair, sync_resize_orig, sync_resize_pair
gc.collect()
return lam
def _mix_audio_samples(self, batch:list, target_key:str='audio_mixed', source_key:str='audio',
normalize:bool = True) -> float:
# assume source_key audios are normalized
batch_size:int = len(batch)
lam:float = self._params_per_batch()
if lam == 1.:
# no mixup, just return
for i in range(batch_size):
batch[i][target_key] = batch[i][source_key]
return lam
num_mixup = int(self.mix_prob * batch_size)
for i in range(num_mixup):
batch[i][target_key] = batch[i][source_key] * lam + batch[batch_size - i - 1][source_key] * (1 - lam)
if normalize:
source_abs_max = batch[i][source_key].abs().max()
target_abs_max = batch[i][target_key].abs().max()
batch[i][target_key] = batch[i][target_key] * (source_abs_max / target_abs_max)
for i in range(num_mixup, batch_size):
batch[i][target_key] = batch[i][source_key] # no mixup
return lam
def __call__(self, batch:list, _=None) -> torch.tensor:
batch_size:int = len(batch)
assert batch_size % 2 == 0, f'Batch size {batch_size} should be even when using mixup'
half = 'half' in self.mode
if half:
batch_size //= 2
if self.modality == 'video' or self.modality == 'both':
lam = self._concat_video_frames(batch, target_key=self.target_video_key, source_key=self.source_video_key)
if self.modality == 'audio' or self.modality == 'both':
# raise NotImplementedError('Audio mixup is not implemented yet.')
lam = self._mix_audio_samples(batch, target_key=self.target_audio_key, source_key=self.source_audio_key)
return default_collate(batch)
class FeatureMixup(MixupBase):
""" Mixup video in feature domain.
Applies different params to each element or whole batch.
Args:
generator (Optional[torch.Generator]): Random number generator for reproducibility
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
prob (float): Probability of applying mixup per batch or element
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
eps (float): Small epsilon value to avoid zero lambda
"""
def __init__(self, generator:torch.Generator,
*,
modality:Literal['video', 'audio', 'both']='video',
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
mode:Literal['elem','pair','batch', 'half']='batch',
eps:float=0.05
):
super().__init__(generator, modality=modality,
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
mode=mode, eps=eps)
self.source_video_key= 'sync_f_vid_orig'
self.source_audio_key = 'sync_f_aud_orig'
self.target_video_key = 'sync_f_vid_mixed'
self.target_audio_key = 'sync_f_aud_mixed'
def _mix_elem_collate(self, batch:dict,
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig'],
half:bool=False) -> torch.tensor:
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
batch_size:int = len(batch['id'])
num_elem:int = batch_size // 2 if half else batch_size
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(num_elem))
indices = torch.arange(num_elem)
mix_indices = batch_size - indices - 1
mix_mask = lam_batch < 1
active_indices = indices[mix_mask]
active_mix_indices = mix_indices[mix_mask]
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
for target_key, source_key in zip(target_keys, source_keys):
batch[target_key][active_indices] = (
batch[source_key][active_indices] * active_lambdas +
batch[source_key][active_mix_indices] * (1 - active_lambdas)
)
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
if half:
lam_batch = torch.cat((lam_batch, torch.ones(num_elem, dtype=lam_batch.dtype)))
return lam_batch.unsqueeze(1)
def _mix_pair_collate(self, batch:dict,
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> torch.tensor:
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
batch_size:int = len(batch['id'])
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(batch_size // 2))
indices = torch.arange(batch_size // 2)
mix_indices = batch_size - indices - 1
mix_mask = lam_batch < 1
active_indices = indices[mix_mask]
active_mix_indices = mix_indices[mix_mask]
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
for target_key, source_key in zip(target_keys, source_keys):
batch[target_key][active_indices] = (
batch[source_key][active_indices] * active_lambdas +
batch[source_key][active_mix_indices] * (1 - active_lambdas)
)
batch[target_key][active_mix_indices] = (
batch[source_key][active_mix_indices] * active_lambdas +
batch[source_key][active_indices] * (1 - active_lambdas)
)
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
batch[target_key][~mix_indices[mix_mask]] = batch[source_key][~mix_indices[mix_mask]]
lam_batch = torch.cat((lam_batch, lam_batch.flip(0)))
return lam_batch.unsqueeze(1)
def _mix_batch_collate(self, batch:dict,
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> float:
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
lam:float = self._params_per_batch()
for target_key, source_key in zip(target_keys, source_keys):
num_mixup = int(self.mix_prob * batch[source_key].shape[0])
flipped_source = torch.flip(batch[source_key], dims=[0])
batch[target_key] = batch[source_key] * lam + flipped_source * (1 - lam)
batch[target_key][num_mixup:] = batch[source_key][num_mixup:] # no mixup
return lam
def __call__(self, batch:dict, _=None) -> None:
batch_size:int = len(batch['id'])
assert batch_size % 2 == 0, f'Batch size(={batch_size}) should be even when using this'
half = 'half' in self.mode
if half:
batch_size //= 2
# Mixup
if self.mode == 'elem' or self.mode == 'half':
collate_fn = partial(self._mix_elem_collate, half=half)
elif self.mode == 'pair':
collate_fn = self._mix_pair_collate
else:
collate_fn = self._mix_batch_collate
if self.modality == 'both':
target_keys, source_keys = [self.target_video_key, self.target_audio_key], [self.source_video_key, self.source_audio_key]
elif self.modality == 'video':
target_keys, source_keys = [self.target_video_key], [self.source_video_key]
elif self.modality == 'audio':
target_keys, source_keys = [self.target_audio_key], [self.source_audio_key]
lam = collate_fn(batch, target_keys=target_keys, source_keys=source_keys)
# return batch