chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,46 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution:
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self, rng: Optional[torch.Generator] = None):
|
||||
# x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
|
||||
r = torch.empty_like(self.mean).normal_(generator=rng)
|
||||
x = self.mean + self.std * r
|
||||
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
|
||||
return 0.5 * torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar
|
||||
else:
|
||||
return 0.5 * (torch.pow(self.mean - other.mean, 2) / other.var +
|
||||
self.var / other.var - 1.0 - self.logvar + other.logvar)
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
from selva_core.utils.misc import instantiate_from_config
|
||||
from selva_core.model.networks_video_enc import TextSynch as TextSynchVideoEnc
|
||||
from selva_core.model.networks_generator import MMAudio
|
||||
|
||||
|
||||
_MODEL_ZOO = (TextSynchVideoEnc, MMAudio)
|
||||
|
||||
|
||||
def create_model_from_factory(factory_path: str, name: str, **kwargs) -> torch.nn.Module:
|
||||
"""
|
||||
Dynamically imports and calls a model factory function.
|
||||
"""
|
||||
params = {'name': name, **kwargs}
|
||||
model = instantiate_from_config(factory_path, params)
|
||||
assert isinstance(model, _MODEL_ZOO), f"Model {type(model)} is not a valid model type."
|
||||
return model
|
||||
@@ -0,0 +1,192 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
import open_clip
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from open_clip import create_model_from_pretrained
|
||||
from torchvision.transforms import Normalize
|
||||
from transformers import T5TokenizerFast, T5EncoderModel
|
||||
|
||||
from selva_core.ext.autoencoder import AutoEncoderModule
|
||||
from selva_core.ext.mel_converter import get_mel_converter
|
||||
from selva_core.ext.synchformer import Synchformer
|
||||
from selva_core.model.utils.distributions import DiagonalGaussianDistribution
|
||||
from selva_core.utils.transforms import generate_multiple_segments
|
||||
|
||||
|
||||
def patch_clip(clip_model):
|
||||
# a hack to make it output last hidden states
|
||||
# https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
|
||||
def new_encode_text(self, text, normalize: bool = False):
|
||||
cast_dtype = self.transformer.get_cast_dtype()
|
||||
|
||||
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding.to(cast_dtype)
|
||||
x = self.transformer(x, attn_mask=self.attn_mask)
|
||||
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
||||
return F.normalize(x, dim=-1) if normalize else x
|
||||
|
||||
clip_model.encode_text = new_encode_text.__get__(clip_model)
|
||||
return clip_model
|
||||
|
||||
|
||||
class FeaturesUtils(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tod_vae_ckpt: Optional[str] = None,
|
||||
bigvgan_vocoder_ckpt: Optional[str] = None,
|
||||
synchformer_ckpt: Optional[str] = None,
|
||||
enable_conditions: bool = True,
|
||||
mode=Literal['16k', '44k'],
|
||||
need_vae_encoder: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if enable_conditions:
|
||||
self.clip_model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',
|
||||
return_transform=False)
|
||||
self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711])
|
||||
self.clip_model = patch_clip(self.clip_model)
|
||||
self.tokenizer_clip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14'
|
||||
|
||||
self.synchformer = Synchformer(video=True, audio=False)
|
||||
self.synchformer.load_state_dict(
|
||||
torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
|
||||
|
||||
self.text_encoder_t5 = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||
self.tokenizer_t5 = T5TokenizerFast.from_pretrained('google/flan-t5-base')
|
||||
else:
|
||||
self.clip_model = None
|
||||
self.synchformer = None
|
||||
self.tokenizer_clip = None
|
||||
self.text_encoder_t5 = None
|
||||
self.tokenizer_t5 = None
|
||||
|
||||
if tod_vae_ckpt is not None:
|
||||
self.mel_converter = get_mel_converter(mode)
|
||||
self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
|
||||
vocoder_ckpt_path=bigvgan_vocoder_ckpt,
|
||||
mode=mode,
|
||||
need_vae_encoder=need_vae_encoder)
|
||||
else:
|
||||
self.tod = None
|
||||
|
||||
def compile(self):
|
||||
if self.clip_model is not None:
|
||||
self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
|
||||
self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
|
||||
if self.synchformer is not None:
|
||||
self.synchformer = torch.compile(self.synchformer)
|
||||
self.synchformer.forward_vfeat = torch.compile(self.synchformer.forward_vfeat)
|
||||
if self.text_encoder_t5 is not None:
|
||||
self.text_encoder_t5.forward = torch.compile(self.text_encoder_t5.forward)
|
||||
self.decode = torch.compile(self.decode)
|
||||
self.vocode = torch.compile(self.vocode)
|
||||
|
||||
def train(self, mode: bool) -> None:
|
||||
return super().train(False)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
|
||||
assert self.clip_model is not None, 'CLIP is not loaded'
|
||||
# x: (B, T, C, H, W) H/W: 384
|
||||
b, t, c, h, w = x.shape
|
||||
assert c == 3 and h == 384 and w == 384
|
||||
x = self.clip_preprocess(x)
|
||||
x = rearrange(x, 'b t c h w -> (b t) c h w')
|
||||
outputs = []
|
||||
if batch_size < 0:
|
||||
batch_size = b * t
|
||||
for i in range(0, b * t, batch_size):
|
||||
outputs.append(self.clip_model.encode_image(x[i:i + batch_size], normalize=True))
|
||||
x = torch.cat(outputs, dim=0)
|
||||
# x = self.clip_model.encode_image(x, normalize=True)
|
||||
x = rearrange(x, '(b t) d -> b t d', b=b)
|
||||
return x
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
|
||||
assert self.synchformer is not None, 'Synchformer is not loaded'
|
||||
# x: (B, T, C, H, W) H/W: 384
|
||||
|
||||
b, t, c, h, w = x.shape
|
||||
assert c == 3 and h == 224 and w == 224
|
||||
|
||||
# partition the video
|
||||
segment_size = 16
|
||||
step_size = 8
|
||||
x = generate_multiple_segments(x, segment_size, step_size) # (B, S, T, C, H, W)
|
||||
num_segments = x.shape[1]
|
||||
|
||||
outputs = []
|
||||
if batch_size < 0:
|
||||
batch_size = b
|
||||
x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
|
||||
for i in range(0, b * num_segments, batch_size):
|
||||
outputs.append(self.synchformer.forward_vfeat(x[i:i + batch_size]))
|
||||
x = torch.cat(outputs, dim=0)
|
||||
x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
|
||||
return x
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_text_clip(self, text: list[str]) -> torch.Tensor:
|
||||
assert self.clip_model is not None, 'CLIP is not loaded'
|
||||
assert self.tokenizer_clip is not None, 'Tokenizer is not loaded'
|
||||
# x: (B, L)
|
||||
tokens = self.tokenizer_clip(text).to(self.device)
|
||||
return self.clip_model.encode_text(tokens, normalize=True)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_text_t5(self, text: list[str]) -> torch.Tensor:
|
||||
device = self.text_encoder_t5.device
|
||||
batch = self.tokenizer_t5(
|
||||
text,
|
||||
max_length=self.tokenizer_t5.model_max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
|
||||
device
|
||||
)
|
||||
|
||||
encoder_hidden_states = self.text_encoder_t5(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).last_hidden_state # (B, L, D)
|
||||
|
||||
boolean_encoder_mask = (attention_mask == 1).to(device) # (B, L)
|
||||
|
||||
return encoder_hidden_states, boolean_encoder_mask
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_audio(self, x) -> DiagonalGaussianDistribution:
|
||||
assert self.tod is not None, 'VAE is not loaded'
|
||||
# x: (B * L)
|
||||
mel = self.mel_converter(x)
|
||||
dist = self.tod.encode(mel)
|
||||
|
||||
return dist
|
||||
|
||||
@torch.inference_mode()
|
||||
def vocode(self, mel: torch.Tensor) -> torch.Tensor:
|
||||
assert self.tod is not None, 'VAE is not loaded'
|
||||
return self.tod.vocode(mel)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
assert self.tod is not None, 'VAE is not loaded'
|
||||
return self.tod.decode(z.transpose(1, 2))
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
def get_parameter_groups(model, cfg, print_log=False):
|
||||
"""
|
||||
Assign different weight decays and learning rates to different parameters.
|
||||
Returns a parameter group which can be passed to the optimizer.
|
||||
"""
|
||||
weight_decay = cfg.weight_decay
|
||||
base_lr = cfg.learning_rate
|
||||
|
||||
params = []
|
||||
|
||||
# inspired by detectron2
|
||||
memo = set()
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
# Avoid duplicating parameters
|
||||
if param in memo:
|
||||
continue
|
||||
memo.add(param)
|
||||
|
||||
if name.startswith('module'):
|
||||
name = name[7:]
|
||||
|
||||
params.append(param)
|
||||
|
||||
parameter_groups = [
|
||||
{
|
||||
'params': params,
|
||||
'lr': base_lr,
|
||||
'weight_decay': weight_decay
|
||||
},
|
||||
]
|
||||
|
||||
return parameter_groups
|
||||
@@ -0,0 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def log_normal_sample(x: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
m: float = 0.0,
|
||||
s: float = 1.0) -> torch.Tensor:
|
||||
bs = x.shape[0]
|
||||
s = torch.randn(bs, device=x.device, generator=generator) * s + m
|
||||
return torch.sigmoid(s)
|
||||
Reference in New Issue
Block a user