chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
from logging import Logger
|
||||
|
||||
from selva_core.utils.logger import TensorboardLogger
|
||||
|
||||
local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
|
||||
world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
||||
|
||||
|
||||
def info_if_rank_zero(logger: Logger, msg: str):
|
||||
if local_rank == 0:
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
def string_if_rank_zero(logger: TensorboardLogger, tag: str, msg: str):
|
||||
if local_rank == 0:
|
||||
logger.log_string(tag, msg)
|
||||
@@ -0,0 +1,109 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
links = [
|
||||
{
|
||||
'name': 'video_enc_sup_5.pth',
|
||||
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/weights/video_enc_sup_5.pth',
|
||||
'md5': 'ff09a6dc36148536ee4db97eba081d05'
|
||||
},
|
||||
{
|
||||
'name': 'generator_small_16k_sup_5.pth',
|
||||
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/weights/generator_small_16k_sup_5.pth',
|
||||
'md5': '1cb0f0deec52de37f67b1fd9965337d0'
|
||||
},
|
||||
{
|
||||
'name': 'generator_small_44k_sup_5.pth',
|
||||
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/weights/generator_small_44k_sup_5.pth',
|
||||
'md5': 'd4df8569624093ac80af99b8b7434525'
|
||||
},
|
||||
{
|
||||
'name': 'generator_medium_44k_sup_5.pth',
|
||||
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/weights/generator_medium_44k_sup_5.pth',
|
||||
'md5': 'e9157e62b4863ad306e89e8f3a587748'
|
||||
},
|
||||
{
|
||||
'name': 'generator_large_44k_sup_5.pth',
|
||||
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/weights/generator_large_44k_sup_5.pth',
|
||||
'md5': 'ab3db08b124d3aaa53eb7a1f52f1fb3f'
|
||||
},
|
||||
{
|
||||
'name': 'v1-16.pth',
|
||||
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/ext_weights/v1-16.pth',
|
||||
'md5': '69f56803f59a549a1a507c93859fd4d7'
|
||||
},
|
||||
{
|
||||
'name': 'best_netG.pt',
|
||||
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/ext_weights/best_netG.pt',
|
||||
'md5': 'eeaf372a38a9c31c362120aba2dde292'
|
||||
},
|
||||
{
|
||||
'name': 'v1-44.pth',
|
||||
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/ext_weights/v1-44.pth',
|
||||
'md5': 'fab020275fa44c6589820ce025191600'
|
||||
},
|
||||
{
|
||||
'name': 'synchformer_state_dict.pth',
|
||||
'url':
|
||||
'https://huggingface.co/jnwnlee/SelVA/resolve/main/ext_weights/synchformer_state_dict.pth',
|
||||
'md5': '5b2f5594b0730f70e41e549b7c94390c'
|
||||
},
|
||||
{
|
||||
'name': 'mmaudio_small_16k.pth',
|
||||
'url': 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_16k.pth',
|
||||
'md5': 'af93cde404179f58e3919ac085b8033b',
|
||||
},
|
||||
{
|
||||
'name': 'mmaudio_small_44k.pth',
|
||||
'url': 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_44k.pth',
|
||||
'md5': 'babd74c884783d13701ea2820a5f5b6d',
|
||||
},
|
||||
{
|
||||
'name': 'mmaudio_medium_44k.pth',
|
||||
'url': 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_medium_44k.pth',
|
||||
'md5': '5a56b6665e45a1e65ada534defa903d0',
|
||||
},
|
||||
{
|
||||
'name': 'mmaudio_large_44k.pth',
|
||||
'url': 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k.pth',
|
||||
'md5': 'fed96c325a6785b85ce75ae1aafd2673'
|
||||
},
|
||||
{
|
||||
'name': 'mmaudio_large_44k_v2.pth',
|
||||
'url': 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth',
|
||||
'md5': '01ad4464f049b2d7efdaa4c1a59b8dfe'
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def download_model_if_needed(model_path: Path):
|
||||
base_name = model_path.name
|
||||
|
||||
for link in links:
|
||||
if link['name'] == base_name:
|
||||
target_link = link
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'No link found for {base_name}')
|
||||
|
||||
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not model_path.exists() or hashlib.md5(open(model_path,
|
||||
'rb').read()).hexdigest() != target_link['md5']:
|
||||
log.info(f'Downloading {base_name} to {model_path}...')
|
||||
r = requests.get(target_link['url'], stream=True)
|
||||
total_size = int(r.headers.get('content-length', 0))
|
||||
block_size = 1024
|
||||
t = tqdm(total=total_size, unit='iB', unit_scale=True)
|
||||
with open(model_path, 'wb') as f:
|
||||
for data in r.iter_content(block_size):
|
||||
t.update(len(data))
|
||||
f.write(data)
|
||||
t.close()
|
||||
if total_size != 0 and t.n != total_size:
|
||||
raise RuntimeError('Error while downloading %s' % base_name)
|
||||
@@ -0,0 +1,50 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from pytz import timezone
|
||||
|
||||
from selva_core.utils.timezone import my_timezone
|
||||
|
||||
_source = 'USE YOURS'
|
||||
_target = 'USE YOURS'
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
_fmt = "%Y-%m-%d %H:%M:%S %Z%z"
|
||||
|
||||
|
||||
class EmailSender:
|
||||
|
||||
def __init__(self, exp_id: str, enable: bool):
|
||||
self.exp_id = exp_id
|
||||
self.enable = enable
|
||||
if enable:
|
||||
load_dotenv()
|
||||
self.MAILGUN_API_KEY = os.getenv('MAILGUN_API_KEY')
|
||||
if self.MAILGUN_API_KEY is None:
|
||||
log.warning('MAILGUN_API_KEY is not set')
|
||||
self.enable = False
|
||||
|
||||
def send(self, subject, content):
|
||||
if self.enable:
|
||||
subject = str(subject)
|
||||
content = str(content)
|
||||
try:
|
||||
return requests.post(f'https://api.mailgun.net/v3/{_source}/messages',
|
||||
auth=('api', self.MAILGUN_API_KEY),
|
||||
data={
|
||||
'from':
|
||||
f'<agent name>🤖 <mailgun@{_source}>',
|
||||
'to': [f'{_target}'],
|
||||
'subject':
|
||||
f'[{self.exp_id}] {subject}',
|
||||
'text':
|
||||
('\n\n' + content + '\n\n<your sign off>\n' +
|
||||
datetime.now(timezone(my_timezone)).strftime(_fmt)),
|
||||
},
|
||||
timeout=20)
|
||||
except Exception as e:
|
||||
log.error(f'Failed to send email: {e}')
|
||||
@@ -0,0 +1,277 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from colorlog import ColoredFormatter
|
||||
from PIL import Image
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from selva_core.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
from selva_core.model.networks_video_enc import TextSynch
|
||||
from selva_core.model.networks_generator import MMAudio
|
||||
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
|
||||
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||
from selva_core.utils.download_utils import download_model_if_needed
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelConfig:
|
||||
model_name: str
|
||||
model_video_enc_path: Path
|
||||
model_generator_path: Path
|
||||
mode: str
|
||||
vae_path: Path
|
||||
bigvgan_16k_path: Optional[Path]
|
||||
synchformer_ckpt: Path = Path('./ext_weights/synchformer_state_dict.pth')
|
||||
|
||||
@property
|
||||
def seq_cfg(self) -> SequenceConfig:
|
||||
if self.mode == '16k':
|
||||
return CONFIG_16K
|
||||
elif self.mode == '44k':
|
||||
return CONFIG_44K
|
||||
|
||||
def download_if_needed(self):
|
||||
download_model_if_needed(self.model_video_enc_path)
|
||||
download_model_if_needed(self.model_generator_path)
|
||||
download_model_if_needed(self.vae_path)
|
||||
if self.bigvgan_16k_path is not None:
|
||||
download_model_if_needed(self.bigvgan_16k_path)
|
||||
download_model_if_needed(self.synchformer_ckpt)
|
||||
|
||||
def download_video_enc_if_needed(self):
|
||||
download_model_if_needed(self.model_video_enc_path)
|
||||
|
||||
def download_generator_if_needed(self):
|
||||
download_model_if_needed(self.model_generator_path)
|
||||
|
||||
def download_external_modules_if_needed(self):
|
||||
download_model_if_needed(self.synchformer_ckpt)
|
||||
download_model_if_needed(self.vae_path)
|
||||
if self.bigvgan_16k_path is not None:
|
||||
download_model_if_needed(self.bigvgan_16k_path)
|
||||
|
||||
|
||||
small_16k = ModelConfig(model_name='small_16k',
|
||||
model_video_enc_path=Path('./weights/video_enc_sup_5.pth'),
|
||||
model_generator_path=Path('./weights/generator_small_16k_sup_5.pth'),
|
||||
vae_path=Path('./ext_weights/v1-16.pth'),
|
||||
bigvgan_16k_path=Path('./ext_weights/best_netG.pt'),
|
||||
mode='16k')
|
||||
small_44k = ModelConfig(model_name='small_44k',
|
||||
model_video_enc_path=Path('./weights/video_enc_sup_5.pth'),
|
||||
model_generator_path=Path('./weights/generator_small_44k_sup_5.pth'),
|
||||
vae_path=Path('./ext_weights/v1-44.pth'),
|
||||
bigvgan_16k_path=None,
|
||||
mode='44k')
|
||||
medium_44k = ModelConfig(model_name='medium_44k',
|
||||
model_video_enc_path=Path('./weights/video_enc_sup_5.pth'),
|
||||
model_generator_path=Path('./weights/generator_medium_44k_sup_5.pth'),
|
||||
vae_path=Path('./ext_weights/v1-44.pth'),
|
||||
bigvgan_16k_path=None,
|
||||
mode='44k')
|
||||
large_44k = ModelConfig(model_name='large_44k',
|
||||
model_video_enc_path=Path('./weights/video_enc_sup_5.pth'),
|
||||
model_generator_path=Path('./weights/generator_large_44k_sup_5.pth'),
|
||||
vae_path=Path('./ext_weights/v1-44.pth'),
|
||||
bigvgan_16k_path=None,
|
||||
mode='44k')
|
||||
all_model_cfg: dict[str, ModelConfig] = {
|
||||
'small_16k': small_16k,
|
||||
'small_44k': small_44k,
|
||||
'medium_44k': medium_44k,
|
||||
'large_44k': large_44k,
|
||||
}
|
||||
|
||||
|
||||
def generate(
|
||||
clip_video: Optional[torch.Tensor],
|
||||
sync_video: Optional[torch.Tensor],
|
||||
text: Optional[list[str]],
|
||||
*,
|
||||
negative_text: Optional[list[str]] = None,
|
||||
feature_utils: FeaturesUtils,
|
||||
net_video_enc: TextSynch,
|
||||
net_generator: MMAudio,
|
||||
fm: FlowMatching,
|
||||
rng: torch.Generator,
|
||||
cfg_strength: float,
|
||||
clip_batch_size_multiplier: int = 40,
|
||||
sync_batch_size_multiplier: int = 40,
|
||||
image_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
device = feature_utils.device
|
||||
dtype = feature_utils.dtype
|
||||
|
||||
bs = len(text)
|
||||
if text is not None:
|
||||
text_features_clip = feature_utils.encode_text_clip(text)
|
||||
text_features_flant5, text_mask_flant5 = feature_utils.encode_text_t5(text)
|
||||
else:
|
||||
text_features_clip = net_generator.get_empty_string_sequence(bs)
|
||||
text_features_flant5 = net_video_enc.get_empty_string_sequence(bs)
|
||||
text_mask_flant5 = torch.zeros_like(text_features_flant5)
|
||||
text_mask_flant5[:, 0] = 1
|
||||
|
||||
if negative_text is not None:
|
||||
assert len(negative_text) == bs
|
||||
negative_text_features_clip = feature_utils.encode_text_clip(negative_text)
|
||||
negative_text_features_flant5, negative_text_mask_flant5 = feature_utils.encode_text_t5(negative_text)
|
||||
else:
|
||||
negative_text_features_clip = None
|
||||
negative_text_features_flant5, negative_text_mask_flant5 = None, None
|
||||
|
||||
if clip_video is not None:
|
||||
clip_video = clip_video.to(device, dtype, non_blocking=True)
|
||||
clip_features = feature_utils.encode_video_with_clip(clip_video,
|
||||
batch_size=bs *
|
||||
clip_batch_size_multiplier)
|
||||
if image_input:
|
||||
clip_features = clip_features.expand(-1, net_generator.clip_seq_len, -1)
|
||||
else:
|
||||
clip_features = net_generator.get_empty_clip_sequence(bs)
|
||||
|
||||
if sync_video is not None and not image_input:
|
||||
text_features_flant5, text_mask_flant5 = net_video_enc.prepend_sup_text_tokens(text_features_flant5, text_mask_flant5)
|
||||
sync_video = sync_video.to(net_video_enc.device, net_video_enc.dtype, non_blocking=True)
|
||||
sync_features = net_video_enc.encode_video_with_sync(
|
||||
sync_video, text_f=text_features_flant5, text_mask=text_mask_flant5
|
||||
)
|
||||
else:
|
||||
sync_features = net_generator.get_empty_sync_sequence(bs)
|
||||
|
||||
x0 = torch.randn(bs,
|
||||
net_generator.latent_seq_len,
|
||||
net_generator.latent_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=rng)
|
||||
preprocessed_conditions = net_generator.preprocess_conditions(clip_features, sync_features, text_features_clip)
|
||||
empty_conditions = net_generator.get_empty_conditions(
|
||||
bs, negative_text_features=negative_text_features_clip
|
||||
)
|
||||
|
||||
cfg_ode_wrapper = lambda t, x: net_generator.ode_wrapper(t, x, preprocessed_conditions, empty_conditions,
|
||||
cfg_strength)
|
||||
x1 = fm.to_data(cfg_ode_wrapper, x0)
|
||||
x1 = net_generator.unnormalize(x1)
|
||||
spec = feature_utils.decode(x1)
|
||||
audio = feature_utils.vocode(spec)
|
||||
return audio
|
||||
|
||||
|
||||
LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s"
|
||||
|
||||
|
||||
def setup_eval_logging(log_level: int = logging.INFO):
|
||||
logging.root.setLevel(log_level)
|
||||
formatter = ColoredFormatter(LOGFORMAT)
|
||||
stream = logging.StreamHandler()
|
||||
stream.setLevel(log_level)
|
||||
stream.setFormatter(formatter)
|
||||
log = logging.getLogger()
|
||||
log.setLevel(log_level)
|
||||
log.addHandler(stream)
|
||||
|
||||
|
||||
_CLIP_SIZE = 384
|
||||
_CLIP_FPS = 8.0
|
||||
|
||||
_SYNC_SIZE = 224
|
||||
_SYNC_FPS = 25.0
|
||||
|
||||
|
||||
def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
|
||||
|
||||
clip_transform = v2.Compose([
|
||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
|
||||
sync_transform = v2.Compose([
|
||||
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
# v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
output_frames, all_frames, orig_fps = read_frames(video_path,
|
||||
list_of_fps=[_CLIP_FPS, _SYNC_FPS],
|
||||
start_sec=0,
|
||||
end_sec=duration_sec,
|
||||
need_all_frames=load_all_frames)
|
||||
|
||||
clip_chunk, sync_chunk = output_frames
|
||||
clip_chunk = torch.from_numpy(clip_chunk).permute(0, 3, 1, 2)
|
||||
sync_chunk = torch.from_numpy(sync_chunk).permute(0, 3, 1, 2)
|
||||
|
||||
clip_frames = clip_transform(clip_chunk)
|
||||
sync_frames = sync_transform(sync_chunk)
|
||||
|
||||
clip_length_sec = clip_frames.shape[0] / _CLIP_FPS
|
||||
sync_length_sec = sync_frames.shape[0] / _SYNC_FPS
|
||||
|
||||
if clip_length_sec < duration_sec:
|
||||
log.warning(f'Clip video is too short: {clip_length_sec:.2f} < {duration_sec:.2f}')
|
||||
log.warning(f'Truncating to {clip_length_sec:.2f} sec')
|
||||
duration_sec = clip_length_sec
|
||||
|
||||
if sync_length_sec < duration_sec:
|
||||
log.warning(f'Sync video is too short: {sync_length_sec:.2f} < {duration_sec:.2f}')
|
||||
log.warning(f'Truncating to {sync_length_sec:.2f} sec')
|
||||
duration_sec = sync_length_sec
|
||||
|
||||
clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
|
||||
sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
|
||||
|
||||
video_info = VideoInfo(
|
||||
duration_sec=duration_sec,
|
||||
fps=orig_fps,
|
||||
clip_frames=clip_frames,
|
||||
sync_frames=sync_frames,
|
||||
all_frames=all_frames if load_all_frames else None,
|
||||
)
|
||||
return video_info
|
||||
|
||||
|
||||
def load_image(image_path: Path) -> VideoInfo:
|
||||
clip_transform = v2.Compose([
|
||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
|
||||
sync_transform = v2.Compose([
|
||||
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
frame = np.array(Image.open(image_path))
|
||||
|
||||
clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
|
||||
sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
|
||||
|
||||
clip_frames = clip_transform(clip_chunk)
|
||||
sync_frames = sync_transform(sync_chunk)
|
||||
|
||||
video_info = ImageInfo(
|
||||
clip_frames=clip_frames,
|
||||
sync_frames=sync_frames,
|
||||
original_frame=frame,
|
||||
)
|
||||
return video_info
|
||||
|
||||
|
||||
def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
||||
reencode_with_audio(video_info, output_path, audio, sampling_rate)
|
||||
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Integrate numerical values for some iterations
|
||||
Typically used for loss computation / logging to tensorboard
|
||||
Call finalize and create a new Integrator when you want to display/log
|
||||
"""
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
|
||||
from selva_core.utils.logger import TensorboardLogger
|
||||
from selva_core.utils.tensor_utils import distribute_into_histogram
|
||||
|
||||
|
||||
class Integrator:
|
||||
|
||||
def __init__(self, logger: TensorboardLogger, distributed: bool = True):
|
||||
self.values = {}
|
||||
self.counts = {}
|
||||
self.hooks = [] # List is used here to maintain insertion order
|
||||
|
||||
# for binned tensors
|
||||
self.binned_tensors = {}
|
||||
self.binned_tensor_indices = {}
|
||||
|
||||
self.logger = logger
|
||||
|
||||
self.distributed = distributed
|
||||
self.local_rank = torch.distributed.get_rank()
|
||||
self.world_size = torch.distributed.get_world_size()
|
||||
|
||||
def add_scalar(self, key: str, x: Union[torch.Tensor, int, float]):
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.detach()
|
||||
if x.dtype in [torch.long, torch.int, torch.bool]:
|
||||
x = x.float()
|
||||
|
||||
if key not in self.values:
|
||||
self.counts[key] = 1
|
||||
self.values[key] = x
|
||||
else:
|
||||
self.counts[key] += 1
|
||||
self.values[key] += x
|
||||
|
||||
def add_scalar_with_count(self, key: str, x: Union[torch.Tensor, int, float], count: int):
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.detach()
|
||||
if x.dtype in [torch.long, torch.int, torch.bool]:
|
||||
x = x.float()
|
||||
|
||||
if key not in self.values:
|
||||
self.counts[key] = count
|
||||
self.values[key] = x
|
||||
else:
|
||||
self.counts[key] += count
|
||||
self.values[key] += x
|
||||
|
||||
def add_dict(self, tensor_dict: dict[str, torch.Tensor]):
|
||||
for k, v in tensor_dict.items():
|
||||
self.add_scalar(k, v)
|
||||
|
||||
def add_dict_with_count(self, tensor_dict: dict[str, torch.Tensor], count: int):
|
||||
for k, v in tensor_dict.items():
|
||||
self.add_scalar_with_count(k, v, count)
|
||||
|
||||
def add_binned_tensor(self, key: str, x: torch.Tensor, indices: torch.Tensor):
|
||||
if key not in self.binned_tensors:
|
||||
self.binned_tensors[key] = [x.detach().flatten()]
|
||||
self.binned_tensor_indices[key] = [indices.detach().flatten()]
|
||||
else:
|
||||
self.binned_tensors[key].append(x.detach().flatten())
|
||||
self.binned_tensor_indices[key].append(indices.detach().flatten())
|
||||
|
||||
def add_hook(self, hook: Callable[[torch.Tensor], tuple[str, torch.Tensor]]):
|
||||
"""
|
||||
Adds a custom hook, i.e. compute new metrics using values in the dict
|
||||
The hook takes the dict as argument, and returns a (k, v) tuple
|
||||
e.g. for computing IoU
|
||||
"""
|
||||
self.hooks.append(hook)
|
||||
|
||||
def reset_except_hooks(self):
|
||||
self.values = {}
|
||||
self.counts = {}
|
||||
|
||||
# Average and output the metrics
|
||||
def finalize(self, prefix: str, it: int, ignore_timer: bool = False) -> None:
|
||||
|
||||
for hook in self.hooks:
|
||||
k, v = hook(self.values)
|
||||
self.add_scalar(k, v)
|
||||
|
||||
# for the metrics
|
||||
outputs = {}
|
||||
for k, v in self.values.items():
|
||||
avg = v / self.counts[k]
|
||||
if self.distributed:
|
||||
# Inplace operation
|
||||
if isinstance(avg, torch.Tensor):
|
||||
avg = avg.cuda()
|
||||
else:
|
||||
avg = torch.tensor(avg).cuda()
|
||||
torch.distributed.reduce(avg, dst=0)
|
||||
|
||||
if self.local_rank == 0:
|
||||
avg = (avg / self.world_size).cpu().item()
|
||||
outputs[k] = avg
|
||||
else:
|
||||
# Simple does it
|
||||
outputs[k] = avg
|
||||
|
||||
if (not self.distributed) or (self.local_rank == 0):
|
||||
self.logger.log_metrics(prefix, outputs, it, ignore_timer=ignore_timer)
|
||||
|
||||
# for the binned tensors
|
||||
for k, v in self.binned_tensors.items():
|
||||
x = torch.cat(v, dim=0)
|
||||
indices = torch.cat(self.binned_tensor_indices[k], dim=0)
|
||||
hist, count = distribute_into_histogram(x, indices)
|
||||
|
||||
if self.distributed:
|
||||
torch.distributed.reduce(hist, dst=0)
|
||||
torch.distributed.reduce(count, dst=0)
|
||||
if self.local_rank == 0:
|
||||
hist = hist / count
|
||||
else:
|
||||
hist = hist / count
|
||||
|
||||
if (not self.distributed) or (self.local_rank == 0):
|
||||
self.logger.log_histogram(f'{prefix}/{k}', hist, it)
|
||||
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
Dumps things to tensorboard and console
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from PIL import Image
|
||||
from pytz import timezone
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from selva_core.utils.email_utils import EmailSender
|
||||
from selva_core.utils.time_estimator import PartialTimeEstimator, TimeEstimator
|
||||
from selva_core.utils.timezone import my_timezone
|
||||
|
||||
|
||||
def tensor_to_numpy(image: torch.Tensor):
|
||||
image_np = (image.numpy() * 255).astype('uint8')
|
||||
return image_np
|
||||
|
||||
|
||||
def detach_to_cpu(x: torch.Tensor):
|
||||
return x.detach().cpu()
|
||||
|
||||
|
||||
def fix_width_trunc(x: float):
|
||||
return ('{:.9s}'.format('{:0.9f}'.format(x)))
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram: np.ndarray, title=None, ylabel="freq_bin", ax=None):
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(1, 1)
|
||||
if title is not None:
|
||||
ax.set_title(title)
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.imshow(spectrogram, origin="lower", aspect="auto", interpolation="nearest")
|
||||
|
||||
|
||||
class TensorboardLogger:
|
||||
|
||||
def __init__(self,
|
||||
exp_id: str,
|
||||
run_dir: Union[Path, str],
|
||||
py_logger: logging.Logger,
|
||||
*,
|
||||
is_rank0: bool = False,
|
||||
enable_email: bool = False):
|
||||
self.exp_id = exp_id
|
||||
self.run_dir = Path(run_dir)
|
||||
self.py_log = py_logger
|
||||
self.email_sender = EmailSender(exp_id, enable=(is_rank0 and enable_email))
|
||||
if is_rank0:
|
||||
self.tb_log = SummaryWriter(run_dir)
|
||||
else:
|
||||
self.tb_log = None
|
||||
|
||||
# Get current git info for logging
|
||||
try:
|
||||
import git
|
||||
repo = git.Repo(".")
|
||||
git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha)
|
||||
except (ImportError, RuntimeError, TypeError):
|
||||
print('Failed to fetch git info. Defaulting to None')
|
||||
git_info = 'None'
|
||||
|
||||
self.log_string('git', git_info)
|
||||
|
||||
# log the SLURM job id if available
|
||||
job_id = os.environ.get('SLURM_JOB_ID', None)
|
||||
if job_id is not None:
|
||||
self.log_string('slurm_job_id', job_id)
|
||||
self.email_sender.send(f'Job {job_id} started', f'Job started {run_dir}')
|
||||
|
||||
# used when logging metrics
|
||||
self.batch_timer: TimeEstimator = None
|
||||
self.data_timer: PartialTimeEstimator = None
|
||||
|
||||
self.nan_count = defaultdict(int)
|
||||
|
||||
def log_scalar(self, tag: str, x: float, it: int):
|
||||
if self.tb_log is None:
|
||||
return
|
||||
if math.isnan(x) and 'grad_norm' not in tag:
|
||||
self.nan_count[tag] += 1
|
||||
if self.nan_count[tag] == 10:
|
||||
self.email_sender.send(
|
||||
f'Nan detected in {tag} @ {self.run_dir}',
|
||||
f'Nan detected in {tag} at iteration {it}; run_dir: {self.run_dir}')
|
||||
else:
|
||||
self.nan_count[tag] = 0
|
||||
self.tb_log.add_scalar(tag, x, it)
|
||||
|
||||
def log_metrics(self,
|
||||
prefix: str,
|
||||
metrics: dict[str, float],
|
||||
it: int,
|
||||
ignore_timer: bool = False):
|
||||
msg = f'{self.exp_id}-{prefix} - it {it:6d}: '
|
||||
metrics_msg = ''
|
||||
for k, v in sorted(metrics.items()):
|
||||
self.log_scalar(f'{prefix}/{k}', v, it)
|
||||
metrics_msg += f'{k: >10}:{v:.7f},\t'
|
||||
|
||||
if self.batch_timer is not None and not ignore_timer:
|
||||
self.batch_timer.update()
|
||||
avg_time = self.batch_timer.get_and_reset_avg_time()
|
||||
data_time = self.data_timer.get_and_reset_avg_time()
|
||||
|
||||
# add time to tensorboard
|
||||
self.log_scalar(f'{prefix}/avg_time', avg_time, it)
|
||||
self.log_scalar(f'{prefix}/data_time', data_time, it)
|
||||
|
||||
est = self.batch_timer.get_est_remaining(it)
|
||||
est = datetime.timedelta(seconds=est)
|
||||
if est.days > 0:
|
||||
remaining_str = f'{est.days}d {est.seconds // 3600}h'
|
||||
else:
|
||||
remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m'
|
||||
eta = datetime.datetime.now(timezone(my_timezone)) + est
|
||||
eta_str = eta.strftime('%Y-%m-%d %H:%M:%S %Z%z')
|
||||
time_msg = f'avg_time:{avg_time:.3f},data:{data_time:.3f},remaining:{remaining_str},eta:{eta_str},\t'
|
||||
msg = f'{msg} {time_msg}'
|
||||
|
||||
msg = f'{msg} {metrics_msg}'
|
||||
self.py_log.info(msg)
|
||||
|
||||
def log_histogram(self, tag: str, hist: torch.Tensor, it: int):
|
||||
if self.tb_log is None:
|
||||
return
|
||||
# hist should be a 1D tensor
|
||||
hist = hist.cpu().numpy()
|
||||
fig, ax = plt.subplots()
|
||||
x_range = np.linspace(0, 1, len(hist))
|
||||
ax.bar(x_range, hist, width=1 / (len(hist) - 1))
|
||||
ax.set_xticks(x_range)
|
||||
ax.set_xticklabels(x_range)
|
||||
plt.tight_layout()
|
||||
self.tb_log.add_figure(tag, fig, it)
|
||||
plt.close()
|
||||
|
||||
def log_image(self, prefix: str, tag: str, image: np.ndarray, it: int):
|
||||
image_dir = self.run_dir / f'{prefix}_images'
|
||||
image_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
image = Image.fromarray(image)
|
||||
image.save(image_dir / f'{it:09d}_{tag}.png')
|
||||
|
||||
def log_audio(self,
|
||||
prefix: str,
|
||||
tag: str,
|
||||
waveform: torch.Tensor,
|
||||
it: Optional[int] = None,
|
||||
*,
|
||||
subdir: Optional[Path] = None,
|
||||
sample_rate: int = 16000) -> Path:
|
||||
if subdir is None:
|
||||
audio_dir = self.run_dir / prefix
|
||||
else:
|
||||
audio_dir = self.run_dir / subdir / prefix
|
||||
audio_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
if it is None:
|
||||
name = f'{tag}.flac'
|
||||
else:
|
||||
name = f'{it:09d}_{tag}.flac'
|
||||
|
||||
torchaudio.save(audio_dir / name,
|
||||
waveform.cpu().float(),
|
||||
sample_rate=sample_rate,
|
||||
channels_first=True)
|
||||
return Path(audio_dir)
|
||||
|
||||
def log_spectrogram(
|
||||
self,
|
||||
prefix: str,
|
||||
tag: str,
|
||||
spec: torch.Tensor,
|
||||
it: Optional[int],
|
||||
*,
|
||||
subdir: Optional[Path] = None,
|
||||
):
|
||||
if subdir is None:
|
||||
spec_dir = self.run_dir / prefix
|
||||
else:
|
||||
spec_dir = self.run_dir / subdir / prefix
|
||||
spec_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
if it is None:
|
||||
name = f'{tag}.png'
|
||||
else:
|
||||
name = f'{it:09d}_{tag}.png'
|
||||
|
||||
plot_spectrogram(spec.cpu().float())
|
||||
plt.tight_layout()
|
||||
plt.savefig(spec_dir / name)
|
||||
plt.close()
|
||||
|
||||
def log_string(self, tag: str, x: str):
|
||||
self.py_log.info(f'{tag} - {x}')
|
||||
if self.tb_log is None:
|
||||
return
|
||||
self.tb_log.add_text(tag, x)
|
||||
|
||||
def debug(self, x):
|
||||
self.py_log.debug(x)
|
||||
|
||||
def info(self, x):
|
||||
self.py_log.info(x)
|
||||
|
||||
def warning(self, x):
|
||||
self.py_log.warning(x)
|
||||
|
||||
def error(self, x):
|
||||
self.py_log.error(x)
|
||||
|
||||
def critical(self, x):
|
||||
self.py_log.critical(x)
|
||||
|
||||
self.email_sender.send(f'Error occurred in {self.run_dir}', x)
|
||||
|
||||
def complete(self):
|
||||
self.email_sender.send(f'Job completed in {self.run_dir}', 'Job completed')
|
||||
@@ -0,0 +1,41 @@
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
|
||||
|
||||
def instantiate_from_config(target: str, params: Optional[dict] = None):
|
||||
"""
|
||||
Instantiate an object from a dotted path `target` and keyword `params`.
|
||||
Common name: instantiate_from_config
|
||||
"""
|
||||
if not target or not isinstance(target, str):
|
||||
raise ValueError(f"Invalid target: {target!r}")
|
||||
params = {} if params is None else params
|
||||
|
||||
# Convert OmegaConf DictConfig to plain dict if needed
|
||||
try:
|
||||
if isinstance(params, DictConfig):
|
||||
params = OmegaConf.to_container(params, resolve=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
module_path, attr_name = target.rsplit('.', 1)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Target must be like 'pkg.mod.Class', got {target!r}") from e
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(f"Could not import module '{module_path}' for target '{target}'.") from e
|
||||
|
||||
try:
|
||||
obj = getattr(module, attr_name)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(f"Module '{module_path}' has no attribute '{attr_name}' (from '{target}').") from e
|
||||
|
||||
if not callable(obj):
|
||||
raise TypeError(f"Resolved target '{target}' is not callable.")
|
||||
return obj(**dict(params))
|
||||
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from nitrous_ema import PostHocEMA
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from selva_core.model.utils.factory import create_model_from_factory
|
||||
|
||||
|
||||
def synthesize_ema(cfg: DictConfig, sigma: float, step: Optional[int]):
|
||||
vae = create_model_from_factory(cfg.model.factory_path,
|
||||
cfg.model.name,
|
||||
**cfg.model.get('params', {})
|
||||
)
|
||||
emas = PostHocEMA(vae,
|
||||
sigma_rels=cfg.ema.sigma_rels,
|
||||
update_every=cfg.ema.update_every,
|
||||
checkpoint_every_num_steps=cfg.ema.checkpoint_every,
|
||||
checkpoint_folder=cfg.ema.checkpoint_folder)
|
||||
|
||||
synthesized_ema = emas.synthesize_ema_model(sigma_rel=sigma, step=step, device='cpu')
|
||||
state_dict = synthesized_ema.ema_model.state_dict()
|
||||
return state_dict
|
||||
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
|
||||
|
||||
def distribute_into_histogram(loss: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
num_bins: int = 25) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
loss = loss.detach().flatten()
|
||||
t = t.detach().flatten()
|
||||
t = (t * num_bins).long()
|
||||
hist = torch.zeros(num_bins, device=loss.device)
|
||||
count = torch.zeros(num_bins, device=loss.device)
|
||||
hist.scatter_add_(0, t, loss)
|
||||
count.scatter_add_(0, t, torch.ones_like(loss))
|
||||
return hist, count
|
||||
@@ -0,0 +1,72 @@
|
||||
import time
|
||||
|
||||
|
||||
class TimeEstimator:
|
||||
|
||||
def __init__(self, total_iter: int, step_size: int, ema_alpha: float = 0.7):
|
||||
self.avg_time_window = [] # window-based average
|
||||
self.exp_avg_time = None # exponential moving average
|
||||
self.alpha = ema_alpha # for exponential moving average
|
||||
|
||||
self.last_time = time.time() # would not be accurate for the first iteration but well
|
||||
self.total_iter = total_iter
|
||||
self.step_size = step_size
|
||||
|
||||
self._buffering_exp = True
|
||||
|
||||
# call this at a fixed interval
|
||||
# does not have to be every step
|
||||
def update(self):
|
||||
curr_time = time.time()
|
||||
time_per_iter = curr_time - self.last_time
|
||||
self.last_time = curr_time
|
||||
|
||||
self.avg_time_window.append(time_per_iter)
|
||||
|
||||
if self._buffering_exp:
|
||||
if self.exp_avg_time is not None:
|
||||
# discard the first iteration call to not pollute the ema
|
||||
self._buffering_exp = False
|
||||
self.exp_avg_time = time_per_iter
|
||||
else:
|
||||
self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter
|
||||
|
||||
def get_est_remaining(self, it: int):
|
||||
if self.exp_avg_time is None:
|
||||
return 0
|
||||
|
||||
remaining_iter = self.total_iter - it
|
||||
return remaining_iter * self.exp_avg_time / self.step_size
|
||||
|
||||
def get_and_reset_avg_time(self):
|
||||
avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size
|
||||
self.avg_time_window = []
|
||||
return avg
|
||||
|
||||
|
||||
class PartialTimeEstimator(TimeEstimator):
|
||||
"""
|
||||
Used where the start_time and the end_time do not align
|
||||
"""
|
||||
|
||||
def update(self):
|
||||
raise RuntimeError('Please use start() and end() for PartialTimeEstimator')
|
||||
|
||||
def start(self):
|
||||
self.last_time = time.time()
|
||||
|
||||
def end(self):
|
||||
assert self.last_time is not None, 'Please call start() before calling end()'
|
||||
curr_time = time.time()
|
||||
time_per_iter = curr_time - self.last_time
|
||||
self.last_time = None
|
||||
|
||||
self.avg_time_window.append(time_per_iter)
|
||||
|
||||
if self._buffering_exp:
|
||||
if self.exp_avg_time is not None:
|
||||
# discard the first iteration call to not pollute the ema
|
||||
self._buffering_exp = False
|
||||
self.exp_avg_time = time_per_iter
|
||||
else:
|
||||
self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter
|
||||
@@ -0,0 +1 @@
|
||||
my_timezone = 'Asia/Seoul' # 'US/Central'
|
||||
@@ -0,0 +1,140 @@
|
||||
import gc
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from selva_core.model.utils.features_utils import FeatureUtils
|
||||
from selva_core.model.networks_video_enc import TextSynch
|
||||
from selva_core.data.mixup import FeatureMixup
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def preprocess_batch_with_tsynch(
|
||||
batch: dict,
|
||||
mixup_config: DictConfig,
|
||||
feature_extractor: FeatureUtils,
|
||||
net_video_enc: TextSynch,
|
||||
feature_mixup: Optional[FeatureMixup] = None,
|
||||
training: bool = False,
|
||||
) -> None:
|
||||
if mixup_config.domain == 'embedding' and feature_mixup is None:
|
||||
raise ValueError('Mixup function is required for embedding domain.')
|
||||
|
||||
bs: int = len(batch['id'])
|
||||
device = feature_extractor.device
|
||||
dtype = feature_extractor.dtype
|
||||
|
||||
video_exist = batch.get('sync_video', None) is not None
|
||||
text_exist = batch.get('caption', None) is not None
|
||||
batch['video_exist'] = torch.tensor(video_exist, device=device, dtype=torch.bool, requires_grad=False)
|
||||
batch['text_exist'] = torch.tensor(text_exist, device=device, dtype=torch.bool, requires_grad=False)
|
||||
|
||||
batch['a_mean'] = batch.get('a_mean').to(device, dtype)
|
||||
batch['a_std'] = batch['a_std'].to(device, dtype)
|
||||
|
||||
batch['clip_features'] = None
|
||||
batch['text_clip_features'] = batch['text_clip_features'].to(device, dtype)
|
||||
|
||||
if video_exist:
|
||||
tsynch_text_features = batch['text_features'].to(device, dtype)
|
||||
tsynch_text_mask = batch['text_masks'].to(device, dtype)
|
||||
|
||||
if mixup_config.enabled and mixup_config.domain == 'data':
|
||||
tsynch_text_features, tsynch_text_mask = net_video_enc.prepend_silence_text_tokens(tsynch_text_features, tsynch_text_mask)
|
||||
batch['sync_video_mixed'] = batch['sync_video_mixed'].to(device, dtype, non_blocking=True)
|
||||
batch['sync_features'] = net_video_enc.encode_video_with_sync(
|
||||
batch['sync_video_mixed'], text_f=tsynch_text_features, text_mask=tsynch_text_mask
|
||||
)
|
||||
elif mixup_config.enabled and mixup_config.domain == 'embedding':
|
||||
assert feature_mixup.modality == mixup_config.params.modality, \
|
||||
f"Mixup class modality {feature_mixup.modality} should be same as config {mixup_config.params.modality}."
|
||||
feature_mixup.target_video_key = 'sync_features'
|
||||
feature_mixup(batch)
|
||||
else:
|
||||
batch['sync_video'] = batch['sync_video'].to(device, dtype)
|
||||
tsynch_text_features, tsynch_text_mask = net_video_enc.prepend_silence_text_tokens(tsynch_text_features, tsynch_text_mask)
|
||||
batch['sync_features'] = net_video_enc.encode_video_with_sync(
|
||||
batch['sync_video'], text_f=tsynch_text_features, text_mask=tsynch_text_mask
|
||||
)
|
||||
|
||||
if training:
|
||||
for k, v in batch.items():
|
||||
if k in ['video_exist', 'text_exist', 'sync_features']:
|
||||
batch[k] = v.clone()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def preprocess_batch_with_mixup(
|
||||
batch: dict,
|
||||
mixup_config: DictConfig,
|
||||
feature_extractor: FeatureUtils = None,
|
||||
sync_batch_size_multiplier: int = 40,
|
||||
feature_mixup: Optional[FeatureMixup] = None,
|
||||
training: bool = False,
|
||||
) -> None:
|
||||
if mixup_config.domain == 'embedding' and feature_mixup is None:
|
||||
raise ValueError('Mixup function is required for embedding domain.')
|
||||
|
||||
bs: int = len(batch['id'])
|
||||
batch['text_features'], batch['text_mask'] = feature_extractor.encode_text(batch['caption'])
|
||||
if mixup_config.params.modality in ['video', 'both']:
|
||||
batch['sync_f_vid_orig'] = feature_extractor.encode_video_with_sync(batch['sync_video'],
|
||||
batch_size=bs *
|
||||
sync_batch_size_multiplier)
|
||||
if mixup_config.params.modality in ['audio', 'both']:
|
||||
batch['sync_f_aud_orig'] = feature_extractor.encode_audio_with_sync(batch['audio'],
|
||||
batch_size=bs *
|
||||
sync_batch_size_multiplier)
|
||||
if mixup_config.domain == 'data':
|
||||
if mixup_config.params.modality in ['video', 'both']:
|
||||
batch['sync_f_vid_mixed'] = feature_extractor.encode_video_with_sync(batch['sync_video_mixed'],
|
||||
batch_size=bs *
|
||||
sync_batch_size_multiplier)
|
||||
if mixup_config.params.modality in ['audio', 'both']:
|
||||
batch['sync_f_aud_mixed'] = feature_extractor.encode_audio_with_sync(batch['audio_mixed'],
|
||||
batch_size=bs *
|
||||
sync_batch_size_multiplier)
|
||||
if mixup_config.domain == 'embedding':
|
||||
assert feature_mixup.modality == mixup_config.params.modality, \
|
||||
f"Mixup class modality {feature_mixup.modality} should be same as config {mixup_config.params.modality}."
|
||||
feature_mixup(batch)
|
||||
if training:
|
||||
for k, v in batch.items():
|
||||
if k in ['text_features', 'text_mask',
|
||||
'sync_f_vid_orig', 'sync_f_aud_orig', 'sync_f_vid_mixed', 'sync_f_aud_mixed']:
|
||||
batch[k] = v.clone()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def preprocess_batch(
|
||||
batch: dict,
|
||||
mixup_config: DictConfig,
|
||||
feature_extractor: FeatureUtils = None,
|
||||
sync_batch_size_multiplier: Union[int, float] = 40,
|
||||
training: bool = False,
|
||||
) -> None:
|
||||
bs: int = len(batch['id'])
|
||||
batch['text_features'], batch['text_mask'] = feature_extractor.encode_text(batch['caption'])
|
||||
sync_batch_size = int(bs * sync_batch_size_multiplier) if sync_batch_size_multiplier > 0 else bs
|
||||
if mixup_config.params.modality in ['video', 'both']:
|
||||
batch['sync_f_vid_orig'] = feature_extractor.encode_video_with_sync(batch['sync_video'],
|
||||
batch_size=sync_batch_size)
|
||||
if mixup_config.params.modality in ['audio', 'both']:
|
||||
batch['sync_f_aud_orig'] = feature_extractor.encode_audio_with_sync(batch['audio'],
|
||||
batch_size=sync_batch_size)
|
||||
|
||||
if training:
|
||||
for k, v in batch.items():
|
||||
if k in ['text_features', 'text_mask',
|
||||
'sync_f_vid_orig', 'sync_f_aud_orig']:
|
||||
batch[k] = v.clone()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
|
||||
|
||||
def generate_multiple_segments(
|
||||
x: torch.Tensor,
|
||||
segment_size: int,
|
||||
step_size: int,
|
||||
) -> torch.Tensor:
|
||||
# x: (B, T, ...)
|
||||
b, t, *rest = x.shape
|
||||
assert t >= segment_size, f'The length of the input tensor {t} is less than the segment size {segment_size}.'
|
||||
assert segment_size > step_size, f'The segment size {segment_size} should be greater than the step size {step_size}.'
|
||||
# partition the tensor into segments
|
||||
num_segments = (t - segment_size) // step_size + 1
|
||||
segments = []
|
||||
for i in range(num_segments):
|
||||
segments.append(x[:, i * step_size:i * step_size + segment_size])
|
||||
x = torch.stack(segments, dim=1)
|
||||
|
||||
return x # (B, S, T, ...)
|
||||
@@ -0,0 +1,66 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torio.io import StreamingMediaDecoder, StreamingMediaEncoder
|
||||
|
||||
|
||||
class VideoJoiner:
|
||||
|
||||
def __init__(self, src_root: Union[str, Path], output_root: Union[str, Path], sample_rate: int,
|
||||
duration_seconds: float):
|
||||
self.src_root = Path(src_root)
|
||||
self.output_root = Path(output_root)
|
||||
self.sample_rate = sample_rate
|
||||
self.duration_seconds = duration_seconds
|
||||
|
||||
self.output_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def join(self, video_id: str, output_name: str, audio: torch.Tensor):
|
||||
video_path = self.src_root / f'{video_id}.mp4'
|
||||
output_path = self.output_root / f'{output_name}.mp4'
|
||||
merge_audio_into_video(video_path, output_path, audio, self.sample_rate,
|
||||
self.duration_seconds)
|
||||
|
||||
|
||||
def merge_audio_into_video(video_path: Union[str, Path], output_path: Union[str, Path],
|
||||
audio: torch.Tensor, sample_rate: int, duration_seconds: float):
|
||||
# audio: (num_samples, num_channels=1/2)
|
||||
|
||||
frame_rate = 24
|
||||
# read the video
|
||||
reader = StreamingMediaDecoder(video_path)
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(frame_rate * duration_seconds),
|
||||
# buffer_chunk_size=1, # does not work with this -- extracted audio would be too short
|
||||
format="rgb24",
|
||||
frame_rate=frame_rate,
|
||||
)
|
||||
|
||||
reader.fill_buffer()
|
||||
video_chunk = reader.pop_chunks()[0]
|
||||
t, _, h, w = video_chunk.shape
|
||||
|
||||
writer = StreamingMediaEncoder(output_path)
|
||||
writer.add_audio_stream(
|
||||
sample_rate=sample_rate,
|
||||
num_channels=audio.shape[-1],
|
||||
encoder='aac', # "libmp3lame",
|
||||
)
|
||||
writer.add_video_stream(frame_rate=frame_rate,
|
||||
width=w,
|
||||
height=h,
|
||||
format="rgb24",
|
||||
encoder="libx264",
|
||||
encoder_format="yuv420p")
|
||||
|
||||
with writer.open():
|
||||
writer.write_video_chunk(1, video_chunk)
|
||||
writer.write_audio_chunk(0, audio.float())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Usage example
|
||||
import sys
|
||||
audio = torch.randn(16000 * 4, 1)
|
||||
merge_audio_into_video(sys.argv[1], sys.argv[2], audio, 16000, 4)
|
||||
Reference in New Issue
Block a user