chore: vendor selva_core from jnwnlee/selva@d7d40a9

Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes.
Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced
librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy
version incompatibility in some ComfyUI environments.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-04 15:18:09 +02:00
parent 762b19fd3a
commit 6bc3fd6443
106 changed files with 11323 additions and 0 deletions
View File
+17
View File
@@ -0,0 +1,17 @@
import os
from logging import Logger
from selva_core.utils.logger import TensorboardLogger
local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
def info_if_rank_zero(logger: Logger, msg: str):
if local_rank == 0:
logger.info(msg)
def string_if_rank_zero(logger: TensorboardLogger, tag: str, msg: str):
if local_rank == 0:
logger.log_string(tag, msg)
+109
View File
@@ -0,0 +1,109 @@
import hashlib
import logging
from pathlib import Path
import requests
from tqdm import tqdm
log = logging.getLogger()
links = [
{
'name': 'video_enc_sup_5.pth',
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/weights/video_enc_sup_5.pth',
'md5': 'ff09a6dc36148536ee4db97eba081d05'
},
{
'name': 'generator_small_16k_sup_5.pth',
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/weights/generator_small_16k_sup_5.pth',
'md5': '1cb0f0deec52de37f67b1fd9965337d0'
},
{
'name': 'generator_small_44k_sup_5.pth',
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/weights/generator_small_44k_sup_5.pth',
'md5': 'd4df8569624093ac80af99b8b7434525'
},
{
'name': 'generator_medium_44k_sup_5.pth',
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/weights/generator_medium_44k_sup_5.pth',
'md5': 'e9157e62b4863ad306e89e8f3a587748'
},
{
'name': 'generator_large_44k_sup_5.pth',
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/weights/generator_large_44k_sup_5.pth',
'md5': 'ab3db08b124d3aaa53eb7a1f52f1fb3f'
},
{
'name': 'v1-16.pth',
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/ext_weights/v1-16.pth',
'md5': '69f56803f59a549a1a507c93859fd4d7'
},
{
'name': 'best_netG.pt',
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/ext_weights/best_netG.pt',
'md5': 'eeaf372a38a9c31c362120aba2dde292'
},
{
'name': 'v1-44.pth',
'url': 'https://huggingface.co/jnwnlee/SelVA/resolve/main/ext_weights/v1-44.pth',
'md5': 'fab020275fa44c6589820ce025191600'
},
{
'name': 'synchformer_state_dict.pth',
'url':
'https://huggingface.co/jnwnlee/SelVA/resolve/main/ext_weights/synchformer_state_dict.pth',
'md5': '5b2f5594b0730f70e41e549b7c94390c'
},
{
'name': 'mmaudio_small_16k.pth',
'url': 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_16k.pth',
'md5': 'af93cde404179f58e3919ac085b8033b',
},
{
'name': 'mmaudio_small_44k.pth',
'url': 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_44k.pth',
'md5': 'babd74c884783d13701ea2820a5f5b6d',
},
{
'name': 'mmaudio_medium_44k.pth',
'url': 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_medium_44k.pth',
'md5': '5a56b6665e45a1e65ada534defa903d0',
},
{
'name': 'mmaudio_large_44k.pth',
'url': 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k.pth',
'md5': 'fed96c325a6785b85ce75ae1aafd2673'
},
{
'name': 'mmaudio_large_44k_v2.pth',
'url': 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth',
'md5': '01ad4464f049b2d7efdaa4c1a59b8dfe'
},
]
def download_model_if_needed(model_path: Path):
base_name = model_path.name
for link in links:
if link['name'] == base_name:
target_link = link
break
else:
raise ValueError(f'No link found for {base_name}')
model_path.parent.mkdir(parents=True, exist_ok=True)
if not model_path.exists() or hashlib.md5(open(model_path,
'rb').read()).hexdigest() != target_link['md5']:
log.info(f'Downloading {base_name} to {model_path}...')
r = requests.get(target_link['url'], stream=True)
total_size = int(r.headers.get('content-length', 0))
block_size = 1024
t = tqdm(total=total_size, unit='iB', unit_scale=True)
with open(model_path, 'wb') as f:
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
t.close()
if total_size != 0 and t.n != total_size:
raise RuntimeError('Error while downloading %s' % base_name)
+50
View File
@@ -0,0 +1,50 @@
import logging
import os
from datetime import datetime
import requests
from dotenv import load_dotenv
from pytz import timezone
from selva_core.utils.timezone import my_timezone
_source = 'USE YOURS'
_target = 'USE YOURS'
log = logging.getLogger()
_fmt = "%Y-%m-%d %H:%M:%S %Z%z"
class EmailSender:
def __init__(self, exp_id: str, enable: bool):
self.exp_id = exp_id
self.enable = enable
if enable:
load_dotenv()
self.MAILGUN_API_KEY = os.getenv('MAILGUN_API_KEY')
if self.MAILGUN_API_KEY is None:
log.warning('MAILGUN_API_KEY is not set')
self.enable = False
def send(self, subject, content):
if self.enable:
subject = str(subject)
content = str(content)
try:
return requests.post(f'https://api.mailgun.net/v3/{_source}/messages',
auth=('api', self.MAILGUN_API_KEY),
data={
'from':
f'<agent name>🤖 <mailgun@{_source}>',
'to': [f'{_target}'],
'subject':
f'[{self.exp_id}] {subject}',
'text':
('\n\n' + content + '\n\n<your sign off>\n' +
datetime.now(timezone(my_timezone)).strftime(_fmt)),
},
timeout=20)
except Exception as e:
log.error(f'Failed to send email: {e}')
+277
View File
@@ -0,0 +1,277 @@
import dataclasses
import logging
from pathlib import Path
from typing import Optional
import numpy as np
import torch
from colorlog import ColoredFormatter
from PIL import Image
from torchvision.transforms import v2
from selva_core.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio
from selva_core.model.flow_matching import FlowMatching
from selva_core.model.networks_video_enc import TextSynch
from selva_core.model.networks_generator import MMAudio
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
from selva_core.model.utils.features_utils import FeaturesUtils
from selva_core.utils.download_utils import download_model_if_needed
log = logging.getLogger()
@dataclasses.dataclass
class ModelConfig:
model_name: str
model_video_enc_path: Path
model_generator_path: Path
mode: str
vae_path: Path
bigvgan_16k_path: Optional[Path]
synchformer_ckpt: Path = Path('./ext_weights/synchformer_state_dict.pth')
@property
def seq_cfg(self) -> SequenceConfig:
if self.mode == '16k':
return CONFIG_16K
elif self.mode == '44k':
return CONFIG_44K
def download_if_needed(self):
download_model_if_needed(self.model_video_enc_path)
download_model_if_needed(self.model_generator_path)
download_model_if_needed(self.vae_path)
if self.bigvgan_16k_path is not None:
download_model_if_needed(self.bigvgan_16k_path)
download_model_if_needed(self.synchformer_ckpt)
def download_video_enc_if_needed(self):
download_model_if_needed(self.model_video_enc_path)
def download_generator_if_needed(self):
download_model_if_needed(self.model_generator_path)
def download_external_modules_if_needed(self):
download_model_if_needed(self.synchformer_ckpt)
download_model_if_needed(self.vae_path)
if self.bigvgan_16k_path is not None:
download_model_if_needed(self.bigvgan_16k_path)
small_16k = ModelConfig(model_name='small_16k',
model_video_enc_path=Path('./weights/video_enc_sup_5.pth'),
model_generator_path=Path('./weights/generator_small_16k_sup_5.pth'),
vae_path=Path('./ext_weights/v1-16.pth'),
bigvgan_16k_path=Path('./ext_weights/best_netG.pt'),
mode='16k')
small_44k = ModelConfig(model_name='small_44k',
model_video_enc_path=Path('./weights/video_enc_sup_5.pth'),
model_generator_path=Path('./weights/generator_small_44k_sup_5.pth'),
vae_path=Path('./ext_weights/v1-44.pth'),
bigvgan_16k_path=None,
mode='44k')
medium_44k = ModelConfig(model_name='medium_44k',
model_video_enc_path=Path('./weights/video_enc_sup_5.pth'),
model_generator_path=Path('./weights/generator_medium_44k_sup_5.pth'),
vae_path=Path('./ext_weights/v1-44.pth'),
bigvgan_16k_path=None,
mode='44k')
large_44k = ModelConfig(model_name='large_44k',
model_video_enc_path=Path('./weights/video_enc_sup_5.pth'),
model_generator_path=Path('./weights/generator_large_44k_sup_5.pth'),
vae_path=Path('./ext_weights/v1-44.pth'),
bigvgan_16k_path=None,
mode='44k')
all_model_cfg: dict[str, ModelConfig] = {
'small_16k': small_16k,
'small_44k': small_44k,
'medium_44k': medium_44k,
'large_44k': large_44k,
}
def generate(
clip_video: Optional[torch.Tensor],
sync_video: Optional[torch.Tensor],
text: Optional[list[str]],
*,
negative_text: Optional[list[str]] = None,
feature_utils: FeaturesUtils,
net_video_enc: TextSynch,
net_generator: MMAudio,
fm: FlowMatching,
rng: torch.Generator,
cfg_strength: float,
clip_batch_size_multiplier: int = 40,
sync_batch_size_multiplier: int = 40,
image_input: bool = False,
) -> torch.Tensor:
device = feature_utils.device
dtype = feature_utils.dtype
bs = len(text)
if text is not None:
text_features_clip = feature_utils.encode_text_clip(text)
text_features_flant5, text_mask_flant5 = feature_utils.encode_text_t5(text)
else:
text_features_clip = net_generator.get_empty_string_sequence(bs)
text_features_flant5 = net_video_enc.get_empty_string_sequence(bs)
text_mask_flant5 = torch.zeros_like(text_features_flant5)
text_mask_flant5[:, 0] = 1
if negative_text is not None:
assert len(negative_text) == bs
negative_text_features_clip = feature_utils.encode_text_clip(negative_text)
negative_text_features_flant5, negative_text_mask_flant5 = feature_utils.encode_text_t5(negative_text)
else:
negative_text_features_clip = None
negative_text_features_flant5, negative_text_mask_flant5 = None, None
if clip_video is not None:
clip_video = clip_video.to(device, dtype, non_blocking=True)
clip_features = feature_utils.encode_video_with_clip(clip_video,
batch_size=bs *
clip_batch_size_multiplier)
if image_input:
clip_features = clip_features.expand(-1, net_generator.clip_seq_len, -1)
else:
clip_features = net_generator.get_empty_clip_sequence(bs)
if sync_video is not None and not image_input:
text_features_flant5, text_mask_flant5 = net_video_enc.prepend_sup_text_tokens(text_features_flant5, text_mask_flant5)
sync_video = sync_video.to(net_video_enc.device, net_video_enc.dtype, non_blocking=True)
sync_features = net_video_enc.encode_video_with_sync(
sync_video, text_f=text_features_flant5, text_mask=text_mask_flant5
)
else:
sync_features = net_generator.get_empty_sync_sequence(bs)
x0 = torch.randn(bs,
net_generator.latent_seq_len,
net_generator.latent_dim,
device=device,
dtype=dtype,
generator=rng)
preprocessed_conditions = net_generator.preprocess_conditions(clip_features, sync_features, text_features_clip)
empty_conditions = net_generator.get_empty_conditions(
bs, negative_text_features=negative_text_features_clip
)
cfg_ode_wrapper = lambda t, x: net_generator.ode_wrapper(t, x, preprocessed_conditions, empty_conditions,
cfg_strength)
x1 = fm.to_data(cfg_ode_wrapper, x0)
x1 = net_generator.unnormalize(x1)
spec = feature_utils.decode(x1)
audio = feature_utils.vocode(spec)
return audio
LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s"
def setup_eval_logging(log_level: int = logging.INFO):
logging.root.setLevel(log_level)
formatter = ColoredFormatter(LOGFORMAT)
stream = logging.StreamHandler()
stream.setLevel(log_level)
stream.setFormatter(formatter)
log = logging.getLogger()
log.setLevel(log_level)
log.addHandler(stream)
_CLIP_SIZE = 384
_CLIP_FPS = 8.0
_SYNC_SIZE = 224
_SYNC_FPS = 25.0
def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
clip_transform = v2.Compose([
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
])
sync_transform = v2.Compose([
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
# v2.CenterCrop(_SYNC_SIZE),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
output_frames, all_frames, orig_fps = read_frames(video_path,
list_of_fps=[_CLIP_FPS, _SYNC_FPS],
start_sec=0,
end_sec=duration_sec,
need_all_frames=load_all_frames)
clip_chunk, sync_chunk = output_frames
clip_chunk = torch.from_numpy(clip_chunk).permute(0, 3, 1, 2)
sync_chunk = torch.from_numpy(sync_chunk).permute(0, 3, 1, 2)
clip_frames = clip_transform(clip_chunk)
sync_frames = sync_transform(sync_chunk)
clip_length_sec = clip_frames.shape[0] / _CLIP_FPS
sync_length_sec = sync_frames.shape[0] / _SYNC_FPS
if clip_length_sec < duration_sec:
log.warning(f'Clip video is too short: {clip_length_sec:.2f} < {duration_sec:.2f}')
log.warning(f'Truncating to {clip_length_sec:.2f} sec')
duration_sec = clip_length_sec
if sync_length_sec < duration_sec:
log.warning(f'Sync video is too short: {sync_length_sec:.2f} < {duration_sec:.2f}')
log.warning(f'Truncating to {sync_length_sec:.2f} sec')
duration_sec = sync_length_sec
clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
video_info = VideoInfo(
duration_sec=duration_sec,
fps=orig_fps,
clip_frames=clip_frames,
sync_frames=sync_frames,
all_frames=all_frames if load_all_frames else None,
)
return video_info
def load_image(image_path: Path) -> VideoInfo:
clip_transform = v2.Compose([
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
])
sync_transform = v2.Compose([
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
v2.CenterCrop(_SYNC_SIZE),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
frame = np.array(Image.open(image_path))
clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
clip_frames = clip_transform(clip_chunk)
sync_frames = sync_transform(sync_chunk)
video_info = ImageInfo(
clip_frames=clip_frames,
sync_frames=sync_frames,
original_frame=frame,
)
return video_info
def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
reencode_with_audio(video_info, output_path, audio, sampling_rate)
+129
View File
@@ -0,0 +1,129 @@
"""
Integrate numerical values for some iterations
Typically used for loss computation / logging to tensorboard
Call finalize and create a new Integrator when you want to display/log
"""
from typing import Callable, Union
import torch
from selva_core.utils.logger import TensorboardLogger
from selva_core.utils.tensor_utils import distribute_into_histogram
class Integrator:
def __init__(self, logger: TensorboardLogger, distributed: bool = True):
self.values = {}
self.counts = {}
self.hooks = [] # List is used here to maintain insertion order
# for binned tensors
self.binned_tensors = {}
self.binned_tensor_indices = {}
self.logger = logger
self.distributed = distributed
self.local_rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
def add_scalar(self, key: str, x: Union[torch.Tensor, int, float]):
if isinstance(x, torch.Tensor):
x = x.detach()
if x.dtype in [torch.long, torch.int, torch.bool]:
x = x.float()
if key not in self.values:
self.counts[key] = 1
self.values[key] = x
else:
self.counts[key] += 1
self.values[key] += x
def add_scalar_with_count(self, key: str, x: Union[torch.Tensor, int, float], count: int):
if isinstance(x, torch.Tensor):
x = x.detach()
if x.dtype in [torch.long, torch.int, torch.bool]:
x = x.float()
if key not in self.values:
self.counts[key] = count
self.values[key] = x
else:
self.counts[key] += count
self.values[key] += x
def add_dict(self, tensor_dict: dict[str, torch.Tensor]):
for k, v in tensor_dict.items():
self.add_scalar(k, v)
def add_dict_with_count(self, tensor_dict: dict[str, torch.Tensor], count: int):
for k, v in tensor_dict.items():
self.add_scalar_with_count(k, v, count)
def add_binned_tensor(self, key: str, x: torch.Tensor, indices: torch.Tensor):
if key not in self.binned_tensors:
self.binned_tensors[key] = [x.detach().flatten()]
self.binned_tensor_indices[key] = [indices.detach().flatten()]
else:
self.binned_tensors[key].append(x.detach().flatten())
self.binned_tensor_indices[key].append(indices.detach().flatten())
def add_hook(self, hook: Callable[[torch.Tensor], tuple[str, torch.Tensor]]):
"""
Adds a custom hook, i.e. compute new metrics using values in the dict
The hook takes the dict as argument, and returns a (k, v) tuple
e.g. for computing IoU
"""
self.hooks.append(hook)
def reset_except_hooks(self):
self.values = {}
self.counts = {}
# Average and output the metrics
def finalize(self, prefix: str, it: int, ignore_timer: bool = False) -> None:
for hook in self.hooks:
k, v = hook(self.values)
self.add_scalar(k, v)
# for the metrics
outputs = {}
for k, v in self.values.items():
avg = v / self.counts[k]
if self.distributed:
# Inplace operation
if isinstance(avg, torch.Tensor):
avg = avg.cuda()
else:
avg = torch.tensor(avg).cuda()
torch.distributed.reduce(avg, dst=0)
if self.local_rank == 0:
avg = (avg / self.world_size).cpu().item()
outputs[k] = avg
else:
# Simple does it
outputs[k] = avg
if (not self.distributed) or (self.local_rank == 0):
self.logger.log_metrics(prefix, outputs, it, ignore_timer=ignore_timer)
# for the binned tensors
for k, v in self.binned_tensors.items():
x = torch.cat(v, dim=0)
indices = torch.cat(self.binned_tensor_indices[k], dim=0)
hist, count = distribute_into_histogram(x, indices)
if self.distributed:
torch.distributed.reduce(hist, dst=0)
torch.distributed.reduce(count, dst=0)
if self.local_rank == 0:
hist = hist / count
else:
hist = hist / count
if (not self.distributed) or (self.local_rank == 0):
self.logger.log_histogram(f'{prefix}/{k}', hist, it)
+231
View File
@@ -0,0 +1,231 @@
"""
Dumps things to tensorboard and console
"""
import datetime
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
from PIL import Image
from pytz import timezone
from torch.utils.tensorboard import SummaryWriter
from selva_core.utils.email_utils import EmailSender
from selva_core.utils.time_estimator import PartialTimeEstimator, TimeEstimator
from selva_core.utils.timezone import my_timezone
def tensor_to_numpy(image: torch.Tensor):
image_np = (image.numpy() * 255).astype('uint8')
return image_np
def detach_to_cpu(x: torch.Tensor):
return x.detach().cpu()
def fix_width_trunc(x: float):
return ('{:.9s}'.format('{:0.9f}'.format(x)))
def plot_spectrogram(spectrogram: np.ndarray, title=None, ylabel="freq_bin", ax=None):
if ax is None:
_, ax = plt.subplots(1, 1)
if title is not None:
ax.set_title(title)
ax.set_ylabel(ylabel)
ax.imshow(spectrogram, origin="lower", aspect="auto", interpolation="nearest")
class TensorboardLogger:
def __init__(self,
exp_id: str,
run_dir: Union[Path, str],
py_logger: logging.Logger,
*,
is_rank0: bool = False,
enable_email: bool = False):
self.exp_id = exp_id
self.run_dir = Path(run_dir)
self.py_log = py_logger
self.email_sender = EmailSender(exp_id, enable=(is_rank0 and enable_email))
if is_rank0:
self.tb_log = SummaryWriter(run_dir)
else:
self.tb_log = None
# Get current git info for logging
try:
import git
repo = git.Repo(".")
git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha)
except (ImportError, RuntimeError, TypeError):
print('Failed to fetch git info. Defaulting to None')
git_info = 'None'
self.log_string('git', git_info)
# log the SLURM job id if available
job_id = os.environ.get('SLURM_JOB_ID', None)
if job_id is not None:
self.log_string('slurm_job_id', job_id)
self.email_sender.send(f'Job {job_id} started', f'Job started {run_dir}')
# used when logging metrics
self.batch_timer: TimeEstimator = None
self.data_timer: PartialTimeEstimator = None
self.nan_count = defaultdict(int)
def log_scalar(self, tag: str, x: float, it: int):
if self.tb_log is None:
return
if math.isnan(x) and 'grad_norm' not in tag:
self.nan_count[tag] += 1
if self.nan_count[tag] == 10:
self.email_sender.send(
f'Nan detected in {tag} @ {self.run_dir}',
f'Nan detected in {tag} at iteration {it}; run_dir: {self.run_dir}')
else:
self.nan_count[tag] = 0
self.tb_log.add_scalar(tag, x, it)
def log_metrics(self,
prefix: str,
metrics: dict[str, float],
it: int,
ignore_timer: bool = False):
msg = f'{self.exp_id}-{prefix} - it {it:6d}: '
metrics_msg = ''
for k, v in sorted(metrics.items()):
self.log_scalar(f'{prefix}/{k}', v, it)
metrics_msg += f'{k: >10}:{v:.7f},\t'
if self.batch_timer is not None and not ignore_timer:
self.batch_timer.update()
avg_time = self.batch_timer.get_and_reset_avg_time()
data_time = self.data_timer.get_and_reset_avg_time()
# add time to tensorboard
self.log_scalar(f'{prefix}/avg_time', avg_time, it)
self.log_scalar(f'{prefix}/data_time', data_time, it)
est = self.batch_timer.get_est_remaining(it)
est = datetime.timedelta(seconds=est)
if est.days > 0:
remaining_str = f'{est.days}d {est.seconds // 3600}h'
else:
remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m'
eta = datetime.datetime.now(timezone(my_timezone)) + est
eta_str = eta.strftime('%Y-%m-%d %H:%M:%S %Z%z')
time_msg = f'avg_time:{avg_time:.3f},data:{data_time:.3f},remaining:{remaining_str},eta:{eta_str},\t'
msg = f'{msg} {time_msg}'
msg = f'{msg} {metrics_msg}'
self.py_log.info(msg)
def log_histogram(self, tag: str, hist: torch.Tensor, it: int):
if self.tb_log is None:
return
# hist should be a 1D tensor
hist = hist.cpu().numpy()
fig, ax = plt.subplots()
x_range = np.linspace(0, 1, len(hist))
ax.bar(x_range, hist, width=1 / (len(hist) - 1))
ax.set_xticks(x_range)
ax.set_xticklabels(x_range)
plt.tight_layout()
self.tb_log.add_figure(tag, fig, it)
plt.close()
def log_image(self, prefix: str, tag: str, image: np.ndarray, it: int):
image_dir = self.run_dir / f'{prefix}_images'
image_dir.mkdir(exist_ok=True, parents=True)
image = Image.fromarray(image)
image.save(image_dir / f'{it:09d}_{tag}.png')
def log_audio(self,
prefix: str,
tag: str,
waveform: torch.Tensor,
it: Optional[int] = None,
*,
subdir: Optional[Path] = None,
sample_rate: int = 16000) -> Path:
if subdir is None:
audio_dir = self.run_dir / prefix
else:
audio_dir = self.run_dir / subdir / prefix
audio_dir.mkdir(exist_ok=True, parents=True)
if it is None:
name = f'{tag}.flac'
else:
name = f'{it:09d}_{tag}.flac'
torchaudio.save(audio_dir / name,
waveform.cpu().float(),
sample_rate=sample_rate,
channels_first=True)
return Path(audio_dir)
def log_spectrogram(
self,
prefix: str,
tag: str,
spec: torch.Tensor,
it: Optional[int],
*,
subdir: Optional[Path] = None,
):
if subdir is None:
spec_dir = self.run_dir / prefix
else:
spec_dir = self.run_dir / subdir / prefix
spec_dir.mkdir(exist_ok=True, parents=True)
if it is None:
name = f'{tag}.png'
else:
name = f'{it:09d}_{tag}.png'
plot_spectrogram(spec.cpu().float())
plt.tight_layout()
plt.savefig(spec_dir / name)
plt.close()
def log_string(self, tag: str, x: str):
self.py_log.info(f'{tag} - {x}')
if self.tb_log is None:
return
self.tb_log.add_text(tag, x)
def debug(self, x):
self.py_log.debug(x)
def info(self, x):
self.py_log.info(x)
def warning(self, x):
self.py_log.warning(x)
def error(self, x):
self.py_log.error(x)
def critical(self, x):
self.py_log.critical(x)
self.email_sender.send(f'Error occurred in {self.run_dir}', x)
def complete(self):
self.email_sender.send(f'Job completed in {self.run_dir}', 'Job completed')
+41
View File
@@ -0,0 +1,41 @@
import importlib
from typing import Optional
from omegaconf import DictConfig, OmegaConf
def instantiate_from_config(target: str, params: Optional[dict] = None):
"""
Instantiate an object from a dotted path `target` and keyword `params`.
Common name: instantiate_from_config
"""
if not target or not isinstance(target, str):
raise ValueError(f"Invalid target: {target!r}")
params = {} if params is None else params
# Convert OmegaConf DictConfig to plain dict if needed
try:
if isinstance(params, DictConfig):
params = OmegaConf.to_container(params, resolve=True)
except Exception:
pass
try:
module_path, attr_name = target.rsplit('.', 1)
except ValueError as e:
raise ValueError(f"Target must be like 'pkg.mod.Class', got {target!r}") from e
try:
module = importlib.import_module(module_path)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(f"Could not import module '{module_path}' for target '{target}'.") from e
try:
obj = getattr(module, attr_name)
except AttributeError as e:
raise AttributeError(f"Module '{module_path}' has no attribute '{attr_name}' (from '{target}').") from e
if not callable(obj):
raise TypeError(f"Resolved target '{target}' is not callable.")
return obj(**dict(params))
+22
View File
@@ -0,0 +1,22 @@
from typing import Optional
from nitrous_ema import PostHocEMA
from omegaconf import DictConfig
from selva_core.model.utils.factory import create_model_from_factory
def synthesize_ema(cfg: DictConfig, sigma: float, step: Optional[int]):
vae = create_model_from_factory(cfg.model.factory_path,
cfg.model.name,
**cfg.model.get('params', {})
)
emas = PostHocEMA(vae,
sigma_rels=cfg.ema.sigma_rels,
update_every=cfg.ema.update_every,
checkpoint_every_num_steps=cfg.ema.checkpoint_every,
checkpoint_folder=cfg.ema.checkpoint_folder)
synthesized_ema = emas.synthesize_ema_model(sigma_rel=sigma, step=step, device='cpu')
state_dict = synthesized_ema.ema_model.state_dict()
return state_dict
+14
View File
@@ -0,0 +1,14 @@
import torch
def distribute_into_histogram(loss: torch.Tensor,
t: torch.Tensor,
num_bins: int = 25) -> tuple[torch.Tensor, torch.Tensor]:
loss = loss.detach().flatten()
t = t.detach().flatten()
t = (t * num_bins).long()
hist = torch.zeros(num_bins, device=loss.device)
count = torch.zeros(num_bins, device=loss.device)
hist.scatter_add_(0, t, loss)
count.scatter_add_(0, t, torch.ones_like(loss))
return hist, count
+72
View File
@@ -0,0 +1,72 @@
import time
class TimeEstimator:
def __init__(self, total_iter: int, step_size: int, ema_alpha: float = 0.7):
self.avg_time_window = [] # window-based average
self.exp_avg_time = None # exponential moving average
self.alpha = ema_alpha # for exponential moving average
self.last_time = time.time() # would not be accurate for the first iteration but well
self.total_iter = total_iter
self.step_size = step_size
self._buffering_exp = True
# call this at a fixed interval
# does not have to be every step
def update(self):
curr_time = time.time()
time_per_iter = curr_time - self.last_time
self.last_time = curr_time
self.avg_time_window.append(time_per_iter)
if self._buffering_exp:
if self.exp_avg_time is not None:
# discard the first iteration call to not pollute the ema
self._buffering_exp = False
self.exp_avg_time = time_per_iter
else:
self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter
def get_est_remaining(self, it: int):
if self.exp_avg_time is None:
return 0
remaining_iter = self.total_iter - it
return remaining_iter * self.exp_avg_time / self.step_size
def get_and_reset_avg_time(self):
avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size
self.avg_time_window = []
return avg
class PartialTimeEstimator(TimeEstimator):
"""
Used where the start_time and the end_time do not align
"""
def update(self):
raise RuntimeError('Please use start() and end() for PartialTimeEstimator')
def start(self):
self.last_time = time.time()
def end(self):
assert self.last_time is not None, 'Please call start() before calling end()'
curr_time = time.time()
time_per_iter = curr_time - self.last_time
self.last_time = None
self.avg_time_window.append(time_per_iter)
if self._buffering_exp:
if self.exp_avg_time is not None:
# discard the first iteration call to not pollute the ema
self._buffering_exp = False
self.exp_avg_time = time_per_iter
else:
self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter
+1
View File
@@ -0,0 +1 @@
my_timezone = 'Asia/Seoul' # 'US/Central'
+140
View File
@@ -0,0 +1,140 @@
import gc
from typing import Optional, Union
import torch
from omegaconf import DictConfig
from selva_core.model.utils.features_utils import FeatureUtils
from selva_core.model.networks_video_enc import TextSynch
from selva_core.data.mixup import FeatureMixup
@torch.no_grad()
def preprocess_batch_with_tsynch(
batch: dict,
mixup_config: DictConfig,
feature_extractor: FeatureUtils,
net_video_enc: TextSynch,
feature_mixup: Optional[FeatureMixup] = None,
training: bool = False,
) -> None:
if mixup_config.domain == 'embedding' and feature_mixup is None:
raise ValueError('Mixup function is required for embedding domain.')
bs: int = len(batch['id'])
device = feature_extractor.device
dtype = feature_extractor.dtype
video_exist = batch.get('sync_video', None) is not None
text_exist = batch.get('caption', None) is not None
batch['video_exist'] = torch.tensor(video_exist, device=device, dtype=torch.bool, requires_grad=False)
batch['text_exist'] = torch.tensor(text_exist, device=device, dtype=torch.bool, requires_grad=False)
batch['a_mean'] = batch.get('a_mean').to(device, dtype)
batch['a_std'] = batch['a_std'].to(device, dtype)
batch['clip_features'] = None
batch['text_clip_features'] = batch['text_clip_features'].to(device, dtype)
if video_exist:
tsynch_text_features = batch['text_features'].to(device, dtype)
tsynch_text_mask = batch['text_masks'].to(device, dtype)
if mixup_config.enabled and mixup_config.domain == 'data':
tsynch_text_features, tsynch_text_mask = net_video_enc.prepend_silence_text_tokens(tsynch_text_features, tsynch_text_mask)
batch['sync_video_mixed'] = batch['sync_video_mixed'].to(device, dtype, non_blocking=True)
batch['sync_features'] = net_video_enc.encode_video_with_sync(
batch['sync_video_mixed'], text_f=tsynch_text_features, text_mask=tsynch_text_mask
)
elif mixup_config.enabled and mixup_config.domain == 'embedding':
assert feature_mixup.modality == mixup_config.params.modality, \
f"Mixup class modality {feature_mixup.modality} should be same as config {mixup_config.params.modality}."
feature_mixup.target_video_key = 'sync_features'
feature_mixup(batch)
else:
batch['sync_video'] = batch['sync_video'].to(device, dtype)
tsynch_text_features, tsynch_text_mask = net_video_enc.prepend_silence_text_tokens(tsynch_text_features, tsynch_text_mask)
batch['sync_features'] = net_video_enc.encode_video_with_sync(
batch['sync_video'], text_f=tsynch_text_features, text_mask=tsynch_text_mask
)
if training:
for k, v in batch.items():
if k in ['video_exist', 'text_exist', 'sync_features']:
batch[k] = v.clone()
gc.collect()
torch.cuda.empty_cache()
@torch.no_grad()
def preprocess_batch_with_mixup(
batch: dict,
mixup_config: DictConfig,
feature_extractor: FeatureUtils = None,
sync_batch_size_multiplier: int = 40,
feature_mixup: Optional[FeatureMixup] = None,
training: bool = False,
) -> None:
if mixup_config.domain == 'embedding' and feature_mixup is None:
raise ValueError('Mixup function is required for embedding domain.')
bs: int = len(batch['id'])
batch['text_features'], batch['text_mask'] = feature_extractor.encode_text(batch['caption'])
if mixup_config.params.modality in ['video', 'both']:
batch['sync_f_vid_orig'] = feature_extractor.encode_video_with_sync(batch['sync_video'],
batch_size=bs *
sync_batch_size_multiplier)
if mixup_config.params.modality in ['audio', 'both']:
batch['sync_f_aud_orig'] = feature_extractor.encode_audio_with_sync(batch['audio'],
batch_size=bs *
sync_batch_size_multiplier)
if mixup_config.domain == 'data':
if mixup_config.params.modality in ['video', 'both']:
batch['sync_f_vid_mixed'] = feature_extractor.encode_video_with_sync(batch['sync_video_mixed'],
batch_size=bs *
sync_batch_size_multiplier)
if mixup_config.params.modality in ['audio', 'both']:
batch['sync_f_aud_mixed'] = feature_extractor.encode_audio_with_sync(batch['audio_mixed'],
batch_size=bs *
sync_batch_size_multiplier)
if mixup_config.domain == 'embedding':
assert feature_mixup.modality == mixup_config.params.modality, \
f"Mixup class modality {feature_mixup.modality} should be same as config {mixup_config.params.modality}."
feature_mixup(batch)
if training:
for k, v in batch.items():
if k in ['text_features', 'text_mask',
'sync_f_vid_orig', 'sync_f_aud_orig', 'sync_f_vid_mixed', 'sync_f_aud_mixed']:
batch[k] = v.clone()
gc.collect()
torch.cuda.empty_cache()
@torch.no_grad()
def preprocess_batch(
batch: dict,
mixup_config: DictConfig,
feature_extractor: FeatureUtils = None,
sync_batch_size_multiplier: Union[int, float] = 40,
training: bool = False,
) -> None:
bs: int = len(batch['id'])
batch['text_features'], batch['text_mask'] = feature_extractor.encode_text(batch['caption'])
sync_batch_size = int(bs * sync_batch_size_multiplier) if sync_batch_size_multiplier > 0 else bs
if mixup_config.params.modality in ['video', 'both']:
batch['sync_f_vid_orig'] = feature_extractor.encode_video_with_sync(batch['sync_video'],
batch_size=sync_batch_size)
if mixup_config.params.modality in ['audio', 'both']:
batch['sync_f_aud_orig'] = feature_extractor.encode_audio_with_sync(batch['audio'],
batch_size=sync_batch_size)
if training:
for k, v in batch.items():
if k in ['text_features', 'text_mask',
'sync_f_vid_orig', 'sync_f_aud_orig']:
batch[k] = v.clone()
gc.collect()
torch.cuda.empty_cache()
+20
View File
@@ -0,0 +1,20 @@
import torch
def generate_multiple_segments(
x: torch.Tensor,
segment_size: int,
step_size: int,
) -> torch.Tensor:
# x: (B, T, ...)
b, t, *rest = x.shape
assert t >= segment_size, f'The length of the input tensor {t} is less than the segment size {segment_size}.'
assert segment_size > step_size, f'The segment size {segment_size} should be greater than the step size {step_size}.'
# partition the tensor into segments
num_segments = (t - segment_size) // step_size + 1
segments = []
for i in range(num_segments):
segments.append(x[:, i * step_size:i * step_size + segment_size])
x = torch.stack(segments, dim=1)
return x # (B, S, T, ...)
+66
View File
@@ -0,0 +1,66 @@
from pathlib import Path
from typing import Union
import torch
from torio.io import StreamingMediaDecoder, StreamingMediaEncoder
class VideoJoiner:
def __init__(self, src_root: Union[str, Path], output_root: Union[str, Path], sample_rate: int,
duration_seconds: float):
self.src_root = Path(src_root)
self.output_root = Path(output_root)
self.sample_rate = sample_rate
self.duration_seconds = duration_seconds
self.output_root.mkdir(parents=True, exist_ok=True)
def join(self, video_id: str, output_name: str, audio: torch.Tensor):
video_path = self.src_root / f'{video_id}.mp4'
output_path = self.output_root / f'{output_name}.mp4'
merge_audio_into_video(video_path, output_path, audio, self.sample_rate,
self.duration_seconds)
def merge_audio_into_video(video_path: Union[str, Path], output_path: Union[str, Path],
audio: torch.Tensor, sample_rate: int, duration_seconds: float):
# audio: (num_samples, num_channels=1/2)
frame_rate = 24
# read the video
reader = StreamingMediaDecoder(video_path)
reader.add_basic_video_stream(
frames_per_chunk=int(frame_rate * duration_seconds),
# buffer_chunk_size=1, # does not work with this -- extracted audio would be too short
format="rgb24",
frame_rate=frame_rate,
)
reader.fill_buffer()
video_chunk = reader.pop_chunks()[0]
t, _, h, w = video_chunk.shape
writer = StreamingMediaEncoder(output_path)
writer.add_audio_stream(
sample_rate=sample_rate,
num_channels=audio.shape[-1],
encoder='aac', # "libmp3lame",
)
writer.add_video_stream(frame_rate=frame_rate,
width=w,
height=h,
format="rgb24",
encoder="libx264",
encoder_format="yuv420p")
with writer.open():
writer.write_video_chunk(1, video_chunk)
writer.write_audio_chunk(0, audio.float())
if __name__ == '__main__':
# Usage example
import sys
audio = torch.randn(16000 * 4, 1)
merge_audio_into_video(sys.argv[1], sys.argv[2], audio, 16000, 4)