chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,277 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from colorlog import ColoredFormatter
|
||||
from PIL import Image
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from selva_core.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
from selva_core.model.networks_video_enc import TextSynch
|
||||
from selva_core.model.networks_generator import MMAudio
|
||||
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
|
||||
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||
from selva_core.utils.download_utils import download_model_if_needed
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelConfig:
|
||||
model_name: str
|
||||
model_video_enc_path: Path
|
||||
model_generator_path: Path
|
||||
mode: str
|
||||
vae_path: Path
|
||||
bigvgan_16k_path: Optional[Path]
|
||||
synchformer_ckpt: Path = Path('./ext_weights/synchformer_state_dict.pth')
|
||||
|
||||
@property
|
||||
def seq_cfg(self) -> SequenceConfig:
|
||||
if self.mode == '16k':
|
||||
return CONFIG_16K
|
||||
elif self.mode == '44k':
|
||||
return CONFIG_44K
|
||||
|
||||
def download_if_needed(self):
|
||||
download_model_if_needed(self.model_video_enc_path)
|
||||
download_model_if_needed(self.model_generator_path)
|
||||
download_model_if_needed(self.vae_path)
|
||||
if self.bigvgan_16k_path is not None:
|
||||
download_model_if_needed(self.bigvgan_16k_path)
|
||||
download_model_if_needed(self.synchformer_ckpt)
|
||||
|
||||
def download_video_enc_if_needed(self):
|
||||
download_model_if_needed(self.model_video_enc_path)
|
||||
|
||||
def download_generator_if_needed(self):
|
||||
download_model_if_needed(self.model_generator_path)
|
||||
|
||||
def download_external_modules_if_needed(self):
|
||||
download_model_if_needed(self.synchformer_ckpt)
|
||||
download_model_if_needed(self.vae_path)
|
||||
if self.bigvgan_16k_path is not None:
|
||||
download_model_if_needed(self.bigvgan_16k_path)
|
||||
|
||||
|
||||
small_16k = ModelConfig(model_name='small_16k',
|
||||
model_video_enc_path=Path('./weights/video_enc_sup_5.pth'),
|
||||
model_generator_path=Path('./weights/generator_small_16k_sup_5.pth'),
|
||||
vae_path=Path('./ext_weights/v1-16.pth'),
|
||||
bigvgan_16k_path=Path('./ext_weights/best_netG.pt'),
|
||||
mode='16k')
|
||||
small_44k = ModelConfig(model_name='small_44k',
|
||||
model_video_enc_path=Path('./weights/video_enc_sup_5.pth'),
|
||||
model_generator_path=Path('./weights/generator_small_44k_sup_5.pth'),
|
||||
vae_path=Path('./ext_weights/v1-44.pth'),
|
||||
bigvgan_16k_path=None,
|
||||
mode='44k')
|
||||
medium_44k = ModelConfig(model_name='medium_44k',
|
||||
model_video_enc_path=Path('./weights/video_enc_sup_5.pth'),
|
||||
model_generator_path=Path('./weights/generator_medium_44k_sup_5.pth'),
|
||||
vae_path=Path('./ext_weights/v1-44.pth'),
|
||||
bigvgan_16k_path=None,
|
||||
mode='44k')
|
||||
large_44k = ModelConfig(model_name='large_44k',
|
||||
model_video_enc_path=Path('./weights/video_enc_sup_5.pth'),
|
||||
model_generator_path=Path('./weights/generator_large_44k_sup_5.pth'),
|
||||
vae_path=Path('./ext_weights/v1-44.pth'),
|
||||
bigvgan_16k_path=None,
|
||||
mode='44k')
|
||||
all_model_cfg: dict[str, ModelConfig] = {
|
||||
'small_16k': small_16k,
|
||||
'small_44k': small_44k,
|
||||
'medium_44k': medium_44k,
|
||||
'large_44k': large_44k,
|
||||
}
|
||||
|
||||
|
||||
def generate(
|
||||
clip_video: Optional[torch.Tensor],
|
||||
sync_video: Optional[torch.Tensor],
|
||||
text: Optional[list[str]],
|
||||
*,
|
||||
negative_text: Optional[list[str]] = None,
|
||||
feature_utils: FeaturesUtils,
|
||||
net_video_enc: TextSynch,
|
||||
net_generator: MMAudio,
|
||||
fm: FlowMatching,
|
||||
rng: torch.Generator,
|
||||
cfg_strength: float,
|
||||
clip_batch_size_multiplier: int = 40,
|
||||
sync_batch_size_multiplier: int = 40,
|
||||
image_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
device = feature_utils.device
|
||||
dtype = feature_utils.dtype
|
||||
|
||||
bs = len(text)
|
||||
if text is not None:
|
||||
text_features_clip = feature_utils.encode_text_clip(text)
|
||||
text_features_flant5, text_mask_flant5 = feature_utils.encode_text_t5(text)
|
||||
else:
|
||||
text_features_clip = net_generator.get_empty_string_sequence(bs)
|
||||
text_features_flant5 = net_video_enc.get_empty_string_sequence(bs)
|
||||
text_mask_flant5 = torch.zeros_like(text_features_flant5)
|
||||
text_mask_flant5[:, 0] = 1
|
||||
|
||||
if negative_text is not None:
|
||||
assert len(negative_text) == bs
|
||||
negative_text_features_clip = feature_utils.encode_text_clip(negative_text)
|
||||
negative_text_features_flant5, negative_text_mask_flant5 = feature_utils.encode_text_t5(negative_text)
|
||||
else:
|
||||
negative_text_features_clip = None
|
||||
negative_text_features_flant5, negative_text_mask_flant5 = None, None
|
||||
|
||||
if clip_video is not None:
|
||||
clip_video = clip_video.to(device, dtype, non_blocking=True)
|
||||
clip_features = feature_utils.encode_video_with_clip(clip_video,
|
||||
batch_size=bs *
|
||||
clip_batch_size_multiplier)
|
||||
if image_input:
|
||||
clip_features = clip_features.expand(-1, net_generator.clip_seq_len, -1)
|
||||
else:
|
||||
clip_features = net_generator.get_empty_clip_sequence(bs)
|
||||
|
||||
if sync_video is not None and not image_input:
|
||||
text_features_flant5, text_mask_flant5 = net_video_enc.prepend_sup_text_tokens(text_features_flant5, text_mask_flant5)
|
||||
sync_video = sync_video.to(net_video_enc.device, net_video_enc.dtype, non_blocking=True)
|
||||
sync_features = net_video_enc.encode_video_with_sync(
|
||||
sync_video, text_f=text_features_flant5, text_mask=text_mask_flant5
|
||||
)
|
||||
else:
|
||||
sync_features = net_generator.get_empty_sync_sequence(bs)
|
||||
|
||||
x0 = torch.randn(bs,
|
||||
net_generator.latent_seq_len,
|
||||
net_generator.latent_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=rng)
|
||||
preprocessed_conditions = net_generator.preprocess_conditions(clip_features, sync_features, text_features_clip)
|
||||
empty_conditions = net_generator.get_empty_conditions(
|
||||
bs, negative_text_features=negative_text_features_clip
|
||||
)
|
||||
|
||||
cfg_ode_wrapper = lambda t, x: net_generator.ode_wrapper(t, x, preprocessed_conditions, empty_conditions,
|
||||
cfg_strength)
|
||||
x1 = fm.to_data(cfg_ode_wrapper, x0)
|
||||
x1 = net_generator.unnormalize(x1)
|
||||
spec = feature_utils.decode(x1)
|
||||
audio = feature_utils.vocode(spec)
|
||||
return audio
|
||||
|
||||
|
||||
LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s"
|
||||
|
||||
|
||||
def setup_eval_logging(log_level: int = logging.INFO):
|
||||
logging.root.setLevel(log_level)
|
||||
formatter = ColoredFormatter(LOGFORMAT)
|
||||
stream = logging.StreamHandler()
|
||||
stream.setLevel(log_level)
|
||||
stream.setFormatter(formatter)
|
||||
log = logging.getLogger()
|
||||
log.setLevel(log_level)
|
||||
log.addHandler(stream)
|
||||
|
||||
|
||||
_CLIP_SIZE = 384
|
||||
_CLIP_FPS = 8.0
|
||||
|
||||
_SYNC_SIZE = 224
|
||||
_SYNC_FPS = 25.0
|
||||
|
||||
|
||||
def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
|
||||
|
||||
clip_transform = v2.Compose([
|
||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
|
||||
sync_transform = v2.Compose([
|
||||
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
# v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
output_frames, all_frames, orig_fps = read_frames(video_path,
|
||||
list_of_fps=[_CLIP_FPS, _SYNC_FPS],
|
||||
start_sec=0,
|
||||
end_sec=duration_sec,
|
||||
need_all_frames=load_all_frames)
|
||||
|
||||
clip_chunk, sync_chunk = output_frames
|
||||
clip_chunk = torch.from_numpy(clip_chunk).permute(0, 3, 1, 2)
|
||||
sync_chunk = torch.from_numpy(sync_chunk).permute(0, 3, 1, 2)
|
||||
|
||||
clip_frames = clip_transform(clip_chunk)
|
||||
sync_frames = sync_transform(sync_chunk)
|
||||
|
||||
clip_length_sec = clip_frames.shape[0] / _CLIP_FPS
|
||||
sync_length_sec = sync_frames.shape[0] / _SYNC_FPS
|
||||
|
||||
if clip_length_sec < duration_sec:
|
||||
log.warning(f'Clip video is too short: {clip_length_sec:.2f} < {duration_sec:.2f}')
|
||||
log.warning(f'Truncating to {clip_length_sec:.2f} sec')
|
||||
duration_sec = clip_length_sec
|
||||
|
||||
if sync_length_sec < duration_sec:
|
||||
log.warning(f'Sync video is too short: {sync_length_sec:.2f} < {duration_sec:.2f}')
|
||||
log.warning(f'Truncating to {sync_length_sec:.2f} sec')
|
||||
duration_sec = sync_length_sec
|
||||
|
||||
clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
|
||||
sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
|
||||
|
||||
video_info = VideoInfo(
|
||||
duration_sec=duration_sec,
|
||||
fps=orig_fps,
|
||||
clip_frames=clip_frames,
|
||||
sync_frames=sync_frames,
|
||||
all_frames=all_frames if load_all_frames else None,
|
||||
)
|
||||
return video_info
|
||||
|
||||
|
||||
def load_image(image_path: Path) -> VideoInfo:
|
||||
clip_transform = v2.Compose([
|
||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
|
||||
sync_transform = v2.Compose([
|
||||
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
frame = np.array(Image.open(image_path))
|
||||
|
||||
clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
|
||||
sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
|
||||
|
||||
clip_frames = clip_transform(clip_chunk)
|
||||
sync_frames = sync_transform(sync_chunk)
|
||||
|
||||
video_info = ImageInfo(
|
||||
clip_frames=clip_frames,
|
||||
sync_frames=sync_frames,
|
||||
original_frame=frame,
|
||||
)
|
||||
return video_info
|
||||
|
||||
|
||||
def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
||||
reencode_with_audio(video_info, output_path, audio, sampling_rate)
|
||||
Reference in New Issue
Block a user