chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,338 @@
|
||||
""" Embedding Mixup
|
||||
Reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/mixup.py
|
||||
"""
|
||||
from typing import Literal, Tuple, Union, List, Optional
|
||||
from functools import partial
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from torchvision.transforms import v2
|
||||
from einops import rearrange
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from selva_core.data.vgg_sound import _SYNC_SIZE
|
||||
|
||||
|
||||
class MixupBase:
|
||||
""" Base class for mixup on either data or feature domain.
|
||||
Applies different params to each element or whole batch.
|
||||
|
||||
Args:
|
||||
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
||||
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
||||
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
||||
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
||||
prob (float): Probability of applying mixup per batch or element
|
||||
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
||||
eps (float): Small epsilon value to avoid zero lambda
|
||||
"""
|
||||
def __init__(self, generator:torch.Generator,
|
||||
*,
|
||||
modality:Literal['video', 'audio', 'both'],
|
||||
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
||||
mode:Literal['elem','pair','batch', 'half']='batch',
|
||||
eps:float=0.05
|
||||
):
|
||||
self.modality = modality
|
||||
self.mixup_lambda:float = mixup_lambda
|
||||
self.mixup_alpha:float = mixup_alpha
|
||||
self.mix_prob:float = prob
|
||||
self.mode:str = mode
|
||||
self.eps:float = eps
|
||||
self.mixup_enabled:bool = True # set to false to disable mixing (intended to be set by train loop)
|
||||
if generator.device.type == 'cuda':
|
||||
self.generator_cuda = generator
|
||||
generator_seed = generator.initial_seed()
|
||||
self.generator = torch.Generator(device='cpu')
|
||||
self.generator.manual_seed(generator_seed)
|
||||
else:
|
||||
self.generator = generator
|
||||
|
||||
if not (self.mixup_lambda >= 0. and self.mixup_lambda <= 1.):
|
||||
raise ValueError(f"mixup_lambda {self.mixup_lambda} should be in [0., 1.].")
|
||||
if not self.mixup_alpha >= 0.:
|
||||
raise ValueError(f"mixup_alpha {self.mixup_alpha} >= 0. should be true.")
|
||||
if (self.mixup_alpha > 0. and self.mixup_lambda < 1.) or (self.mixup_alpha == 0. and self.mixup_lambda == 1.):
|
||||
raise ValueError(f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true.")
|
||||
|
||||
def _params_per_elem(self, batch_size:int) -> np.ndarray:
|
||||
lam:np.ndarray = np.ones(batch_size, dtype=np.float32)
|
||||
if self.mixup_enabled:
|
||||
if self.mixup_lambda < 1.: # constant lambda
|
||||
lam_mix = np.full(batch_size, self.mixup_lambda, dtype=np.float32)
|
||||
elif self.mixup_alpha > 0.: # sampled lambda
|
||||
# Use torch's beta distribution with generator
|
||||
lam_mix = torch.distributions.Beta(
|
||||
torch.tensor([self.mixup_alpha]),
|
||||
torch.tensor([self.mixup_alpha]),
|
||||
).sample([batch_size]).numpy().astype(np.float32).reshape(-1)
|
||||
else:
|
||||
assert False, f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
|
||||
lam_mix[lam_mix < self.eps] = self.eps
|
||||
|
||||
# Use torch's random with generator for the random comparison
|
||||
rand_vals = torch.rand(batch_size, generator=self.generator).numpy()
|
||||
lam = np.where(rand_vals < self.mix_prob, lam_mix, lam)
|
||||
return lam
|
||||
|
||||
def _params_per_batch(self) -> float:
|
||||
lam:float = 1.
|
||||
if self.mixup_enabled:
|
||||
if self.mixup_lambda < 1.: # constant lambda
|
||||
lam = self.mixup_lambda
|
||||
elif self.mixup_alpha > 0.: # sampled lambda
|
||||
lam = torch.distributions.Beta(
|
||||
torch.tensor([self.mixup_alpha]),
|
||||
torch.tensor([self.mixup_alpha]),
|
||||
).sample().item()
|
||||
else:
|
||||
assert False, f"mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
|
||||
if lam < self.eps: lam = self.eps
|
||||
lam = float(lam)
|
||||
return lam
|
||||
|
||||
|
||||
class DataMixupCollate(MixupBase):
|
||||
""" Mixup video in data domain.
|
||||
Applies different params to each element or whole batch.
|
||||
|
||||
Args:
|
||||
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
||||
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
||||
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
||||
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
||||
prob (float): Probability of applying mixup per batch or element
|
||||
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
||||
eps (float): Small epsilon value to avoid zero lambda
|
||||
"""
|
||||
def __init__(self, generator:torch.Generator,
|
||||
*,
|
||||
modality:Literal['video', 'audio', 'both']='video',
|
||||
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
||||
mode:Literal['elem','pair','batch', 'half']='batch',
|
||||
eps:float=0.05
|
||||
):
|
||||
super().__init__(generator, modality=modality,
|
||||
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
|
||||
mode=mode, eps=eps)
|
||||
|
||||
self.source_video_key= 'sync_video'
|
||||
self.source_audio_key = 'audio'
|
||||
self.target_video_key = 'sync_video_mixed'
|
||||
self.target_audio_key = 'audio_mixed'
|
||||
|
||||
if not mode == 'batch':
|
||||
raise ValueError(f"Mode {mode} is not supported for data domain.")
|
||||
self.sync_transform = v2.Compose([
|
||||
v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
def _concat_video_frames(self, batch:list, target_key:str='sync_video_mixed', source_key:str='sync_video') -> float:
|
||||
# only batch mode supported
|
||||
batch_size:int = len(batch)
|
||||
lam:float = self._params_per_batch()
|
||||
|
||||
if lam == 1.:
|
||||
# no mixup, just return
|
||||
for i in range(batch_size):
|
||||
batch[i][target_key] = batch[i][source_key]
|
||||
return lam
|
||||
|
||||
# Randomly choose between horizontal and vertical resizing using
|
||||
orig_size = int(lam * _SYNC_SIZE)
|
||||
is_horizontal = True # torch.rand(1, generator=self.generator).item() < 0.5
|
||||
if is_horizontal:
|
||||
# Horizontal resize
|
||||
resize_shape_orig = (_SYNC_SIZE, orig_size)
|
||||
resize_shape_pair = (_SYNC_SIZE, _SYNC_SIZE-orig_size)
|
||||
else:
|
||||
# Vertical resize
|
||||
resize_shape_orig = (orig_size, _SYNC_SIZE)
|
||||
resize_shape_pair = (_SYNC_SIZE-orig_size, _SYNC_SIZE)
|
||||
sync_resize_orig = v2.Compose([
|
||||
v2.Resize(resize_shape_orig, interpolation=v2.InterpolationMode.BICUBIC),
|
||||
])
|
||||
sync_resize_pair = v2.Compose([
|
||||
v2.Resize(resize_shape_pair, interpolation=v2.InterpolationMode.BICUBIC),
|
||||
])
|
||||
|
||||
batch_videos_orig = torch.stack([batch[i][source_key] for i in range(batch_size)], dim=0)
|
||||
batch_videos_pair = torch.stack([batch[batch_size - i - 1][source_key] for i in range(batch_size)], dim=0)
|
||||
# (B, T, C, H, W)
|
||||
# pass through resize, transform and concat
|
||||
batch_videos_orig = sync_resize_orig(batch_videos_orig)
|
||||
batch_videos_pair = sync_resize_pair(batch_videos_pair)
|
||||
batch_videos_concat = torch.cat((batch_videos_orig, batch_videos_pair), dim=-1 if is_horizontal else -2)
|
||||
batch_videos_concat = self.sync_transform(batch_videos_concat)
|
||||
|
||||
num_mixup = int(self.mix_prob * batch_size)
|
||||
for i in range(num_mixup):
|
||||
batch[i][target_key] = batch_videos_concat[i]
|
||||
for i in range(num_mixup, batch_size):
|
||||
batch[i][target_key] = batch[i][source_key] # no mixup
|
||||
|
||||
del batch_videos_orig, batch_videos_pair, sync_resize_orig, sync_resize_pair
|
||||
gc.collect()
|
||||
|
||||
return lam
|
||||
|
||||
def _mix_audio_samples(self, batch:list, target_key:str='audio_mixed', source_key:str='audio',
|
||||
normalize:bool = True) -> float:
|
||||
# assume source_key audios are normalized
|
||||
batch_size:int = len(batch)
|
||||
lam:float = self._params_per_batch()
|
||||
|
||||
if lam == 1.:
|
||||
# no mixup, just return
|
||||
for i in range(batch_size):
|
||||
batch[i][target_key] = batch[i][source_key]
|
||||
return lam
|
||||
|
||||
num_mixup = int(self.mix_prob * batch_size)
|
||||
for i in range(num_mixup):
|
||||
batch[i][target_key] = batch[i][source_key] * lam + batch[batch_size - i - 1][source_key] * (1 - lam)
|
||||
if normalize:
|
||||
source_abs_max = batch[i][source_key].abs().max()
|
||||
target_abs_max = batch[i][target_key].abs().max()
|
||||
batch[i][target_key] = batch[i][target_key] * (source_abs_max / target_abs_max)
|
||||
for i in range(num_mixup, batch_size):
|
||||
batch[i][target_key] = batch[i][source_key] # no mixup
|
||||
|
||||
return lam
|
||||
|
||||
def __call__(self, batch:list, _=None) -> torch.tensor:
|
||||
batch_size:int = len(batch)
|
||||
assert batch_size % 2 == 0, f'Batch size {batch_size} should be even when using mixup'
|
||||
half = 'half' in self.mode
|
||||
if half:
|
||||
batch_size //= 2
|
||||
|
||||
if self.modality == 'video' or self.modality == 'both':
|
||||
lam = self._concat_video_frames(batch, target_key=self.target_video_key, source_key=self.source_video_key)
|
||||
if self.modality == 'audio' or self.modality == 'both':
|
||||
# raise NotImplementedError('Audio mixup is not implemented yet.')
|
||||
lam = self._mix_audio_samples(batch, target_key=self.target_audio_key, source_key=self.source_audio_key)
|
||||
|
||||
return default_collate(batch)
|
||||
|
||||
|
||||
class FeatureMixup(MixupBase):
|
||||
""" Mixup video in feature domain.
|
||||
Applies different params to each element or whole batch.
|
||||
|
||||
Args:
|
||||
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
||||
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
||||
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
||||
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
||||
prob (float): Probability of applying mixup per batch or element
|
||||
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
||||
eps (float): Small epsilon value to avoid zero lambda
|
||||
"""
|
||||
def __init__(self, generator:torch.Generator,
|
||||
*,
|
||||
modality:Literal['video', 'audio', 'both']='video',
|
||||
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
||||
mode:Literal['elem','pair','batch', 'half']='batch',
|
||||
eps:float=0.05
|
||||
):
|
||||
super().__init__(generator, modality=modality,
|
||||
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
|
||||
mode=mode, eps=eps)
|
||||
self.source_video_key= 'sync_f_vid_orig'
|
||||
self.source_audio_key = 'sync_f_aud_orig'
|
||||
self.target_video_key = 'sync_f_vid_mixed'
|
||||
self.target_audio_key = 'sync_f_aud_mixed'
|
||||
|
||||
def _mix_elem_collate(self, batch:dict,
|
||||
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig'],
|
||||
half:bool=False) -> torch.tensor:
|
||||
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
||||
batch_size:int = len(batch['id'])
|
||||
num_elem:int = batch_size // 2 if half else batch_size
|
||||
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(num_elem))
|
||||
|
||||
indices = torch.arange(num_elem)
|
||||
mix_indices = batch_size - indices - 1
|
||||
mix_mask = lam_batch < 1
|
||||
active_indices = indices[mix_mask]
|
||||
active_mix_indices = mix_indices[mix_mask]
|
||||
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
|
||||
for target_key, source_key in zip(target_keys, source_keys):
|
||||
batch[target_key][active_indices] = (
|
||||
batch[source_key][active_indices] * active_lambdas +
|
||||
batch[source_key][active_mix_indices] * (1 - active_lambdas)
|
||||
)
|
||||
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
|
||||
if half:
|
||||
lam_batch = torch.cat((lam_batch, torch.ones(num_elem, dtype=lam_batch.dtype)))
|
||||
return lam_batch.unsqueeze(1)
|
||||
|
||||
def _mix_pair_collate(self, batch:dict,
|
||||
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> torch.tensor:
|
||||
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
||||
batch_size:int = len(batch['id'])
|
||||
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(batch_size // 2))
|
||||
|
||||
indices = torch.arange(batch_size // 2)
|
||||
mix_indices = batch_size - indices - 1
|
||||
mix_mask = lam_batch < 1
|
||||
active_indices = indices[mix_mask]
|
||||
active_mix_indices = mix_indices[mix_mask]
|
||||
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
|
||||
for target_key, source_key in zip(target_keys, source_keys):
|
||||
batch[target_key][active_indices] = (
|
||||
batch[source_key][active_indices] * active_lambdas +
|
||||
batch[source_key][active_mix_indices] * (1 - active_lambdas)
|
||||
)
|
||||
batch[target_key][active_mix_indices] = (
|
||||
batch[source_key][active_mix_indices] * active_lambdas +
|
||||
batch[source_key][active_indices] * (1 - active_lambdas)
|
||||
)
|
||||
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
|
||||
batch[target_key][~mix_indices[mix_mask]] = batch[source_key][~mix_indices[mix_mask]]
|
||||
lam_batch = torch.cat((lam_batch, lam_batch.flip(0)))
|
||||
return lam_batch.unsqueeze(1)
|
||||
|
||||
def _mix_batch_collate(self, batch:dict,
|
||||
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> float:
|
||||
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
||||
lam:float = self._params_per_batch()
|
||||
|
||||
for target_key, source_key in zip(target_keys, source_keys):
|
||||
num_mixup = int(self.mix_prob * batch[source_key].shape[0])
|
||||
flipped_source = torch.flip(batch[source_key], dims=[0])
|
||||
batch[target_key] = batch[source_key] * lam + flipped_source * (1 - lam)
|
||||
batch[target_key][num_mixup:] = batch[source_key][num_mixup:] # no mixup
|
||||
return lam
|
||||
|
||||
def __call__(self, batch:dict, _=None) -> None:
|
||||
batch_size:int = len(batch['id'])
|
||||
assert batch_size % 2 == 0, f'Batch size(={batch_size}) should be even when using this'
|
||||
half = 'half' in self.mode
|
||||
if half:
|
||||
batch_size //= 2
|
||||
|
||||
# Mixup
|
||||
if self.mode == 'elem' or self.mode == 'half':
|
||||
collate_fn = partial(self._mix_elem_collate, half=half)
|
||||
elif self.mode == 'pair':
|
||||
collate_fn = self._mix_pair_collate
|
||||
else:
|
||||
collate_fn = self._mix_batch_collate
|
||||
|
||||
if self.modality == 'both':
|
||||
target_keys, source_keys = [self.target_video_key, self.target_audio_key], [self.source_video_key, self.source_audio_key]
|
||||
elif self.modality == 'video':
|
||||
target_keys, source_keys = [self.target_video_key], [self.source_video_key]
|
||||
elif self.modality == 'audio':
|
||||
target_keys, source_keys = [self.target_audio_key], [self.source_audio_key]
|
||||
lam = collate_fn(batch, target_keys=target_keys, source_keys=source_keys)
|
||||
|
||||
# return batch
|
||||
|
||||
Reference in New Issue
Block a user