feat: extract prismaudio_core model modules (DiT, conditioners, VAE, diffusion)
Fetch and adapt inference-critical model modules from upstream PrismAudio repo: - dit.py: DiffusionTransformer with debug prints removed - diffusion.py: ConditionedDiffusionModelWrapper, DiTWrapper, MMDiTWrapper - conditioners.py: Cond_MLP, Sync_MLP, MultiConditioner with stubbed training imports - autoencoders.py: AudioAutoencoder, OobleckEncoder/Decoder - transformer.py: ContinuousTransformer, Attention with flash_attn fallback to SDPA - blocks.py, utils.py, bottleneck.py, pretransforms.py, local_attention.py, pqmf.py - adp.py: UNetCFG1d, UNet1d, NumberEmbedder - mmmodules/model/low_level.py: MLP, ChannelLastConv1d, ConvMLP All internal imports rewritten from PrismAudio.* to prismaudio_core.*, training-only imports stubbed, flash_attn made optional with HAS_FLASH_ATTN flag. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,830 @@
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchaudio import transforms as T
|
||||
from alias_free_torch import Activation1d
|
||||
from dac.nn.layers import WNConv1d, WNConvTranspose1d
|
||||
from typing import Literal, Dict, Any
|
||||
|
||||
from .blocks import SnakeBeta
|
||||
from .bottleneck import Bottleneck, DiscreteBottleneck
|
||||
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
|
||||
from .pretransforms import Pretransform
|
||||
|
||||
|
||||
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
||||
"""Minimal stub for inference.utils.prepare_audio used by autoencoders."""
|
||||
import torchaudio.transforms as T
|
||||
import torch
|
||||
|
||||
if in_sr != target_sr:
|
||||
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
||||
audio = resample_tf(audio)
|
||||
|
||||
if audio.shape[0] > target_channels:
|
||||
audio = audio[:target_channels]
|
||||
elif audio.shape[0] < target_channels:
|
||||
audio = audio.repeat(target_channels // audio.shape[0] + 1, 1)[:target_channels]
|
||||
|
||||
if audio.shape[-1] < target_length:
|
||||
audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[-1]))
|
||||
elif audio.shape[-1] > target_length:
|
||||
audio = audio[..., :target_length]
|
||||
|
||||
return audio.unsqueeze(0)
|
||||
|
||||
|
||||
def _lazy_create_pretransform_from_config(pretransform, sample_rate):
|
||||
from prismaudio_core.factory import create_pretransform_from_config
|
||||
return create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
|
||||
def _lazy_create_bottleneck_from_config(bottleneck):
|
||||
from prismaudio_core.factory import create_bottleneck_from_config
|
||||
return create_bottleneck_from_config(bottleneck)
|
||||
|
||||
def checkpoint(function, *args, **kwargs):
|
||||
kwargs.setdefault("use_reentrant", False)
|
||||
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
act = nn.ELU()
|
||||
elif activation == "snake":
|
||||
act = SnakeBeta(channels)
|
||||
elif activation == "none":
|
||||
act = nn.Identity()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation {activation}")
|
||||
|
||||
if antialias:
|
||||
act = Activation1d(act)
|
||||
|
||||
return act
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.dilation = dilation
|
||||
|
||||
padding = (dilation * (7-1)) // 2
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=7, dilation=dilation, padding=padding),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
||||
kernel_size=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
|
||||
#x = checkpoint(self.layers, x)
|
||||
x = self.layers(x)
|
||||
|
||||
return x + res
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
||||
super().__init__()
|
||||
|
||||
if use_nearest_upsample:
|
||||
upsample_layer = nn.Sequential(
|
||||
nn.Upsample(scale_factor=stride, mode="nearest"),
|
||||
WNConv1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride,
|
||||
stride=1,
|
||||
bias=False,
|
||||
padding='same')
|
||||
)
|
||||
else:
|
||||
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
upsample_layer,
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=9, use_snake=use_snake),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class OobleckEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
||||
]
|
||||
|
||||
for i in range(self.depth-1):
|
||||
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
||||
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class OobleckDecoder(nn.Module):
|
||||
def __init__(self,
|
||||
out_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False,
|
||||
use_nearest_upsample=False,
|
||||
final_tanh=True):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
||||
]
|
||||
|
||||
for i in range(self.depth-1, 0, -1):
|
||||
layers += [DecoderBlock(
|
||||
in_channels=c_mults[i]*channels,
|
||||
out_channels=c_mults[i-1]*channels,
|
||||
stride=strides[i-1],
|
||||
use_snake=use_snake,
|
||||
antialias_activation=antialias_activation,
|
||||
use_nearest_upsample=use_nearest_upsample
|
||||
)
|
||||
]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
||||
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
||||
nn.Tanh() if final_tanh else nn.Identity()
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class DACEncoderWrapper(nn.Module):
|
||||
def __init__(self, in_channels=1, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
from dac.model.dac import Encoder as DACEncoder
|
||||
|
||||
latent_dim = kwargs.pop("latent_dim", None)
|
||||
|
||||
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
|
||||
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
|
||||
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
|
||||
|
||||
if in_channels != 1:
|
||||
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
class DACDecoderWrapper(nn.Module):
|
||||
def __init__(self, latent_dim, out_channels=1, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
from dac.model.dac import Decoder as DACDecoder
|
||||
|
||||
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(x)
|
||||
|
||||
class AudioAutoencoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder,
|
||||
decoder,
|
||||
latent_dim,
|
||||
downsampling_ratio,
|
||||
sample_rate,
|
||||
io_channels=2,
|
||||
bottleneck: Bottleneck = None,
|
||||
pretransform: Pretransform = None,
|
||||
in_channels = None,
|
||||
out_channels = None,
|
||||
soft_clip = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.downsampling_ratio = downsampling_ratio
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.io_channels = io_channels
|
||||
self.in_channels = io_channels
|
||||
self.out_channels = io_channels
|
||||
|
||||
self.min_length = self.downsampling_ratio
|
||||
|
||||
if in_channels is not None:
|
||||
self.in_channels = in_channels
|
||||
|
||||
if out_channels is not None:
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.bottleneck = bottleneck
|
||||
|
||||
self.encoder = encoder
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
self.pretransform = pretransform
|
||||
|
||||
self.soft_clip = soft_clip
|
||||
|
||||
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
|
||||
|
||||
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
|
||||
|
||||
info = {}
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
if self.pretransform is not None and not skip_pretransform:
|
||||
if self.pretransform.enable_grad:
|
||||
if iterate_batch:
|
||||
audios = []
|
||||
for i in range(audio.shape[0]):
|
||||
audios.append(self.pretransform.encode(audio[i:i+1]))
|
||||
audio = torch.cat(audios, dim=0)
|
||||
else:
|
||||
audio = self.pretransform.encode(audio)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if iterate_batch:
|
||||
audios = []
|
||||
for i in range(audio.shape[0]):
|
||||
audios.append(self.pretransform.encode(audio[i:i+1]))
|
||||
audio = torch.cat(audios, dim=0)
|
||||
else:
|
||||
audio = self.pretransform.encode(audio)
|
||||
|
||||
if self.encoder is not None:
|
||||
if iterate_batch:
|
||||
latents = []
|
||||
for i in range(audio.shape[0]):
|
||||
latents.append(self.encoder(audio[i:i+1]))
|
||||
latents = torch.cat(latents, dim=0)
|
||||
else:
|
||||
latents = self.encoder(audio)
|
||||
else:
|
||||
latents = audio
|
||||
|
||||
if self.bottleneck is not None:
|
||||
# TODO: Add iterate batch logic, needs to merge the info dicts
|
||||
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
|
||||
|
||||
info.update(bottleneck_info)
|
||||
|
||||
if return_info:
|
||||
return latents, info
|
||||
|
||||
return latents
|
||||
|
||||
def decode(self, latents, iterate_batch=False, **kwargs):
|
||||
|
||||
if self.bottleneck is not None:
|
||||
if iterate_batch:
|
||||
decoded = []
|
||||
for i in range(latents.shape[0]):
|
||||
decoded.append(self.bottleneck.decode(latents[i:i+1]))
|
||||
latents = torch.cat(decoded, dim=0)
|
||||
else:
|
||||
latents = self.bottleneck.decode(latents)
|
||||
|
||||
if iterate_batch:
|
||||
decoded = []
|
||||
for i in range(latents.shape[0]):
|
||||
decoded.append(self.decoder(latents[i:i+1]))
|
||||
decoded = torch.cat(decoded, dim=0)
|
||||
else:
|
||||
decoded = self.decoder(latents, **kwargs)
|
||||
|
||||
if self.pretransform is not None:
|
||||
if self.pretransform.enable_grad:
|
||||
if iterate_batch:
|
||||
decodeds = []
|
||||
for i in range(decoded.shape[0]):
|
||||
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
||||
decoded = torch.cat(decodeds, dim=0)
|
||||
else:
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if iterate_batch:
|
||||
decodeds = []
|
||||
for i in range(latents.shape[0]):
|
||||
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
||||
decoded = torch.cat(decodeds, dim=0)
|
||||
else:
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
|
||||
if self.soft_clip:
|
||||
decoded = torch.tanh(decoded)
|
||||
|
||||
return decoded
|
||||
|
||||
def decode_tokens(self, tokens, **kwargs):
|
||||
'''
|
||||
Decode discrete tokens to audio
|
||||
Only works with discrete autoencoders
|
||||
'''
|
||||
|
||||
assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
|
||||
|
||||
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
|
||||
def preprocess_audio_for_encoder(self, audio, in_sr):
|
||||
'''
|
||||
Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
|
||||
If the model is mono, stereo audio will be converted to mono.
|
||||
Audio will be silence-padded to be a multiple of the model's downsampling ratio.
|
||||
Audio will be resampled to the model's sample rate.
|
||||
The output will have batch size 1 and be shape (1 x Channels x Length)
|
||||
'''
|
||||
return self.preprocess_audio_list_for_encoder([audio], [in_sr])
|
||||
|
||||
def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
|
||||
'''
|
||||
Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
|
||||
The audio in that list can be of different lengths and channels.
|
||||
in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
|
||||
All audio will be resampled to the model's sample rate.
|
||||
Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
|
||||
If the model is mono, all audio will be converted to mono.
|
||||
The output will be a tensor of shape (Batch x Channels x Length)
|
||||
'''
|
||||
batch_size = len(audio_list)
|
||||
if isinstance(in_sr_list, int):
|
||||
in_sr_list = [in_sr_list]*batch_size
|
||||
assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
|
||||
new_audio = []
|
||||
max_length = 0
|
||||
# resample & find the max length
|
||||
for i in range(batch_size):
|
||||
audio = audio_list[i]
|
||||
in_sr = in_sr_list[i]
|
||||
if len(audio.shape) == 3 and audio.shape[0] == 1:
|
||||
# batchsize 1 was given by accident. Just squeeze it.
|
||||
audio = audio.squeeze(0)
|
||||
elif len(audio.shape) == 1:
|
||||
# Mono signal, channel dimension is missing, unsqueeze it in
|
||||
audio = audio.unsqueeze(0)
|
||||
assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
|
||||
# Resample audio
|
||||
if in_sr != self.sample_rate:
|
||||
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
|
||||
audio = resample_tf(audio)
|
||||
new_audio.append(audio)
|
||||
if audio.shape[-1] > max_length:
|
||||
max_length = audio.shape[-1]
|
||||
# Pad every audio to the same length, multiple of model's downsampling ratio
|
||||
padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
|
||||
for i in range(batch_size):
|
||||
# Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
|
||||
new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
|
||||
target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
|
||||
# convert to tensor
|
||||
return torch.stack(new_audio)
|
||||
|
||||
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
||||
'''
|
||||
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
|
||||
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
|
||||
Overlap and chunk_size params are both measured in number of latents (not audio samples)
|
||||
# and therefore you likely could use the same values with decode_audio.
|
||||
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
||||
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
||||
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
|
||||
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
||||
Smaller chunk_size uses less memory, but more compute.
|
||||
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
||||
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
||||
'''
|
||||
if not chunked:
|
||||
# default behavior. Encode the entire audio in parallel
|
||||
return self.encode(audio, **kwargs)
|
||||
else:
|
||||
# CHUNKED ENCODING
|
||||
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
samples_per_latent = self.downsampling_ratio
|
||||
total_size = audio.shape[2] # in samples
|
||||
print(f'audio shape: {audio.shape}')
|
||||
batch_size = audio.shape[0]
|
||||
chunk_size *= samples_per_latent # converting metric in latents to samples
|
||||
overlap *= samples_per_latent # converting metric in latents to samples
|
||||
hop_size = chunk_size - overlap
|
||||
chunks = []
|
||||
for i in range(0, total_size - chunk_size + 1, hop_size):
|
||||
chunk = audio[:,:,i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
if i+chunk_size != total_size:
|
||||
# Final chunk
|
||||
chunk = audio[:,:,-chunk_size:]
|
||||
chunks.append(chunk)
|
||||
chunks = torch.stack(chunks)
|
||||
num_chunks = chunks.shape[0]
|
||||
# Note: y_size might be a different value from the latent length used in diffusion training
|
||||
# because we can encode audio of varying lengths
|
||||
# However, the audio should've been padded to a multiple of samples_per_latent by now.
|
||||
y_size = total_size // samples_per_latent
|
||||
# Create an empty latent, we will populate it with chunks as we encode them
|
||||
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
|
||||
print(f'y_final shape: {y_final.shape}')
|
||||
for i in range(num_chunks):
|
||||
x_chunk = chunks[i,:]
|
||||
# encode the chunk
|
||||
y_chunk = self.encode(x_chunk)
|
||||
print(f'y_chunk shape: {y_chunk.shape}')
|
||||
# figure out where to put the audio along the time domain
|
||||
if i == num_chunks-1:
|
||||
# final chunk always goes at the end
|
||||
t_end = y_size
|
||||
t_start = t_end - y_chunk.shape[2]
|
||||
else:
|
||||
t_start = i * hop_size // samples_per_latent
|
||||
t_end = t_start + chunk_size // samples_per_latent
|
||||
# remove the edges of the overlaps
|
||||
ol = overlap//samples_per_latent//2
|
||||
chunk_start = 0
|
||||
chunk_end = y_chunk.shape[2]
|
||||
if i > 0:
|
||||
# no overlap for the start of the first chunk
|
||||
t_start += ol
|
||||
chunk_start += ol
|
||||
if i < num_chunks-1:
|
||||
# no overlap for the end of the last chunk
|
||||
t_end -= ol
|
||||
chunk_end -= ol
|
||||
# paste the chunked audio into our y_final output audio
|
||||
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
||||
return y_final
|
||||
|
||||
def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
||||
'''
|
||||
Decode latents to audio.
|
||||
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
|
||||
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
||||
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
||||
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
|
||||
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
||||
Smaller chunk_size uses less memory, but more compute.
|
||||
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
||||
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
||||
'''
|
||||
if not chunked:
|
||||
# default behavior. Decode the entire latent in parallel
|
||||
return self.decode(latents, **kwargs)
|
||||
else:
|
||||
# chunked decoding
|
||||
hop_size = chunk_size - overlap
|
||||
total_size = latents.shape[2]
|
||||
batch_size = latents.shape[0]
|
||||
chunks = []
|
||||
for i in range(0, total_size - chunk_size + 1, hop_size):
|
||||
chunk = latents[:,:,i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
if i+chunk_size != total_size:
|
||||
# Final chunk
|
||||
chunk = latents[:,:,-chunk_size:]
|
||||
chunks.append(chunk)
|
||||
chunks = torch.stack(chunks)
|
||||
num_chunks = chunks.shape[0]
|
||||
# samples_per_latent is just the downsampling ratio
|
||||
samples_per_latent = self.downsampling_ratio
|
||||
# Create an empty waveform, we will populate it with chunks as decode them
|
||||
y_size = total_size * samples_per_latent
|
||||
y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
|
||||
for i in range(num_chunks):
|
||||
x_chunk = chunks[i,:]
|
||||
# decode the chunk
|
||||
y_chunk = self.decode(x_chunk)
|
||||
# figure out where to put the audio along the time domain
|
||||
if i == num_chunks-1:
|
||||
# final chunk always goes at the end
|
||||
t_end = y_size
|
||||
t_start = t_end - y_chunk.shape[2]
|
||||
else:
|
||||
t_start = i * hop_size * samples_per_latent
|
||||
t_end = t_start + chunk_size * samples_per_latent
|
||||
# remove the edges of the overlaps
|
||||
ol = (overlap//2) * samples_per_latent
|
||||
chunk_start = 0
|
||||
chunk_end = y_chunk.shape[2]
|
||||
if i > 0:
|
||||
# no overlap for the start of the first chunk
|
||||
t_start += ol
|
||||
chunk_start += ol
|
||||
if i < num_chunks-1:
|
||||
# no overlap for the end of the last chunk
|
||||
t_end -= ol
|
||||
chunk_end -= ol
|
||||
# paste the chunked audio into our y_final output audio
|
||||
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
||||
return y_final
|
||||
|
||||
|
||||
class DiffusionAutoencoder(AudioAutoencoder):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion: ConditionedDiffusionModel,
|
||||
diffusion_downsampling_ratio,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.diffusion = diffusion
|
||||
|
||||
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
|
||||
|
||||
if self.encoder is not None:
|
||||
# Shrink the initial encoder parameters to avoid saturated latents
|
||||
with torch.no_grad():
|
||||
for param in self.encoder.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def decode(self, latents, steps=100):
|
||||
|
||||
upsampled_length = latents.shape[2] * self.downsampling_ratio
|
||||
|
||||
if self.bottleneck is not None:
|
||||
latents = self.bottleneck.decode(latents)
|
||||
|
||||
if self.decoder is not None:
|
||||
latents = self.decode(latents)
|
||||
|
||||
# Upsample latents to match diffusion length
|
||||
if latents.shape[2] != upsampled_length:
|
||||
latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
|
||||
|
||||
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
|
||||
from prismaudio_core.inference.sampling import sample
|
||||
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
|
||||
|
||||
if self.pretransform is not None:
|
||||
if self.pretransform.enable_grad:
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
|
||||
return decoded
|
||||
|
||||
# AE factories
|
||||
|
||||
def create_encoder_from_config(encoder_config: Dict[str, Any]):
|
||||
encoder_type = encoder_config.get("type", None)
|
||||
assert encoder_type is not None, "Encoder type must be specified"
|
||||
|
||||
if encoder_type == "oobleck":
|
||||
encoder = OobleckEncoder(
|
||||
**encoder_config["config"]
|
||||
)
|
||||
|
||||
elif encoder_type == "seanet":
|
||||
from encodec.modules import SEANetEncoder
|
||||
seanet_encoder_config = encoder_config["config"]
|
||||
|
||||
#SEANet encoder expects strides in reverse order
|
||||
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
|
||||
encoder = SEANetEncoder(
|
||||
**seanet_encoder_config
|
||||
)
|
||||
elif encoder_type == "dac":
|
||||
dac_config = encoder_config["config"]
|
||||
|
||||
encoder = DACEncoderWrapper(**dac_config)
|
||||
elif encoder_type == "local_attn":
|
||||
from .local_attention import TransformerEncoder1D
|
||||
|
||||
local_attn_config = encoder_config["config"]
|
||||
|
||||
encoder = TransformerEncoder1D(
|
||||
**local_attn_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder type {encoder_type}")
|
||||
|
||||
requires_grad = encoder_config.get("requires_grad", True)
|
||||
if not requires_grad:
|
||||
for param in encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return encoder
|
||||
|
||||
def create_decoder_from_config(decoder_config: Dict[str, Any]):
|
||||
decoder_type = decoder_config.get("type", None)
|
||||
assert decoder_type is not None, "Decoder type must be specified"
|
||||
|
||||
if decoder_type == "oobleck":
|
||||
decoder = OobleckDecoder(
|
||||
**decoder_config["config"]
|
||||
)
|
||||
elif decoder_type == "seanet":
|
||||
from encodec.modules import SEANetDecoder
|
||||
|
||||
decoder = SEANetDecoder(
|
||||
**decoder_config["config"]
|
||||
)
|
||||
elif decoder_type == "dac":
|
||||
dac_config = decoder_config["config"]
|
||||
|
||||
decoder = DACDecoderWrapper(**dac_config)
|
||||
elif decoder_type == "local_attn":
|
||||
from .local_attention import TransformerDecoder1D
|
||||
|
||||
local_attn_config = decoder_config["config"]
|
||||
|
||||
decoder = TransformerDecoder1D(
|
||||
**local_attn_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown decoder type {decoder_type}")
|
||||
|
||||
requires_grad = decoder_config.get("requires_grad", True)
|
||||
if not requires_grad:
|
||||
for param in decoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return decoder
|
||||
|
||||
def create_autoencoder_from_config(config: Dict[str, Any]):
|
||||
|
||||
ae_config = config["model"]
|
||||
|
||||
encoder = create_encoder_from_config(ae_config["encoder"])
|
||||
decoder = create_decoder_from_config(ae_config["decoder"])
|
||||
|
||||
bottleneck = ae_config.get("bottleneck", None)
|
||||
|
||||
latent_dim = ae_config.get("latent_dim", None)
|
||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||
io_channels = ae_config.get("io_channels", None)
|
||||
assert io_channels is not None, "io_channels must be specified in model config"
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||
|
||||
in_channels = ae_config.get("in_channels", None)
|
||||
out_channels = ae_config.get("out_channels", None)
|
||||
|
||||
pretransform = ae_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
if bottleneck is not None:
|
||||
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
||||
|
||||
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
||||
|
||||
return AudioAutoencoder(
|
||||
encoder,
|
||||
decoder,
|
||||
io_channels=io_channels,
|
||||
latent_dim=latent_dim,
|
||||
downsampling_ratio=downsampling_ratio,
|
||||
sample_rate=sample_rate,
|
||||
bottleneck=bottleneck,
|
||||
pretransform=pretransform,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
soft_clip=soft_clip
|
||||
)
|
||||
|
||||
def create_diffAE_from_config(config: Dict[str, Any]):
|
||||
|
||||
diffae_config = config["model"]
|
||||
|
||||
if "encoder" in diffae_config:
|
||||
encoder = create_encoder_from_config(diffae_config["encoder"])
|
||||
else:
|
||||
encoder = None
|
||||
|
||||
if "decoder" in diffae_config:
|
||||
decoder = create_decoder_from_config(diffae_config["decoder"])
|
||||
else:
|
||||
decoder = None
|
||||
|
||||
diffusion_model_type = diffae_config["diffusion"]["type"]
|
||||
|
||||
if diffusion_model_type == "DAU1d":
|
||||
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
|
||||
elif diffusion_model_type == "adp_1d":
|
||||
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
|
||||
elif diffusion_model_type == "dit":
|
||||
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
|
||||
|
||||
latent_dim = diffae_config.get("latent_dim", None)
|
||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
|
||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||
io_channels = diffae_config.get("io_channels", None)
|
||||
assert io_channels is not None, "io_channels must be specified in model config"
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||
|
||||
bottleneck = diffae_config.get("bottleneck", None)
|
||||
|
||||
pretransform = diffae_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
if bottleneck is not None:
|
||||
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
||||
|
||||
diffusion_downsampling_ratio = None,
|
||||
|
||||
if diffusion_model_type == "DAU1d":
|
||||
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
|
||||
elif diffusion_model_type == "adp_1d":
|
||||
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
|
||||
elif diffusion_model_type == "dit":
|
||||
diffusion_downsampling_ratio = 1
|
||||
|
||||
return DiffusionAutoencoder(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
diffusion=diffusion,
|
||||
io_channels=io_channels,
|
||||
sample_rate=sample_rate,
|
||||
latent_dim=latent_dim,
|
||||
downsampling_ratio=downsampling_ratio,
|
||||
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
|
||||
bottleneck=bottleneck,
|
||||
pretransform=pretransform
|
||||
)
|
||||
Reference in New Issue
Block a user