import torch import math import numpy as np from torch import nn from torch.nn import functional as F from torchaudio import transforms as T from alias_free_torch import Activation1d from dac.nn.layers import WNConv1d, WNConvTranspose1d from typing import Literal, Dict, Any from .blocks import SnakeBeta from .bottleneck import Bottleneck, DiscreteBottleneck from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper from .pretransforms import Pretransform def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): """Minimal stub for inference.utils.prepare_audio used by autoencoders.""" import torchaudio.transforms as T import torch if in_sr != target_sr: resample_tf = T.Resample(in_sr, target_sr).to(device) audio = resample_tf(audio) if audio.shape[0] > target_channels: audio = audio[:target_channels] elif audio.shape[0] < target_channels: audio = audio.repeat(target_channels // audio.shape[0] + 1, 1)[:target_channels] if audio.shape[-1] < target_length: audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[-1])) elif audio.shape[-1] > target_length: audio = audio[..., :target_length] return audio.unsqueeze(0) def _lazy_create_pretransform_from_config(pretransform, sample_rate): from prismaudio_core.factory import create_pretransform_from_config return create_pretransform_from_config(pretransform, sample_rate) def _lazy_create_bottleneck_from_config(bottleneck): from prismaudio_core.factory import create_bottleneck_from_config return create_bottleneck_from_config(bottleneck) def checkpoint(function, *args, **kwargs): kwargs.setdefault("use_reentrant", False) return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: if activation == "elu": act = nn.ELU() elif activation == "snake": act = SnakeBeta(channels) elif activation == "none": act = nn.Identity() else: raise ValueError(f"Unknown activation {activation}") if antialias: act = Activation1d(act) return act class ResidualUnit(nn.Module): def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): super().__init__() self.dilation = dilation padding = (dilation * (7-1)) // 2 self.layers = nn.Sequential( get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), WNConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, dilation=dilation, padding=padding), get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), WNConv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1) ) def forward(self, x): res = x #x = checkpoint(self.layers, x) x = self.layers(x) return x + res class EncoderBlock(nn.Module): def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): super().__init__() self.layers = nn.Sequential( ResidualUnit(in_channels=in_channels, out_channels=in_channels, dilation=1, use_snake=use_snake), ResidualUnit(in_channels=in_channels, out_channels=in_channels, dilation=3, use_snake=use_snake), ResidualUnit(in_channels=in_channels, out_channels=in_channels, dilation=9, use_snake=use_snake), get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), WNConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), ) def forward(self, x): return self.layers(x) class DecoderBlock(nn.Module): def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): super().__init__() if use_nearest_upsample: upsample_layer = nn.Sequential( nn.Upsample(scale_factor=stride, mode="nearest"), WNConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=2*stride, stride=1, bias=False, padding='same') ) else: upsample_layer = WNConvTranspose1d(in_channels=in_channels, out_channels=out_channels, kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) self.layers = nn.Sequential( get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), upsample_layer, ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=1, use_snake=use_snake), ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=3, use_snake=use_snake), ResidualUnit(in_channels=out_channels, out_channels=out_channels, dilation=9, use_snake=use_snake), ) def forward(self, x): return self.layers(x) class OobleckEncoder(nn.Module): def __init__(self, in_channels=2, channels=128, latent_dim=32, c_mults = [1, 2, 4, 8], strides = [2, 4, 8, 8], use_snake=False, antialias_activation=False ): super().__init__() c_mults = [1] + c_mults self.depth = len(c_mults) layers = [ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) ] for i in range(self.depth-1): layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] layers += [ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) ] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class OobleckDecoder(nn.Module): def __init__(self, out_channels=2, channels=128, latent_dim=32, c_mults = [1, 2, 4, 8], strides = [2, 4, 8, 8], use_snake=False, antialias_activation=False, use_nearest_upsample=False, final_tanh=True): super().__init__() c_mults = [1] + c_mults self.depth = len(c_mults) layers = [ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), ] for i in range(self.depth-1, 0, -1): layers += [DecoderBlock( in_channels=c_mults[i]*channels, out_channels=c_mults[i-1]*channels, stride=strides[i-1], use_snake=use_snake, antialias_activation=antialias_activation, use_nearest_upsample=use_nearest_upsample ) ] layers += [ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), nn.Tanh() if final_tanh else nn.Identity() ] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class DACEncoderWrapper(nn.Module): def __init__(self, in_channels=1, **kwargs): super().__init__() from dac.model.dac import Encoder as DACEncoder latent_dim = kwargs.pop("latent_dim", None) encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) self.latent_dim = latent_dim # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() if in_channels != 1: self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) def forward(self, x): x = self.encoder(x) x = self.proj_out(x) return x class DACDecoderWrapper(nn.Module): def __init__(self, latent_dim, out_channels=1, **kwargs): super().__init__() from dac.model.dac import Decoder as DACDecoder self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) self.latent_dim = latent_dim def forward(self, x): return self.decoder(x) class AudioAutoencoder(nn.Module): def __init__( self, encoder, decoder, latent_dim, downsampling_ratio, sample_rate, io_channels=2, bottleneck: Bottleneck = None, pretransform: Pretransform = None, in_channels = None, out_channels = None, soft_clip = False ): super().__init__() self.downsampling_ratio = downsampling_ratio self.sample_rate = sample_rate self.latent_dim = latent_dim self.io_channels = io_channels self.in_channels = io_channels self.out_channels = io_channels self.min_length = self.downsampling_ratio if in_channels is not None: self.in_channels = in_channels if out_channels is not None: self.out_channels = out_channels self.bottleneck = bottleneck self.encoder = encoder self.decoder = decoder self.pretransform = pretransform self.soft_clip = soft_clip self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): info = {} # import ipdb # ipdb.set_trace() if self.pretransform is not None and not skip_pretransform: if self.pretransform.enable_grad: if iterate_batch: audios = [] for i in range(audio.shape[0]): audios.append(self.pretransform.encode(audio[i:i+1])) audio = torch.cat(audios, dim=0) else: audio = self.pretransform.encode(audio) else: with torch.no_grad(): if iterate_batch: audios = [] for i in range(audio.shape[0]): audios.append(self.pretransform.encode(audio[i:i+1])) audio = torch.cat(audios, dim=0) else: audio = self.pretransform.encode(audio) if self.encoder is not None: if iterate_batch: latents = [] for i in range(audio.shape[0]): latents.append(self.encoder(audio[i:i+1])) latents = torch.cat(latents, dim=0) else: latents = self.encoder(audio) else: latents = audio if self.bottleneck is not None: # TODO: Add iterate batch logic, needs to merge the info dicts latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) info.update(bottleneck_info) if return_info: return latents, info return latents def decode(self, latents, iterate_batch=False, **kwargs): if self.bottleneck is not None: if iterate_batch: decoded = [] for i in range(latents.shape[0]): decoded.append(self.bottleneck.decode(latents[i:i+1])) latents = torch.cat(decoded, dim=0) else: latents = self.bottleneck.decode(latents) if iterate_batch: decoded = [] for i in range(latents.shape[0]): decoded.append(self.decoder(latents[i:i+1])) decoded = torch.cat(decoded, dim=0) else: decoded = self.decoder(latents, **kwargs) if self.pretransform is not None: if self.pretransform.enable_grad: if iterate_batch: decodeds = [] for i in range(decoded.shape[0]): decodeds.append(self.pretransform.decode(decoded[i:i+1])) decoded = torch.cat(decodeds, dim=0) else: decoded = self.pretransform.decode(decoded) else: with torch.no_grad(): if iterate_batch: decodeds = [] for i in range(latents.shape[0]): decodeds.append(self.pretransform.decode(decoded[i:i+1])) decoded = torch.cat(decodeds, dim=0) else: decoded = self.pretransform.decode(decoded) if self.soft_clip: decoded = torch.tanh(decoded) return decoded def decode_tokens(self, tokens, **kwargs): ''' Decode discrete tokens to audio Only works with discrete autoencoders ''' assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" latents = self.bottleneck.decode_tokens(tokens, **kwargs) return self.decode(latents, **kwargs) def preprocess_audio_for_encoder(self, audio, in_sr): ''' Preprocess single audio tensor (Channels x Length) to be compatible with the encoder. If the model is mono, stereo audio will be converted to mono. Audio will be silence-padded to be a multiple of the model's downsampling ratio. Audio will be resampled to the model's sample rate. The output will have batch size 1 and be shape (1 x Channels x Length) ''' return self.preprocess_audio_list_for_encoder([audio], [in_sr]) def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): ''' Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. The audio in that list can be of different lengths and channels. in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio. All audio will be resampled to the model's sample rate. Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. If the model is mono, all audio will be converted to mono. The output will be a tensor of shape (Batch x Channels x Length) ''' batch_size = len(audio_list) if isinstance(in_sr_list, int): in_sr_list = [in_sr_list]*batch_size assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" new_audio = [] max_length = 0 # resample & find the max length for i in range(batch_size): audio = audio_list[i] in_sr = in_sr_list[i] if len(audio.shape) == 3 and audio.shape[0] == 1: # batchsize 1 was given by accident. Just squeeze it. audio = audio.squeeze(0) elif len(audio.shape) == 1: # Mono signal, channel dimension is missing, unsqueeze it in audio = audio.unsqueeze(0) assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" # Resample audio if in_sr != self.sample_rate: resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) audio = resample_tf(audio) new_audio.append(audio) if audio.shape[-1] > max_length: max_length = audio.shape[-1] # Pad every audio to the same length, multiple of model's downsampling ratio padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length for i in range(batch_size): # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, target_channels=self.in_channels, device=new_audio[i].device).squeeze(0) # convert to tensor return torch.stack(new_audio) def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): ''' Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. Overlap and chunk_size params are both measured in number of latents (not audio samples) # and therefore you likely could use the same values with decode_audio. A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. Every autoencoder will have a different receptive field size, and thus ideal overlap. You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. Smaller chunk_size uses less memory, but more compute. The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks ''' if not chunked: # default behavior. Encode the entire audio in parallel return self.encode(audio, **kwargs) else: # CHUNKED ENCODING # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) # import ipdb # ipdb.set_trace() samples_per_latent = self.downsampling_ratio total_size = audio.shape[2] # in samples print(f'audio shape: {audio.shape}') batch_size = audio.shape[0] chunk_size *= samples_per_latent # converting metric in latents to samples overlap *= samples_per_latent # converting metric in latents to samples hop_size = chunk_size - overlap chunks = [] for i in range(0, total_size - chunk_size + 1, hop_size): chunk = audio[:,:,i:i+chunk_size] chunks.append(chunk) if i+chunk_size != total_size: # Final chunk chunk = audio[:,:,-chunk_size:] chunks.append(chunk) chunks = torch.stack(chunks) num_chunks = chunks.shape[0] # Note: y_size might be a different value from the latent length used in diffusion training # because we can encode audio of varying lengths # However, the audio should've been padded to a multiple of samples_per_latent by now. y_size = total_size // samples_per_latent # Create an empty latent, we will populate it with chunks as we encode them y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) print(f'y_final shape: {y_final.shape}') for i in range(num_chunks): x_chunk = chunks[i,:] # encode the chunk y_chunk = self.encode(x_chunk) print(f'y_chunk shape: {y_chunk.shape}') # figure out where to put the audio along the time domain if i == num_chunks-1: # final chunk always goes at the end t_end = y_size t_start = t_end - y_chunk.shape[2] else: t_start = i * hop_size // samples_per_latent t_end = t_start + chunk_size // samples_per_latent # remove the edges of the overlaps ol = overlap//samples_per_latent//2 chunk_start = 0 chunk_end = y_chunk.shape[2] if i > 0: # no overlap for the start of the first chunk t_start += ol chunk_start += ol if i < num_chunks-1: # no overlap for the end of the last chunk t_end -= ol chunk_end -= ol # paste the chunked audio into our y_final output audio y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] return y_final def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs): ''' Decode latents to audio. If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. Every autoencoder will have a different receptive field size, and thus ideal overlap. You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff. The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. Smaller chunk_size uses less memory, but more compute. The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks ''' if not chunked: # default behavior. Decode the entire latent in parallel return self.decode(latents, **kwargs) else: # chunked decoding hop_size = chunk_size - overlap total_size = latents.shape[2] batch_size = latents.shape[0] chunks = [] for i in range(0, total_size - chunk_size + 1, hop_size): chunk = latents[:,:,i:i+chunk_size] chunks.append(chunk) if i+chunk_size != total_size: # Final chunk chunk = latents[:,:,-chunk_size:] chunks.append(chunk) chunks = torch.stack(chunks) num_chunks = chunks.shape[0] # samples_per_latent is just the downsampling ratio samples_per_latent = self.downsampling_ratio # Create an empty waveform, we will populate it with chunks as decode them y_size = total_size * samples_per_latent y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device) for i in range(num_chunks): x_chunk = chunks[i,:] # decode the chunk y_chunk = self.decode(x_chunk) # figure out where to put the audio along the time domain if i == num_chunks-1: # final chunk always goes at the end t_end = y_size t_start = t_end - y_chunk.shape[2] else: t_start = i * hop_size * samples_per_latent t_end = t_start + chunk_size * samples_per_latent # remove the edges of the overlaps ol = (overlap//2) * samples_per_latent chunk_start = 0 chunk_end = y_chunk.shape[2] if i > 0: # no overlap for the start of the first chunk t_start += ol chunk_start += ol if i < num_chunks-1: # no overlap for the end of the last chunk t_end -= ol chunk_end -= ol # paste the chunked audio into our y_final output audio y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] return y_final class DiffusionAutoencoder(AudioAutoencoder): def __init__( self, diffusion: ConditionedDiffusionModel, diffusion_downsampling_ratio, *args, **kwargs ): super().__init__(*args, **kwargs) self.diffusion = diffusion self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio if self.encoder is not None: # Shrink the initial encoder parameters to avoid saturated latents with torch.no_grad(): for param in self.encoder.parameters(): param *= 0.5 def decode(self, latents, steps=100): upsampled_length = latents.shape[2] * self.downsampling_ratio if self.bottleneck is not None: latents = self.bottleneck.decode(latents) if self.decoder is not None: latents = self.decode(latents) # Upsample latents to match diffusion length if latents.shape[2] != upsampled_length: latents = F.interpolate(latents, size=upsampled_length, mode='nearest') noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device) from prismaudio_core.inference.sampling import sample decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents) if self.pretransform is not None: if self.pretransform.enable_grad: decoded = self.pretransform.decode(decoded) else: with torch.no_grad(): decoded = self.pretransform.decode(decoded) return decoded # AE factories def create_encoder_from_config(encoder_config: Dict[str, Any]): encoder_type = encoder_config.get("type", None) assert encoder_type is not None, "Encoder type must be specified" if encoder_type == "oobleck": encoder = OobleckEncoder( **encoder_config["config"] ) elif encoder_type == "seanet": from encodec.modules import SEANetEncoder seanet_encoder_config = encoder_config["config"] #SEANet encoder expects strides in reverse order seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) encoder = SEANetEncoder( **seanet_encoder_config ) elif encoder_type == "dac": dac_config = encoder_config["config"] encoder = DACEncoderWrapper(**dac_config) elif encoder_type == "local_attn": from .local_attention import TransformerEncoder1D local_attn_config = encoder_config["config"] encoder = TransformerEncoder1D( **local_attn_config ) else: raise ValueError(f"Unknown encoder type {encoder_type}") requires_grad = encoder_config.get("requires_grad", True) if not requires_grad: for param in encoder.parameters(): param.requires_grad = False return encoder def create_decoder_from_config(decoder_config: Dict[str, Any]): decoder_type = decoder_config.get("type", None) assert decoder_type is not None, "Decoder type must be specified" if decoder_type == "oobleck": decoder = OobleckDecoder( **decoder_config["config"] ) elif decoder_type == "seanet": from encodec.modules import SEANetDecoder decoder = SEANetDecoder( **decoder_config["config"] ) elif decoder_type == "dac": dac_config = decoder_config["config"] decoder = DACDecoderWrapper(**dac_config) elif decoder_type == "local_attn": from .local_attention import TransformerDecoder1D local_attn_config = decoder_config["config"] decoder = TransformerDecoder1D( **local_attn_config ) else: raise ValueError(f"Unknown decoder type {decoder_type}") requires_grad = decoder_config.get("requires_grad", True) if not requires_grad: for param in decoder.parameters(): param.requires_grad = False return decoder def create_autoencoder_from_config(config: Dict[str, Any]): ae_config = config["model"] encoder = create_encoder_from_config(ae_config["encoder"]) decoder = create_decoder_from_config(ae_config["decoder"]) bottleneck = ae_config.get("bottleneck", None) latent_dim = ae_config.get("latent_dim", None) assert latent_dim is not None, "latent_dim must be specified in model config" downsampling_ratio = ae_config.get("downsampling_ratio", None) assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" io_channels = ae_config.get("io_channels", None) assert io_channels is not None, "io_channels must be specified in model config" sample_rate = config.get("sample_rate", None) assert sample_rate is not None, "sample_rate must be specified in model config" in_channels = ae_config.get("in_channels", None) out_channels = ae_config.get("out_channels", None) pretransform = ae_config.get("pretransform", None) if pretransform is not None: pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate) if bottleneck is not None: bottleneck = _lazy_create_bottleneck_from_config(bottleneck) soft_clip = ae_config["decoder"].get("soft_clip", False) return AudioAutoencoder( encoder, decoder, io_channels=io_channels, latent_dim=latent_dim, downsampling_ratio=downsampling_ratio, sample_rate=sample_rate, bottleneck=bottleneck, pretransform=pretransform, in_channels=in_channels, out_channels=out_channels, soft_clip=soft_clip ) def create_diffAE_from_config(config: Dict[str, Any]): diffae_config = config["model"] if "encoder" in diffae_config: encoder = create_encoder_from_config(diffae_config["encoder"]) else: encoder = None if "decoder" in diffae_config: decoder = create_decoder_from_config(diffae_config["decoder"]) else: decoder = None diffusion_model_type = diffae_config["diffusion"]["type"] if diffusion_model_type == "DAU1d": diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"]) elif diffusion_model_type == "adp_1d": diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"]) elif diffusion_model_type == "dit": diffusion = DiTWrapper(**diffae_config["diffusion"]["config"]) latent_dim = diffae_config.get("latent_dim", None) assert latent_dim is not None, "latent_dim must be specified in model config" downsampling_ratio = diffae_config.get("downsampling_ratio", None) assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" io_channels = diffae_config.get("io_channels", None) assert io_channels is not None, "io_channels must be specified in model config" sample_rate = config.get("sample_rate", None) assert sample_rate is not None, "sample_rate must be specified in model config" bottleneck = diffae_config.get("bottleneck", None) pretransform = diffae_config.get("pretransform", None) if pretransform is not None: pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate) if bottleneck is not None: bottleneck = _lazy_create_bottleneck_from_config(bottleneck) diffusion_downsampling_ratio = None, if diffusion_model_type == "DAU1d": diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"]) elif diffusion_model_type == "adp_1d": diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"]) elif diffusion_model_type == "dit": diffusion_downsampling_ratio = 1 return DiffusionAutoencoder( encoder=encoder, decoder=decoder, diffusion=diffusion, io_channels=io_channels, sample_rate=sample_rate, latent_dim=latent_dim, downsampling_ratio=downsampling_ratio, diffusion_downsampling_ratio=diffusion_downsampling_ratio, bottleneck=bottleneck, pretransform=pretransform )