6bc3fd6443
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
339 lines
16 KiB
Python
339 lines
16 KiB
Python
""" Embedding Mixup
|
|
Reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/mixup.py
|
|
"""
|
|
from typing import Literal, Tuple, Union, List, Optional
|
|
from functools import partial
|
|
import gc
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data.dataloader import default_collate
|
|
from torchvision.transforms import v2
|
|
from einops import rearrange
|
|
from omegaconf import DictConfig
|
|
|
|
from selva_core.data.vgg_sound import _SYNC_SIZE
|
|
|
|
|
|
class MixupBase:
|
|
""" Base class for mixup on either data or feature domain.
|
|
Applies different params to each element or whole batch.
|
|
|
|
Args:
|
|
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
|
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
|
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
|
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
|
prob (float): Probability of applying mixup per batch or element
|
|
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
|
eps (float): Small epsilon value to avoid zero lambda
|
|
"""
|
|
def __init__(self, generator:torch.Generator,
|
|
*,
|
|
modality:Literal['video', 'audio', 'both'],
|
|
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
|
mode:Literal['elem','pair','batch', 'half']='batch',
|
|
eps:float=0.05
|
|
):
|
|
self.modality = modality
|
|
self.mixup_lambda:float = mixup_lambda
|
|
self.mixup_alpha:float = mixup_alpha
|
|
self.mix_prob:float = prob
|
|
self.mode:str = mode
|
|
self.eps:float = eps
|
|
self.mixup_enabled:bool = True # set to false to disable mixing (intended to be set by train loop)
|
|
if generator.device.type == 'cuda':
|
|
self.generator_cuda = generator
|
|
generator_seed = generator.initial_seed()
|
|
self.generator = torch.Generator(device='cpu')
|
|
self.generator.manual_seed(generator_seed)
|
|
else:
|
|
self.generator = generator
|
|
|
|
if not (self.mixup_lambda >= 0. and self.mixup_lambda <= 1.):
|
|
raise ValueError(f"mixup_lambda {self.mixup_lambda} should be in [0., 1.].")
|
|
if not self.mixup_alpha >= 0.:
|
|
raise ValueError(f"mixup_alpha {self.mixup_alpha} >= 0. should be true.")
|
|
if (self.mixup_alpha > 0. and self.mixup_lambda < 1.) or (self.mixup_alpha == 0. and self.mixup_lambda == 1.):
|
|
raise ValueError(f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true.")
|
|
|
|
def _params_per_elem(self, batch_size:int) -> np.ndarray:
|
|
lam:np.ndarray = np.ones(batch_size, dtype=np.float32)
|
|
if self.mixup_enabled:
|
|
if self.mixup_lambda < 1.: # constant lambda
|
|
lam_mix = np.full(batch_size, self.mixup_lambda, dtype=np.float32)
|
|
elif self.mixup_alpha > 0.: # sampled lambda
|
|
# Use torch's beta distribution with generator
|
|
lam_mix = torch.distributions.Beta(
|
|
torch.tensor([self.mixup_alpha]),
|
|
torch.tensor([self.mixup_alpha]),
|
|
).sample([batch_size]).numpy().astype(np.float32).reshape(-1)
|
|
else:
|
|
assert False, f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
|
|
lam_mix[lam_mix < self.eps] = self.eps
|
|
|
|
# Use torch's random with generator for the random comparison
|
|
rand_vals = torch.rand(batch_size, generator=self.generator).numpy()
|
|
lam = np.where(rand_vals < self.mix_prob, lam_mix, lam)
|
|
return lam
|
|
|
|
def _params_per_batch(self) -> float:
|
|
lam:float = 1.
|
|
if self.mixup_enabled:
|
|
if self.mixup_lambda < 1.: # constant lambda
|
|
lam = self.mixup_lambda
|
|
elif self.mixup_alpha > 0.: # sampled lambda
|
|
lam = torch.distributions.Beta(
|
|
torch.tensor([self.mixup_alpha]),
|
|
torch.tensor([self.mixup_alpha]),
|
|
).sample().item()
|
|
else:
|
|
assert False, f"mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
|
|
if lam < self.eps: lam = self.eps
|
|
lam = float(lam)
|
|
return lam
|
|
|
|
|
|
class DataMixupCollate(MixupBase):
|
|
""" Mixup video in data domain.
|
|
Applies different params to each element or whole batch.
|
|
|
|
Args:
|
|
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
|
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
|
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
|
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
|
prob (float): Probability of applying mixup per batch or element
|
|
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
|
eps (float): Small epsilon value to avoid zero lambda
|
|
"""
|
|
def __init__(self, generator:torch.Generator,
|
|
*,
|
|
modality:Literal['video', 'audio', 'both']='video',
|
|
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
|
mode:Literal['elem','pair','batch', 'half']='batch',
|
|
eps:float=0.05
|
|
):
|
|
super().__init__(generator, modality=modality,
|
|
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
|
|
mode=mode, eps=eps)
|
|
|
|
self.source_video_key= 'sync_video'
|
|
self.source_audio_key = 'audio'
|
|
self.target_video_key = 'sync_video_mixed'
|
|
self.target_audio_key = 'audio_mixed'
|
|
|
|
if not mode == 'batch':
|
|
raise ValueError(f"Mode {mode} is not supported for data domain.")
|
|
self.sync_transform = v2.Compose([
|
|
v2.CenterCrop(_SYNC_SIZE),
|
|
v2.ToImage(),
|
|
v2.ToDtype(torch.float32, scale=True),
|
|
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
])
|
|
|
|
def _concat_video_frames(self, batch:list, target_key:str='sync_video_mixed', source_key:str='sync_video') -> float:
|
|
# only batch mode supported
|
|
batch_size:int = len(batch)
|
|
lam:float = self._params_per_batch()
|
|
|
|
if lam == 1.:
|
|
# no mixup, just return
|
|
for i in range(batch_size):
|
|
batch[i][target_key] = batch[i][source_key]
|
|
return lam
|
|
|
|
# Randomly choose between horizontal and vertical resizing using
|
|
orig_size = int(lam * _SYNC_SIZE)
|
|
is_horizontal = True # torch.rand(1, generator=self.generator).item() < 0.5
|
|
if is_horizontal:
|
|
# Horizontal resize
|
|
resize_shape_orig = (_SYNC_SIZE, orig_size)
|
|
resize_shape_pair = (_SYNC_SIZE, _SYNC_SIZE-orig_size)
|
|
else:
|
|
# Vertical resize
|
|
resize_shape_orig = (orig_size, _SYNC_SIZE)
|
|
resize_shape_pair = (_SYNC_SIZE-orig_size, _SYNC_SIZE)
|
|
sync_resize_orig = v2.Compose([
|
|
v2.Resize(resize_shape_orig, interpolation=v2.InterpolationMode.BICUBIC),
|
|
])
|
|
sync_resize_pair = v2.Compose([
|
|
v2.Resize(resize_shape_pair, interpolation=v2.InterpolationMode.BICUBIC),
|
|
])
|
|
|
|
batch_videos_orig = torch.stack([batch[i][source_key] for i in range(batch_size)], dim=0)
|
|
batch_videos_pair = torch.stack([batch[batch_size - i - 1][source_key] for i in range(batch_size)], dim=0)
|
|
# (B, T, C, H, W)
|
|
# pass through resize, transform and concat
|
|
batch_videos_orig = sync_resize_orig(batch_videos_orig)
|
|
batch_videos_pair = sync_resize_pair(batch_videos_pair)
|
|
batch_videos_concat = torch.cat((batch_videos_orig, batch_videos_pair), dim=-1 if is_horizontal else -2)
|
|
batch_videos_concat = self.sync_transform(batch_videos_concat)
|
|
|
|
num_mixup = int(self.mix_prob * batch_size)
|
|
for i in range(num_mixup):
|
|
batch[i][target_key] = batch_videos_concat[i]
|
|
for i in range(num_mixup, batch_size):
|
|
batch[i][target_key] = batch[i][source_key] # no mixup
|
|
|
|
del batch_videos_orig, batch_videos_pair, sync_resize_orig, sync_resize_pair
|
|
gc.collect()
|
|
|
|
return lam
|
|
|
|
def _mix_audio_samples(self, batch:list, target_key:str='audio_mixed', source_key:str='audio',
|
|
normalize:bool = True) -> float:
|
|
# assume source_key audios are normalized
|
|
batch_size:int = len(batch)
|
|
lam:float = self._params_per_batch()
|
|
|
|
if lam == 1.:
|
|
# no mixup, just return
|
|
for i in range(batch_size):
|
|
batch[i][target_key] = batch[i][source_key]
|
|
return lam
|
|
|
|
num_mixup = int(self.mix_prob * batch_size)
|
|
for i in range(num_mixup):
|
|
batch[i][target_key] = batch[i][source_key] * lam + batch[batch_size - i - 1][source_key] * (1 - lam)
|
|
if normalize:
|
|
source_abs_max = batch[i][source_key].abs().max()
|
|
target_abs_max = batch[i][target_key].abs().max()
|
|
batch[i][target_key] = batch[i][target_key] * (source_abs_max / target_abs_max)
|
|
for i in range(num_mixup, batch_size):
|
|
batch[i][target_key] = batch[i][source_key] # no mixup
|
|
|
|
return lam
|
|
|
|
def __call__(self, batch:list, _=None) -> torch.tensor:
|
|
batch_size:int = len(batch)
|
|
assert batch_size % 2 == 0, f'Batch size {batch_size} should be even when using mixup'
|
|
half = 'half' in self.mode
|
|
if half:
|
|
batch_size //= 2
|
|
|
|
if self.modality == 'video' or self.modality == 'both':
|
|
lam = self._concat_video_frames(batch, target_key=self.target_video_key, source_key=self.source_video_key)
|
|
if self.modality == 'audio' or self.modality == 'both':
|
|
# raise NotImplementedError('Audio mixup is not implemented yet.')
|
|
lam = self._mix_audio_samples(batch, target_key=self.target_audio_key, source_key=self.source_audio_key)
|
|
|
|
return default_collate(batch)
|
|
|
|
|
|
class FeatureMixup(MixupBase):
|
|
""" Mixup video in feature domain.
|
|
Applies different params to each element or whole batch.
|
|
|
|
Args:
|
|
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
|
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
|
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
|
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
|
prob (float): Probability of applying mixup per batch or element
|
|
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
|
eps (float): Small epsilon value to avoid zero lambda
|
|
"""
|
|
def __init__(self, generator:torch.Generator,
|
|
*,
|
|
modality:Literal['video', 'audio', 'both']='video',
|
|
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
|
mode:Literal['elem','pair','batch', 'half']='batch',
|
|
eps:float=0.05
|
|
):
|
|
super().__init__(generator, modality=modality,
|
|
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
|
|
mode=mode, eps=eps)
|
|
self.source_video_key= 'sync_f_vid_orig'
|
|
self.source_audio_key = 'sync_f_aud_orig'
|
|
self.target_video_key = 'sync_f_vid_mixed'
|
|
self.target_audio_key = 'sync_f_aud_mixed'
|
|
|
|
def _mix_elem_collate(self, batch:dict,
|
|
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig'],
|
|
half:bool=False) -> torch.tensor:
|
|
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
|
batch_size:int = len(batch['id'])
|
|
num_elem:int = batch_size // 2 if half else batch_size
|
|
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(num_elem))
|
|
|
|
indices = torch.arange(num_elem)
|
|
mix_indices = batch_size - indices - 1
|
|
mix_mask = lam_batch < 1
|
|
active_indices = indices[mix_mask]
|
|
active_mix_indices = mix_indices[mix_mask]
|
|
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
|
|
for target_key, source_key in zip(target_keys, source_keys):
|
|
batch[target_key][active_indices] = (
|
|
batch[source_key][active_indices] * active_lambdas +
|
|
batch[source_key][active_mix_indices] * (1 - active_lambdas)
|
|
)
|
|
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
|
|
if half:
|
|
lam_batch = torch.cat((lam_batch, torch.ones(num_elem, dtype=lam_batch.dtype)))
|
|
return lam_batch.unsqueeze(1)
|
|
|
|
def _mix_pair_collate(self, batch:dict,
|
|
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> torch.tensor:
|
|
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
|
batch_size:int = len(batch['id'])
|
|
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(batch_size // 2))
|
|
|
|
indices = torch.arange(batch_size // 2)
|
|
mix_indices = batch_size - indices - 1
|
|
mix_mask = lam_batch < 1
|
|
active_indices = indices[mix_mask]
|
|
active_mix_indices = mix_indices[mix_mask]
|
|
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
|
|
for target_key, source_key in zip(target_keys, source_keys):
|
|
batch[target_key][active_indices] = (
|
|
batch[source_key][active_indices] * active_lambdas +
|
|
batch[source_key][active_mix_indices] * (1 - active_lambdas)
|
|
)
|
|
batch[target_key][active_mix_indices] = (
|
|
batch[source_key][active_mix_indices] * active_lambdas +
|
|
batch[source_key][active_indices] * (1 - active_lambdas)
|
|
)
|
|
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
|
|
batch[target_key][~mix_indices[mix_mask]] = batch[source_key][~mix_indices[mix_mask]]
|
|
lam_batch = torch.cat((lam_batch, lam_batch.flip(0)))
|
|
return lam_batch.unsqueeze(1)
|
|
|
|
def _mix_batch_collate(self, batch:dict,
|
|
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> float:
|
|
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
|
lam:float = self._params_per_batch()
|
|
|
|
for target_key, source_key in zip(target_keys, source_keys):
|
|
num_mixup = int(self.mix_prob * batch[source_key].shape[0])
|
|
flipped_source = torch.flip(batch[source_key], dims=[0])
|
|
batch[target_key] = batch[source_key] * lam + flipped_source * (1 - lam)
|
|
batch[target_key][num_mixup:] = batch[source_key][num_mixup:] # no mixup
|
|
return lam
|
|
|
|
def __call__(self, batch:dict, _=None) -> None:
|
|
batch_size:int = len(batch['id'])
|
|
assert batch_size % 2 == 0, f'Batch size(={batch_size}) should be even when using this'
|
|
half = 'half' in self.mode
|
|
if half:
|
|
batch_size //= 2
|
|
|
|
# Mixup
|
|
if self.mode == 'elem' or self.mode == 'half':
|
|
collate_fn = partial(self._mix_elem_collate, half=half)
|
|
elif self.mode == 'pair':
|
|
collate_fn = self._mix_pair_collate
|
|
else:
|
|
collate_fn = self._mix_batch_collate
|
|
|
|
if self.modality == 'both':
|
|
target_keys, source_keys = [self.target_video_key, self.target_audio_key], [self.source_video_key, self.source_audio_key]
|
|
elif self.modality == 'video':
|
|
target_keys, source_keys = [self.target_video_key], [self.source_video_key]
|
|
elif self.modality == 'audio':
|
|
target_keys, source_keys = [self.target_audio_key], [self.source_audio_key]
|
|
lam = collate_fn(batch, target_keys=target_keys, source_keys=source_keys)
|
|
|
|
# return batch
|
|
|