From 30e85f0f99d391c4f42a7567e143e823cd4630cf Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 27 Mar 2026 17:56:02 +0100 Subject: [PATCH] fix: resolve critical bugs and quality issues in prismaudio_core/models Co-Authored-By: Claude Opus 4.6 --- prismaudio_core/models/autoencoders.py | 9 +++------ prismaudio_core/models/blocks.py | 8 -------- prismaudio_core/models/diffusion.py | 3 ++- prismaudio_core/models/local_attention.py | 5 +---- prismaudio_core/models/transformer.py | 17 +++-------------- prismaudio_core/models/utils.py | 4 ++++ 6 files changed, 13 insertions(+), 33 deletions(-) diff --git a/prismaudio_core/models/autoencoders.py b/prismaudio_core/models/autoencoders.py index 52b9fbc..5a4b45d 100644 --- a/prismaudio_core/models/autoencoders.py +++ b/prismaudio_core/models/autoencoders.py @@ -13,6 +13,7 @@ from .blocks import SnakeBeta from .bottleneck import Bottleneck, DiscreteBottleneck from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper from .pretransforms import Pretransform +from .utils import checkpoint def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): @@ -46,10 +47,6 @@ def _lazy_create_bottleneck_from_config(bottleneck): from prismaudio_core.factory import create_bottleneck_from_config return create_bottleneck_from_config(bottleneck) -def checkpoint(function, *args, **kwargs): - kwargs.setdefault("use_reentrant", False) - return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) - def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: if activation == "elu": act = nn.ELU() @@ -617,7 +614,7 @@ class DiffusionAutoencoder(AudioAutoencoder): latents = self.bottleneck.decode(latents) if self.decoder is not None: - latents = self.decode(latents) + latents = self.decoder(latents) # Upsample latents to match diffusion length if latents.shape[2] != upsampled_length: @@ -801,7 +798,7 @@ def create_diffAE_from_config(config: Dict[str, Any]): if bottleneck is not None: bottleneck = _lazy_create_bottleneck_from_config(bottleneck) - diffusion_downsampling_ratio = None, + diffusion_downsampling_ratio = None if diffusion_model_type == "DAU1d": diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"]) diff --git a/prismaudio_core/models/blocks.py b/prismaudio_core/models/blocks.py index 3c827fd..dfc0466 100644 --- a/prismaudio_core/models/blocks.py +++ b/prismaudio_core/models/blocks.py @@ -200,14 +200,6 @@ def zero_init(layer): nn.init.zeros_(layer.bias) return layer -def rms_norm(x, scale, eps): - dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) - mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) - scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) - return x * scale.to(x.dtype) - -#rms_norm = torch.compile(rms_norm) - class AdaRMSNorm(nn.Module): def __init__(self, features, cond_features, eps=1e-6): super().__init__() diff --git a/prismaudio_core/models/diffusion.py b/prismaudio_core/models/diffusion.py index b66d115..8e3aee1 100644 --- a/prismaudio_core/models/diffusion.py +++ b/prismaudio_core/models/diffusion.py @@ -771,6 +771,7 @@ def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]): elif model_type == "adp_uncond_1d": + io_channels = model_config.get("io_channels", 64) model = UNet1DUncondWrapper( io_channels = io_channels, **model_config @@ -858,7 +859,7 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): extra_kwargs["diffusion_objective"] = diffusion_objective extra_kwargs["mm_cond_ids"] = mm_cond_ids - if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill': + elif model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill': wrapper_fn = ConditionedDiffusionModelWrapper extra_kwargs["diffusion_objective"] = diffusion_objective diff --git a/prismaudio_core/models/local_attention.py b/prismaudio_core/models/local_attention.py index 893ce11..5d6aa7d 100644 --- a/prismaudio_core/models/local_attention.py +++ b/prismaudio_core/models/local_attention.py @@ -5,10 +5,7 @@ from torch import nn from .blocks import AdaRMSNorm from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm - -def checkpoint(function, *args, **kwargs): - kwargs.setdefault("use_reentrant", False) - return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) +from .utils import checkpoint # Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py class ContinuousLocalTransformer(nn.Module): diff --git a/prismaudio_core/models/transformer.py b/prismaudio_core/models/transformer.py index 4bd371a..2ee75c4 100644 --- a/prismaudio_core/models/transformer.py +++ b/prismaudio_core/models/transformer.py @@ -18,16 +18,12 @@ except ImportError: flash_attn_kvpacked_func = None flash_attn_func = None -from .utils import compile +from .utils import compile, checkpoint try: import natten except ImportError: natten = None -def checkpoint(function, *args, **kwargs): - kwargs.setdefault("use_reentrant", False) - return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) - def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): return x * (1 + scale) + shift @@ -389,8 +385,6 @@ class Attention(nn.Module): self.lambda_hf = nn.Parameter(torch.zeros(dim)) self.causal = causal - if causal: - print('Using `causal` argument disables FlexAttention. If you want to use them together, incorporate causal masking into `flex_attention_block_mask`.') @compile def apply_qk_layernorm(self, q, k): @@ -409,14 +403,8 @@ class Attention(nn.Module): k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) flash_attn_available = HAS_FLASH_ATTN - if flash_attn_sliding_window is not None and (not flash_attn_available): - print(f"Cannot use FlashAttention sliding window as FlashAttention is disabled or not available") - - if (flex_attention_block_mask is not None or flex_attention_score_mod is not None) and flash_attn_sliding_window is not None: - print(f"cannot use both FlashAttention and FlexAttention, favouring FlexAttention") if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None): - print(f"Disabling FlexAttention because causal is set") flex_attention_block_mask = None flex_attention_score_mod = None @@ -606,7 +594,6 @@ class TransformerBlock(nn.Module): self.dim_context = dim_context self.causal = causal if layer_scale and zero_init_branch_outputs: - print('zero_init_branch_outputs is redundant with layer_scale, setting zero_init_branch_outputs to False') zero_init_branch_outputs = False self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim) @@ -909,6 +896,8 @@ class ContinuousTransformer(nn.Module): model_dtype = next(self.parameters()).dtype x = x.to(model_dtype) + prepend_length = 0 + info = { "hidden_states": [], } diff --git a/prismaudio_core/models/utils.py b/prismaudio_core/models/utils.py index 4a29e62..c2c4f4f 100644 --- a/prismaudio_core/models/utils.py +++ b/prismaudio_core/models/utils.py @@ -162,6 +162,10 @@ def resample(video_feat, audio_latent): aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta] return aligned_video.permute(0, 2, 1) # [B, Ta, D] +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + import os enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1"