diff --git a/prismaudio_core/models/autoencoders.py b/prismaudio_core/models/autoencoders.py index b279823..52b9fbc 100644 --- a/prismaudio_core/models/autoencoders.py +++ b/prismaudio_core/models/autoencoders.py @@ -304,8 +304,7 @@ class AudioAutoencoder(nn.Module): def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): info = {} - # import ipdb - # ipdb.set_trace() + if self.pretransform is not None and not skip_pretransform: if self.pretransform.enable_grad: if iterate_batch: @@ -476,11 +475,8 @@ class AudioAutoencoder(nn.Module): else: # CHUNKED ENCODING # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) - # import ipdb - # ipdb.set_trace() samples_per_latent = self.downsampling_ratio total_size = audio.shape[2] # in samples - print(f'audio shape: {audio.shape}') batch_size = audio.shape[0] chunk_size *= samples_per_latent # converting metric in latents to samples overlap *= samples_per_latent # converting metric in latents to samples @@ -501,12 +497,10 @@ class AudioAutoencoder(nn.Module): y_size = total_size // samples_per_latent # Create an empty latent, we will populate it with chunks as we encode them y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) - print(f'y_final shape: {y_final.shape}') for i in range(num_chunks): x_chunk = chunks[i,:] # encode the chunk y_chunk = self.encode(x_chunk) - print(f'y_chunk shape: {y_chunk.shape}') # figure out where to put the audio along the time domain if i == num_chunks-1: # final chunk always goes at the end diff --git a/prismaudio_core/models/conditioners.py b/prismaudio_core/models/conditioners.py index 8a5e7ee..3351f47 100644 --- a/prismaudio_core/models/conditioners.py +++ b/prismaudio_core/models/conditioners.py @@ -213,8 +213,6 @@ class Video_Global(Conditioner): self.global_proj = nn.Sequential(nn.Linear(output_dim, global_dim)) def forward(self, x, device: tp.Any = "cuda"): - # import ipdb - # ipdb.set_trace() if not isinstance(x[0], torch.Tensor): video_feats = [] for path in x: @@ -242,8 +240,6 @@ class Video_Sync(Conditioner): self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) def forward(self, x, device: tp.Any = "cuda"): - # import ipdb - # ipdb.set_trace() if not isinstance(x[0], torch.Tensor): video_feats = [] for path in x: @@ -269,8 +265,6 @@ class Text_Linear(Conditioner): self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) def forward(self, x, device: tp.Any = "cuda"): - # import ipdb - # ipdb.set_trace() if not isinstance(x[0], torch.Tensor): video_feats = [] for path in x: @@ -295,8 +289,6 @@ class mm_unchang(Conditioner): super().__init__(dim, output_dim) def forward(self, x, device: tp.Any = "cuda"): - # import ipdb - # ipdb.set_trace() if not isinstance(x[0], torch.Tensor): video_feats = [] for path in x: @@ -349,8 +341,6 @@ class CLIPConditioner(Conditioner): self.model.to(device) self.proj_out.to(device) - # import ipdb - # ipdb.set_trace() self.model.eval() if not isinstance(images[0], torch.Tensor): diff --git a/prismaudio_core/models/diffusion.py b/prismaudio_core/models/diffusion.py index 8233c09..b66d115 100644 --- a/prismaudio_core/models/diffusion.py +++ b/prismaudio_core/models/diffusion.py @@ -21,25 +21,6 @@ def _get_create_multi_conditioner_from_conditioning_config(): from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config return create_multi_conditioner_from_conditioning_config -from time import time - -class Profiler: - - def __init__(self): - self.ticks = [[time(), None]] - - def tick(self, msg): - self.ticks.append([time(), msg]) - - def __repr__(self): - rep = 80 * "=" + "\n" - for i in range(1, len(self.ticks)): - msg = self.ticks[i][1] - ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] - rep += msg + f": {ellapsed*1000:.2f}ms\n" - rep += 80 * "=" + "\n\n\n" - return rep - class DiffusionModel(nn.Module): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -176,8 +157,7 @@ class ConditionedDiffusionModelWrapper(nn.Module): cross_attention_input.append(cross_attn_in) cross_attention_masks.append(cross_attn_mask) - # import ipdb - # ipdb.set_trace() + cross_attention_input = torch.cat(cross_attention_input, dim=1) cross_attention_masks = torch.cat(cross_attention_masks, dim=1) @@ -314,10 +294,6 @@ class UNetCFG1DWrapper(ConditionedDiffusionModel): prepend_cond=None, prepend_cond_mask=None, **kwargs): - p = Profiler() - - p.tick("start") - channels_list = None if input_concat_cond is not None: channels_list = [input_concat_cond] @@ -337,9 +313,6 @@ class UNetCFG1DWrapper(ConditionedDiffusionModel): negative_embedding_mask=negative_cross_attn_mask, **kwargs) - p.tick("UNetCFG1D forward") - - #print(f"Profiler: {p}") return outputs class UNet1DCondWrapper(ConditionedDiffusionModel): @@ -613,54 +586,6 @@ class DiTWrapper(ConditionedDiffusionModel): global_embed=global_cond, **kwargs) -class MMDiTWrapper(ConditionedDiffusionModel): - def __init__( - self, - *args, - **kwargs - ): - super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) - - self.model = MMAudio(*args, **kwargs) - - # with torch.no_grad(): - # for param in self.model.parameters(): - # param *= 0.5 - - def forward(self, - x, - t, - clip_f, - sync_f, - text_f, - inpaint_masked_input=None, - t5_features=None, - metaclip_global_text_features=None, - cfg_scale=1.0, - cfg_dropout_prob: float = 0.0, - batch_cfg: bool = True, - rescale_cfg: bool = False, - scale_phi: float = 0.0, - **kwargs): - - # breakpoint() - assert batch_cfg, "batch_cfg must be True for DiTWrapper" - #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" - - return self.model( - latent=x, - t=t, - clip_f=clip_f, - sync_f=sync_f, - text_f=text_f, - inpaint_masked_input=inpaint_masked_input, - t5_features=t5_features, - metaclip_global_text_features=metaclip_global_text_features, - cfg_scale=cfg_scale, - cfg_dropout_prob=cfg_dropout_prob, - scale_phi=scale_phi, - **kwargs) - class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel): """ A diffusion model that takes in conditioning @@ -739,8 +664,6 @@ class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel): } def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): - # breakpoint() - # print(kwargs) return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs) def generate(self, *args, **kwargs): @@ -888,8 +811,8 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): diffusion_model = UNet1DCondWrapper(**diffusion_model_config) elif diffusion_model_type == 'dit': diffusion_model = DiTWrapper(**diffusion_model_config) - elif diffusion_model_type == 'mmdit': - diffusion_model = MMDiTWrapper(**diffusion_model_config) + else: + raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}') io_channels = model_config.get('io_channels', None) assert io_channels is not None, "Must specify io_channels in model config" @@ -939,14 +862,9 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): wrapper_fn = ConditionedDiffusionModelWrapper extra_kwargs["diffusion_objective"] = diffusion_objective - elif model_type == "diffusion_prior": - prior_type = model_config.get("prior_type", None) - assert prior_type is not None, "Must specify prior_type in diffusion prior model config" + else: + raise NotImplementedError(f'Unknown model type: {model_type}') - if prior_type == "mono_stereo": - from .diffusion_prior import MonoToStereoDiffusionPrior - wrapper_fn = MonoToStereoDiffusionPrior - return wrapper_fn( diffusion_model, conditioner, diff --git a/prismaudio_core/models/dit.py b/prismaudio_core/models/dit.py index 5e614d7..ec282a9 100644 --- a/prismaudio_core/models/dit.py +++ b/prismaudio_core/models/dit.py @@ -223,8 +223,6 @@ class DiffusionTransformer(nn.Module): # Get the batch of timestep embeddings timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) - # import ipdb - # ipdb.set_trace() # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists if self.timestep_cond_type == "global": if global_embed is not None: diff --git a/prismaudio_core/models/pretransforms.py b/prismaudio_core/models/pretransforms.py index c9942db..89edf3f 100644 --- a/prismaudio_core/models/pretransforms.py +++ b/prismaudio_core/models/pretransforms.py @@ -89,25 +89,6 @@ class AutoencoderPretransform(Pretransform): def load_state_dict(self, state_dict, strict=True): self.model.load_state_dict(state_dict, strict=strict) -class WaveletPretransform(Pretransform): - def __init__(self, channels, levels, wavelet): - super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) - - from .wavelets import WaveletEncode1d, WaveletDecode1d - - self.encoder = WaveletEncode1d(channels, levels, wavelet) - self.decoder = WaveletDecode1d(channels, levels, wavelet) - - self.downsampling_ratio = 2 ** levels - self.io_channels = channels - self.encoded_channels = channels * self.downsampling_ratio - - def encode(self, x): - return self.encoder(x) - - def decode(self, z): - return self.decoder(z) - class PQMFPretransform(Pretransform): def __init__(self, attenuation=100, num_bands=16): # TODO: Fix PQMF to take in in-channels diff --git a/prismaudio_core/models/transformer.py b/prismaudio_core/models/transformer.py index 495a34f..4bd371a 100644 --- a/prismaudio_core/models/transformer.py +++ b/prismaudio_core/models/transformer.py @@ -421,9 +421,10 @@ class Attention(nn.Module): flex_attention_score_mod = None if flex_attention_block_mask is not None or flex_attention_score_mod is not None: - out = flex_attention_compiled(q,k,v, - block_mask = flex_attention_block_mask, - score_mod = flex_attention_score_mod) + raise NotImplementedError( + "FlexAttention is not available in this build. " + "flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments." + ) elif flash_attn_available: fa_dtype_in = q.dtype q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v)) diff --git a/prismaudio_core/models/utils.py b/prismaudio_core/models/utils.py index a90f1e2..4a29e62 100644 --- a/prismaudio_core/models/utils.py +++ b/prismaudio_core/models/utils.py @@ -18,7 +18,6 @@ def load_ckpt_state_dict(ckpt_path, prefix=None): def remove_weight_norm_from_model(model): for module in model.modules(): if hasattr(module, "weight"): - print(f"Removing weight norm from {module}") remove_weight_norm(module) return model