import dataclasses import logging from pathlib import Path from typing import Optional import numpy as np import torch from colorlog import ColoredFormatter from PIL import Image from torchvision.transforms import v2 from selva_core.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio from selva_core.model.flow_matching import FlowMatching from selva_core.model.networks_video_enc import TextSynch from selva_core.model.networks_generator import MMAudio from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig from selva_core.model.utils.features_utils import FeaturesUtils from selva_core.utils.download_utils import download_model_if_needed log = logging.getLogger() @dataclasses.dataclass class ModelConfig: model_name: str model_video_enc_path: Path model_generator_path: Path mode: str vae_path: Path bigvgan_16k_path: Optional[Path] synchformer_ckpt: Path = Path('./ext_weights/synchformer_state_dict.pth') @property def seq_cfg(self) -> SequenceConfig: if self.mode == '16k': return CONFIG_16K elif self.mode == '44k': return CONFIG_44K def download_if_needed(self): download_model_if_needed(self.model_video_enc_path) download_model_if_needed(self.model_generator_path) download_model_if_needed(self.vae_path) if self.bigvgan_16k_path is not None: download_model_if_needed(self.bigvgan_16k_path) download_model_if_needed(self.synchformer_ckpt) def download_video_enc_if_needed(self): download_model_if_needed(self.model_video_enc_path) def download_generator_if_needed(self): download_model_if_needed(self.model_generator_path) def download_external_modules_if_needed(self): download_model_if_needed(self.synchformer_ckpt) download_model_if_needed(self.vae_path) if self.bigvgan_16k_path is not None: download_model_if_needed(self.bigvgan_16k_path) small_16k = ModelConfig(model_name='small_16k', model_video_enc_path=Path('./weights/video_enc_sup_5.pth'), model_generator_path=Path('./weights/generator_small_16k_sup_5.pth'), vae_path=Path('./ext_weights/v1-16.pth'), bigvgan_16k_path=Path('./ext_weights/best_netG.pt'), mode='16k') small_44k = ModelConfig(model_name='small_44k', model_video_enc_path=Path('./weights/video_enc_sup_5.pth'), model_generator_path=Path('./weights/generator_small_44k_sup_5.pth'), vae_path=Path('./ext_weights/v1-44.pth'), bigvgan_16k_path=None, mode='44k') medium_44k = ModelConfig(model_name='medium_44k', model_video_enc_path=Path('./weights/video_enc_sup_5.pth'), model_generator_path=Path('./weights/generator_medium_44k_sup_5.pth'), vae_path=Path('./ext_weights/v1-44.pth'), bigvgan_16k_path=None, mode='44k') large_44k = ModelConfig(model_name='large_44k', model_video_enc_path=Path('./weights/video_enc_sup_5.pth'), model_generator_path=Path('./weights/generator_large_44k_sup_5.pth'), vae_path=Path('./ext_weights/v1-44.pth'), bigvgan_16k_path=None, mode='44k') all_model_cfg: dict[str, ModelConfig] = { 'small_16k': small_16k, 'small_44k': small_44k, 'medium_44k': medium_44k, 'large_44k': large_44k, } def generate( clip_video: Optional[torch.Tensor], sync_video: Optional[torch.Tensor], text: Optional[list[str]], *, negative_text: Optional[list[str]] = None, feature_utils: FeaturesUtils, net_video_enc: TextSynch, net_generator: MMAudio, fm: FlowMatching, rng: torch.Generator, cfg_strength: float, clip_batch_size_multiplier: int = 40, sync_batch_size_multiplier: int = 40, image_input: bool = False, ) -> torch.Tensor: device = feature_utils.device dtype = feature_utils.dtype bs = len(text) if text is not None: text_features_clip = feature_utils.encode_text_clip(text) text_features_flant5, text_mask_flant5 = feature_utils.encode_text_t5(text) else: text_features_clip = net_generator.get_empty_string_sequence(bs) text_features_flant5 = net_video_enc.get_empty_string_sequence(bs) text_mask_flant5 = torch.zeros_like(text_features_flant5) text_mask_flant5[:, 0] = 1 if negative_text is not None: assert len(negative_text) == bs negative_text_features_clip = feature_utils.encode_text_clip(negative_text) negative_text_features_flant5, negative_text_mask_flant5 = feature_utils.encode_text_t5(negative_text) else: negative_text_features_clip = None negative_text_features_flant5, negative_text_mask_flant5 = None, None if clip_video is not None: clip_video = clip_video.to(device, dtype, non_blocking=True) clip_features = feature_utils.encode_video_with_clip(clip_video, batch_size=bs * clip_batch_size_multiplier) if image_input: clip_features = clip_features.expand(-1, net_generator.clip_seq_len, -1) else: clip_features = net_generator.get_empty_clip_sequence(bs) if sync_video is not None and not image_input: text_features_flant5, text_mask_flant5 = net_video_enc.prepend_sup_text_tokens(text_features_flant5, text_mask_flant5) sync_video = sync_video.to(net_video_enc.device, net_video_enc.dtype, non_blocking=True) sync_features = net_video_enc.encode_video_with_sync( sync_video, text_f=text_features_flant5, text_mask=text_mask_flant5 ) else: sync_features = net_generator.get_empty_sync_sequence(bs) x0 = torch.randn(bs, net_generator.latent_seq_len, net_generator.latent_dim, device=device, dtype=dtype, generator=rng) preprocessed_conditions = net_generator.preprocess_conditions(clip_features, sync_features, text_features_clip) empty_conditions = net_generator.get_empty_conditions( bs, negative_text_features=negative_text_features_clip ) cfg_ode_wrapper = lambda t, x: net_generator.ode_wrapper(t, x, preprocessed_conditions, empty_conditions, cfg_strength) x1 = fm.to_data(cfg_ode_wrapper, x0) x1 = net_generator.unnormalize(x1) spec = feature_utils.decode(x1) audio = feature_utils.vocode(spec) return audio LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s" def setup_eval_logging(log_level: int = logging.INFO): logging.root.setLevel(log_level) formatter = ColoredFormatter(LOGFORMAT) stream = logging.StreamHandler() stream.setLevel(log_level) stream.setFormatter(formatter) log = logging.getLogger() log.setLevel(log_level) log.addHandler(stream) _CLIP_SIZE = 384 _CLIP_FPS = 8.0 _SYNC_SIZE = 224 _SYNC_FPS = 25.0 def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo: clip_transform = v2.Compose([ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), ]) sync_transform = v2.Compose([ v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC), # v2.CenterCrop(_SYNC_SIZE), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) output_frames, all_frames, orig_fps = read_frames(video_path, list_of_fps=[_CLIP_FPS, _SYNC_FPS], start_sec=0, end_sec=duration_sec, need_all_frames=load_all_frames) clip_chunk, sync_chunk = output_frames clip_chunk = torch.from_numpy(clip_chunk).permute(0, 3, 1, 2) sync_chunk = torch.from_numpy(sync_chunk).permute(0, 3, 1, 2) clip_frames = clip_transform(clip_chunk) sync_frames = sync_transform(sync_chunk) clip_length_sec = clip_frames.shape[0] / _CLIP_FPS sync_length_sec = sync_frames.shape[0] / _SYNC_FPS if clip_length_sec < duration_sec: log.warning(f'Clip video is too short: {clip_length_sec:.2f} < {duration_sec:.2f}') log.warning(f'Truncating to {clip_length_sec:.2f} sec') duration_sec = clip_length_sec if sync_length_sec < duration_sec: log.warning(f'Sync video is too short: {sync_length_sec:.2f} < {duration_sec:.2f}') log.warning(f'Truncating to {sync_length_sec:.2f} sec') duration_sec = sync_length_sec clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)] sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)] video_info = VideoInfo( duration_sec=duration_sec, fps=orig_fps, clip_frames=clip_frames, sync_frames=sync_frames, all_frames=all_frames if load_all_frames else None, ) return video_info def load_image(image_path: Path) -> VideoInfo: clip_transform = v2.Compose([ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), ]) sync_transform = v2.Compose([ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), v2.CenterCrop(_SYNC_SIZE), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) frame = np.array(Image.open(image_path)) clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2) sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2) clip_frames = clip_transform(clip_chunk) sync_frames = sync_transform(sync_chunk) video_info = ImageInfo( clip_frames=clip_frames, sync_frames=sync_frames, original_frame=frame, ) return video_info def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int): reencode_with_audio(video_info, output_path, audio, sampling_rate)