chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from selva_core.ext.autoencoder.vae import VAE, get_my_vae
|
||||
from selva_core.ext.bigvgan import BigVGAN
|
||||
from selva_core.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
|
||||
from selva_core.model.utils.distributions import DiagonalGaussianDistribution
|
||||
|
||||
|
||||
class AutoEncoderModule(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vae_ckpt_path,
|
||||
vocoder_ckpt_path: Optional[str] = None,
|
||||
mode: Literal['16k', '44k'],
|
||||
need_vae_encoder: bool = True):
|
||||
super().__init__()
|
||||
self.vae: VAE = get_my_vae(mode).eval()
|
||||
vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
|
||||
self.vae.load_state_dict(vae_state_dict)
|
||||
self.vae.remove_weight_norm()
|
||||
|
||||
if mode == '16k':
|
||||
assert vocoder_ckpt_path is not None
|
||||
self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
|
||||
elif mode == '44k':
|
||||
self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
|
||||
use_cuda_kernel=False)
|
||||
self.vocoder.remove_weight_norm()
|
||||
else:
|
||||
raise ValueError(f'Unknown mode: {mode}')
|
||||
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if not need_vae_encoder:
|
||||
del self.vae.encoder
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
|
||||
return self.vae.encode(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
return self.vae.decode(z)
|
||||
|
||||
@torch.inference_mode()
|
||||
def vocode(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.vocoder(spec)
|
||||
Reference in New Issue
Block a user