fix: resolve critical bugs and quality issues in prismaudio_core/models

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 17:56:02 +01:00
parent 6e1186d5bd
commit 30e85f0f99
6 changed files with 13 additions and 33 deletions
+3 -6
View File
@@ -13,6 +13,7 @@ from .blocks import SnakeBeta
from .bottleneck import Bottleneck, DiscreteBottleneck from .bottleneck import Bottleneck, DiscreteBottleneck
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
from .pretransforms import Pretransform from .pretransforms import Pretransform
from .utils import checkpoint
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
@@ -46,10 +47,6 @@ def _lazy_create_bottleneck_from_config(bottleneck):
from prismaudio_core.factory import create_bottleneck_from_config from prismaudio_core.factory import create_bottleneck_from_config
return create_bottleneck_from_config(bottleneck) return create_bottleneck_from_config(bottleneck)
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu": if activation == "elu":
act = nn.ELU() act = nn.ELU()
@@ -617,7 +614,7 @@ class DiffusionAutoencoder(AudioAutoencoder):
latents = self.bottleneck.decode(latents) latents = self.bottleneck.decode(latents)
if self.decoder is not None: if self.decoder is not None:
latents = self.decode(latents) latents = self.decoder(latents)
# Upsample latents to match diffusion length # Upsample latents to match diffusion length
if latents.shape[2] != upsampled_length: if latents.shape[2] != upsampled_length:
@@ -801,7 +798,7 @@ def create_diffAE_from_config(config: Dict[str, Any]):
if bottleneck is not None: if bottleneck is not None:
bottleneck = _lazy_create_bottleneck_from_config(bottleneck) bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
diffusion_downsampling_ratio = None, diffusion_downsampling_ratio = None
if diffusion_model_type == "DAU1d": if diffusion_model_type == "DAU1d":
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"]) diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
-8
View File
@@ -200,14 +200,6 @@ def zero_init(layer):
nn.init.zeros_(layer.bias) nn.init.zeros_(layer.bias)
return layer return layer
def rms_norm(x, scale, eps):
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
return x * scale.to(x.dtype)
#rms_norm = torch.compile(rms_norm)
class AdaRMSNorm(nn.Module): class AdaRMSNorm(nn.Module):
def __init__(self, features, cond_features, eps=1e-6): def __init__(self, features, cond_features, eps=1e-6):
super().__init__() super().__init__()
+2 -1
View File
@@ -771,6 +771,7 @@ def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]):
elif model_type == "adp_uncond_1d": elif model_type == "adp_uncond_1d":
io_channels = model_config.get("io_channels", 64)
model = UNet1DUncondWrapper( model = UNet1DUncondWrapper(
io_channels = io_channels, io_channels = io_channels,
**model_config **model_config
@@ -858,7 +859,7 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
extra_kwargs["diffusion_objective"] = diffusion_objective extra_kwargs["diffusion_objective"] = diffusion_objective
extra_kwargs["mm_cond_ids"] = mm_cond_ids extra_kwargs["mm_cond_ids"] = mm_cond_ids
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill': elif model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
wrapper_fn = ConditionedDiffusionModelWrapper wrapper_fn = ConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective extra_kwargs["diffusion_objective"] = diffusion_objective
+1 -4
View File
@@ -5,10 +5,7 @@ from torch import nn
from .blocks import AdaRMSNorm from .blocks import AdaRMSNorm
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
from .utils import checkpoint
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py # Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
class ContinuousLocalTransformer(nn.Module): class ContinuousLocalTransformer(nn.Module):
+3 -14
View File
@@ -18,16 +18,12 @@ except ImportError:
flash_attn_kvpacked_func = None flash_attn_kvpacked_func = None
flash_attn_func = None flash_attn_func = None
from .utils import compile from .utils import compile, checkpoint
try: try:
import natten import natten
except ImportError: except ImportError:
natten = None natten = None
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift return x * (1 + scale) + shift
@@ -389,8 +385,6 @@ class Attention(nn.Module):
self.lambda_hf = nn.Parameter(torch.zeros(dim)) self.lambda_hf = nn.Parameter(torch.zeros(dim))
self.causal = causal self.causal = causal
if causal:
print('Using `causal` argument disables FlexAttention. If you want to use them together, incorporate causal masking into `flex_attention_block_mask`.')
@compile @compile
def apply_qk_layernorm(self, q, k): def apply_qk_layernorm(self, q, k):
@@ -409,14 +403,8 @@ class Attention(nn.Module):
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
flash_attn_available = HAS_FLASH_ATTN flash_attn_available = HAS_FLASH_ATTN
if flash_attn_sliding_window is not None and (not flash_attn_available):
print(f"Cannot use FlashAttention sliding window as FlashAttention is disabled or not available")
if (flex_attention_block_mask is not None or flex_attention_score_mod is not None) and flash_attn_sliding_window is not None:
print(f"cannot use both FlashAttention and FlexAttention, favouring FlexAttention")
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None): if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
print(f"Disabling FlexAttention because causal is set")
flex_attention_block_mask = None flex_attention_block_mask = None
flex_attention_score_mod = None flex_attention_score_mod = None
@@ -606,7 +594,6 @@ class TransformerBlock(nn.Module):
self.dim_context = dim_context self.dim_context = dim_context
self.causal = causal self.causal = causal
if layer_scale and zero_init_branch_outputs: if layer_scale and zero_init_branch_outputs:
print('zero_init_branch_outputs is redundant with layer_scale, setting zero_init_branch_outputs to False')
zero_init_branch_outputs = False zero_init_branch_outputs = False
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim) self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
@@ -909,6 +896,8 @@ class ContinuousTransformer(nn.Module):
model_dtype = next(self.parameters()).dtype model_dtype = next(self.parameters()).dtype
x = x.to(model_dtype) x = x.to(model_dtype)
prepend_length = 0
info = { info = {
"hidden_states": [], "hidden_states": [],
} }
+4
View File
@@ -162,6 +162,10 @@ def resample(video_feat, audio_latent):
aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta] aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
return aligned_video.permute(0, 2, 1) # [B, Ta, D] return aligned_video.permute(0, 2, 1) # [B, Ta, D]
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
import os import os
enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1" enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1"