fix: resolve critical bugs and quality issues in prismaudio_core/models
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -13,6 +13,7 @@ from .blocks import SnakeBeta
|
|||||||
from .bottleneck import Bottleneck, DiscreteBottleneck
|
from .bottleneck import Bottleneck, DiscreteBottleneck
|
||||||
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
|
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
|
||||||
from .pretransforms import Pretransform
|
from .pretransforms import Pretransform
|
||||||
|
from .utils import checkpoint
|
||||||
|
|
||||||
|
|
||||||
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
||||||
@@ -46,10 +47,6 @@ def _lazy_create_bottleneck_from_config(bottleneck):
|
|||||||
from prismaudio_core.factory import create_bottleneck_from_config
|
from prismaudio_core.factory import create_bottleneck_from_config
|
||||||
return create_bottleneck_from_config(bottleneck)
|
return create_bottleneck_from_config(bottleneck)
|
||||||
|
|
||||||
def checkpoint(function, *args, **kwargs):
|
|
||||||
kwargs.setdefault("use_reentrant", False)
|
|
||||||
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
|
||||||
|
|
||||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||||
if activation == "elu":
|
if activation == "elu":
|
||||||
act = nn.ELU()
|
act = nn.ELU()
|
||||||
@@ -617,7 +614,7 @@ class DiffusionAutoencoder(AudioAutoencoder):
|
|||||||
latents = self.bottleneck.decode(latents)
|
latents = self.bottleneck.decode(latents)
|
||||||
|
|
||||||
if self.decoder is not None:
|
if self.decoder is not None:
|
||||||
latents = self.decode(latents)
|
latents = self.decoder(latents)
|
||||||
|
|
||||||
# Upsample latents to match diffusion length
|
# Upsample latents to match diffusion length
|
||||||
if latents.shape[2] != upsampled_length:
|
if latents.shape[2] != upsampled_length:
|
||||||
@@ -801,7 +798,7 @@ def create_diffAE_from_config(config: Dict[str, Any]):
|
|||||||
if bottleneck is not None:
|
if bottleneck is not None:
|
||||||
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
||||||
|
|
||||||
diffusion_downsampling_ratio = None,
|
diffusion_downsampling_ratio = None
|
||||||
|
|
||||||
if diffusion_model_type == "DAU1d":
|
if diffusion_model_type == "DAU1d":
|
||||||
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
|
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
|
||||||
|
|||||||
@@ -200,14 +200,6 @@ def zero_init(layer):
|
|||||||
nn.init.zeros_(layer.bias)
|
nn.init.zeros_(layer.bias)
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
def rms_norm(x, scale, eps):
|
|
||||||
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
|
||||||
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
|
||||||
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
|
||||||
return x * scale.to(x.dtype)
|
|
||||||
|
|
||||||
#rms_norm = torch.compile(rms_norm)
|
|
||||||
|
|
||||||
class AdaRMSNorm(nn.Module):
|
class AdaRMSNorm(nn.Module):
|
||||||
def __init__(self, features, cond_features, eps=1e-6):
|
def __init__(self, features, cond_features, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -771,6 +771,7 @@ def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]):
|
|||||||
|
|
||||||
elif model_type == "adp_uncond_1d":
|
elif model_type == "adp_uncond_1d":
|
||||||
|
|
||||||
|
io_channels = model_config.get("io_channels", 64)
|
||||||
model = UNet1DUncondWrapper(
|
model = UNet1DUncondWrapper(
|
||||||
io_channels = io_channels,
|
io_channels = io_channels,
|
||||||
**model_config
|
**model_config
|
||||||
@@ -858,7 +859,7 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
|||||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||||
extra_kwargs["mm_cond_ids"] = mm_cond_ids
|
extra_kwargs["mm_cond_ids"] = mm_cond_ids
|
||||||
|
|
||||||
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
|
elif model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
|
||||||
wrapper_fn = ConditionedDiffusionModelWrapper
|
wrapper_fn = ConditionedDiffusionModelWrapper
|
||||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,7 @@ from torch import nn
|
|||||||
|
|
||||||
from .blocks import AdaRMSNorm
|
from .blocks import AdaRMSNorm
|
||||||
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
|
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
|
||||||
|
from .utils import checkpoint
|
||||||
def checkpoint(function, *args, **kwargs):
|
|
||||||
kwargs.setdefault("use_reentrant", False)
|
|
||||||
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
|
||||||
|
|
||||||
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
|
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
|
||||||
class ContinuousLocalTransformer(nn.Module):
|
class ContinuousLocalTransformer(nn.Module):
|
||||||
|
|||||||
@@ -18,16 +18,12 @@ except ImportError:
|
|||||||
flash_attn_kvpacked_func = None
|
flash_attn_kvpacked_func = None
|
||||||
flash_attn_func = None
|
flash_attn_func = None
|
||||||
|
|
||||||
from .utils import compile
|
from .utils import compile, checkpoint
|
||||||
try:
|
try:
|
||||||
import natten
|
import natten
|
||||||
except ImportError:
|
except ImportError:
|
||||||
natten = None
|
natten = None
|
||||||
|
|
||||||
def checkpoint(function, *args, **kwargs):
|
|
||||||
kwargs.setdefault("use_reentrant", False)
|
|
||||||
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
|
||||||
|
|
||||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||||
return x * (1 + scale) + shift
|
return x * (1 + scale) + shift
|
||||||
|
|
||||||
@@ -389,8 +385,6 @@ class Attention(nn.Module):
|
|||||||
self.lambda_hf = nn.Parameter(torch.zeros(dim))
|
self.lambda_hf = nn.Parameter(torch.zeros(dim))
|
||||||
|
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
if causal:
|
|
||||||
print('Using `causal` argument disables FlexAttention. If you want to use them together, incorporate causal masking into `flex_attention_block_mask`.')
|
|
||||||
|
|
||||||
@compile
|
@compile
|
||||||
def apply_qk_layernorm(self, q, k):
|
def apply_qk_layernorm(self, q, k):
|
||||||
@@ -409,14 +403,8 @@ class Attention(nn.Module):
|
|||||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||||
|
|
||||||
flash_attn_available = HAS_FLASH_ATTN
|
flash_attn_available = HAS_FLASH_ATTN
|
||||||
if flash_attn_sliding_window is not None and (not flash_attn_available):
|
|
||||||
print(f"Cannot use FlashAttention sliding window as FlashAttention is disabled or not available")
|
|
||||||
|
|
||||||
if (flex_attention_block_mask is not None or flex_attention_score_mod is not None) and flash_attn_sliding_window is not None:
|
|
||||||
print(f"cannot use both FlashAttention and FlexAttention, favouring FlexAttention")
|
|
||||||
|
|
||||||
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
|
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
|
||||||
print(f"Disabling FlexAttention because causal is set")
|
|
||||||
flex_attention_block_mask = None
|
flex_attention_block_mask = None
|
||||||
flex_attention_score_mod = None
|
flex_attention_score_mod = None
|
||||||
|
|
||||||
@@ -606,7 +594,6 @@ class TransformerBlock(nn.Module):
|
|||||||
self.dim_context = dim_context
|
self.dim_context = dim_context
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
if layer_scale and zero_init_branch_outputs:
|
if layer_scale and zero_init_branch_outputs:
|
||||||
print('zero_init_branch_outputs is redundant with layer_scale, setting zero_init_branch_outputs to False')
|
|
||||||
zero_init_branch_outputs = False
|
zero_init_branch_outputs = False
|
||||||
|
|
||||||
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||||
@@ -909,6 +896,8 @@ class ContinuousTransformer(nn.Module):
|
|||||||
model_dtype = next(self.parameters()).dtype
|
model_dtype = next(self.parameters()).dtype
|
||||||
x = x.to(model_dtype)
|
x = x.to(model_dtype)
|
||||||
|
|
||||||
|
prepend_length = 0
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"hidden_states": [],
|
"hidden_states": [],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -162,6 +162,10 @@ def resample(video_feat, audio_latent):
|
|||||||
aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
|
aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
|
||||||
return aligned_video.permute(0, 2, 1) # [B, Ta, D]
|
return aligned_video.permute(0, 2, 1) # [B, Ta, D]
|
||||||
|
|
||||||
|
def checkpoint(function, *args, **kwargs):
|
||||||
|
kwargs.setdefault("use_reentrant", False)
|
||||||
|
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||||
|
|
||||||
import os
|
import os
|
||||||
enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1"
|
enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user