feat: extract prismaudio_core model modules (DiT, conditioners, VAE, diffusion)
Fetch and adapt inference-critical model modules from upstream PrismAudio repo: - dit.py: DiffusionTransformer with debug prints removed - diffusion.py: ConditionedDiffusionModelWrapper, DiTWrapper, MMDiTWrapper - conditioners.py: Cond_MLP, Sync_MLP, MultiConditioner with stubbed training imports - autoencoders.py: AudioAutoencoder, OobleckEncoder/Decoder - transformer.py: ContinuousTransformer, Attention with flash_attn fallback to SDPA - blocks.py, utils.py, bottleneck.py, pretransforms.py, local_attention.py, pqmf.py - adp.py: UNetCFG1d, UNet1d, NumberEmbedder - mmmodules/model/low_level.py: MLP, ChannelLastConv1d, ConvMLP All internal imports rewritten from PrismAudio.* to prismaudio_core.*, training-only imports stubbed, flash_attn made optional with HAS_FLASH_ATTN flag. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,355 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from einops import rearrange
|
||||
from vector_quantize_pytorch import ResidualVQ, FSQ
|
||||
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, is_discrete: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.is_discrete = is_discrete
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
class DiscreteBottleneck(Bottleneck):
|
||||
def __init__(self, num_quantizers, codebook_size, tokens_id):
|
||||
super().__init__(is_discrete=True)
|
||||
|
||||
self.num_quantizers = num_quantizers
|
||||
self.codebook_size = codebook_size
|
||||
self.tokens_id = tokens_id
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class TanhBottleneck(Bottleneck):
|
||||
def __init__(self):
|
||||
super().__init__(is_discrete=False)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
x = torch.tanh(x)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def vae_sample(mean, scale):
|
||||
stdev = nn.functional.softplus(scale) + 1e-4
|
||||
var = stdev * stdev
|
||||
logvar = torch.log(var)
|
||||
latents = torch.randn_like(mean) * stdev + mean
|
||||
|
||||
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
||||
|
||||
return latents, kl
|
||||
|
||||
class VAEBottleneck(Bottleneck):
|
||||
def __init__(self):
|
||||
super().__init__(is_discrete=False)
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
mean, scale = x.chunk(2, dim=1)
|
||||
|
||||
x, kl = vae_sample(mean, scale)
|
||||
|
||||
info["kl"] = kl
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def compute_mean_kernel(x, y):
|
||||
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
|
||||
return torch.exp(-kernel_input).mean()
|
||||
|
||||
def compute_mmd(latents):
|
||||
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
|
||||
noise = torch.randn_like(latents_reshaped)
|
||||
|
||||
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
|
||||
noise_kernel = compute_mean_kernel(noise, noise)
|
||||
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
|
||||
|
||||
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
|
||||
return mmd.mean()
|
||||
|
||||
class WassersteinBottleneck(Bottleneck):
|
||||
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
|
||||
super().__init__(is_discrete=False)
|
||||
|
||||
self.noise_augment_dim = noise_augment_dim
|
||||
self.bypass_mmd = bypass_mmd
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
if self.training and return_info:
|
||||
if self.bypass_mmd:
|
||||
mmd = torch.tensor(0.0)
|
||||
else:
|
||||
mmd = compute_mmd(x)
|
||||
|
||||
info["mmd"] = mmd
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.noise_augment_dim > 0:
|
||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||
x.shape[-1]).type_as(x)
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
class L2Bottleneck(Bottleneck):
|
||||
def __init__(self):
|
||||
super().__init__(is_discrete=False)
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
x = F.normalize(x, dim=1)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return F.normalize(x, dim=1)
|
||||
|
||||
class RVQBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
||||
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x, indices, loss = self.quantizer(x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
info["quantizer_indices"] = indices
|
||||
info["quantizer_loss"] = loss.mean()
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents = self.quantizer.get_outputs_from_indices(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class RVQVAEBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
||||
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
x, kl = vae_sample(*x.chunk(2, dim=1))
|
||||
|
||||
info["kl"] = kl
|
||||
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x, indices, loss = self.quantizer(x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
info["quantizer_indices"] = indices
|
||||
info["quantizer_loss"] = loss.mean()
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents = self.quantizer.get_outputs_from_indices(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class DACRVQBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
||||
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
self.noise_augment_dim = noise_augment_dim
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
info["pre_quantizer"] = x
|
||||
|
||||
if self.quantize_on_decode:
|
||||
return x, info if return_info else x
|
||||
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
|
||||
|
||||
output = {
|
||||
"z": z,
|
||||
"codes": codes,
|
||||
"latents": latents,
|
||||
"vq/commitment_loss": commitment_loss,
|
||||
"vq/codebook_loss": codebook_loss,
|
||||
}
|
||||
|
||||
output["vq/commitment_loss"] /= self.num_quantizers
|
||||
output["vq/codebook_loss"] /= self.num_quantizers
|
||||
|
||||
info.update(output)
|
||||
|
||||
if return_info:
|
||||
return output["z"], info
|
||||
|
||||
return output["z"]
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.quantize_on_decode:
|
||||
x = self.quantizer(x)[0]
|
||||
|
||||
if self.noise_augment_dim > 0:
|
||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||
x.shape[-1]).type_as(x)
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents, _, _ = self.quantizer.from_codes(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class DACRVQVAEBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
||||
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
|
||||
def encode(self, x, return_info=False, n_quantizers: int = None):
|
||||
info = {}
|
||||
|
||||
mean, scale = x.chunk(2, dim=1)
|
||||
|
||||
x, kl = vae_sample(mean, scale)
|
||||
|
||||
info["pre_quantizer"] = x
|
||||
info["kl"] = kl
|
||||
|
||||
if self.quantize_on_decode:
|
||||
return x, info if return_info else x
|
||||
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
|
||||
|
||||
output = {
|
||||
"z": z,
|
||||
"codes": codes,
|
||||
"latents": latents,
|
||||
"vq/commitment_loss": commitment_loss,
|
||||
"vq/codebook_loss": codebook_loss,
|
||||
}
|
||||
|
||||
output["vq/commitment_loss"] /= self.num_quantizers
|
||||
output["vq/codebook_loss"] /= self.num_quantizers
|
||||
|
||||
info.update(output)
|
||||
|
||||
if return_info:
|
||||
return output["z"], info
|
||||
|
||||
return output["z"]
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.quantize_on_decode:
|
||||
x = self.quantizer(x)[0]
|
||||
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents, _, _ = self.quantizer.from_codes(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class FSQBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, noise_augment_dim=0, **kwargs):
|
||||
super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
|
||||
|
||||
self.noise_augment_dim = noise_augment_dim
|
||||
|
||||
self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x, indices = self.quantizer(x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
x = x.to(orig_dtype)
|
||||
|
||||
# Reorder indices to match the expected format
|
||||
indices = rearrange(indices, "b n q -> b q n")
|
||||
|
||||
info["quantizer_indices"] = indices
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.noise_augment_dim > 0:
|
||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||
x.shape[-1]).type_as(x)
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
def decode_tokens(self, tokens, **kwargs):
|
||||
latents = self.quantizer.indices_to_codes(tokens)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
Reference in New Issue
Block a user