fix: resolve critical bugs and quality issues in prismaudio_core/models
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -13,6 +13,7 @@ from .blocks import SnakeBeta
|
||||
from .bottleneck import Bottleneck, DiscreteBottleneck
|
||||
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
|
||||
from .pretransforms import Pretransform
|
||||
from .utils import checkpoint
|
||||
|
||||
|
||||
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
||||
@@ -46,10 +47,6 @@ def _lazy_create_bottleneck_from_config(bottleneck):
|
||||
from prismaudio_core.factory import create_bottleneck_from_config
|
||||
return create_bottleneck_from_config(bottleneck)
|
||||
|
||||
def checkpoint(function, *args, **kwargs):
|
||||
kwargs.setdefault("use_reentrant", False)
|
||||
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
act = nn.ELU()
|
||||
@@ -617,7 +614,7 @@ class DiffusionAutoencoder(AudioAutoencoder):
|
||||
latents = self.bottleneck.decode(latents)
|
||||
|
||||
if self.decoder is not None:
|
||||
latents = self.decode(latents)
|
||||
latents = self.decoder(latents)
|
||||
|
||||
# Upsample latents to match diffusion length
|
||||
if latents.shape[2] != upsampled_length:
|
||||
@@ -801,7 +798,7 @@ def create_diffAE_from_config(config: Dict[str, Any]):
|
||||
if bottleneck is not None:
|
||||
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
||||
|
||||
diffusion_downsampling_ratio = None,
|
||||
diffusion_downsampling_ratio = None
|
||||
|
||||
if diffusion_model_type == "DAU1d":
|
||||
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
|
||||
|
||||
Reference in New Issue
Block a user