fix: remove MMDiTWrapper import and dead code paths from factory.py

MMDiTWrapper was removed from diffusion.py during cleanup but the import
in factory.py was missed, causing ImportError on every model load.
Also stub wavelet and diffusion_prior paths that reference deleted modules.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 19:12:40 +01:00
parent 807f00417f
commit 9b1cb71b2a
+3 -16
View File
@@ -51,14 +51,7 @@ def create_pretransform_from_config(pretransform_config, sample_rate):
pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
elif pretransform_type == 'wavelet': elif pretransform_type == 'wavelet':
from prismaudio_core.models.pretransforms import WaveletPretransform raise NotImplementedError("wavelet pretransform type is not supported")
wavelet_config = pretransform_config["config"]
channels = wavelet_config["channels"]
levels = wavelet_config["levels"]
wavelet = wavelet_config["wavelet"]
pretransform = WaveletPretransform(channels, levels, wavelet)
elif pretransform_type == 'pqmf': elif pretransform_type == 'pqmf':
from prismaudio_core.models.pretransforms import PQMFPretransform from prismaudio_core.models.pretransforms import PQMFPretransform
pqmf_config = pretransform_config["config"] pqmf_config = pretransform_config["config"]
@@ -327,7 +320,6 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
UNetCFG1DWrapper, UNetCFG1DWrapper,
UNet1DCondWrapper, UNet1DCondWrapper,
DiTWrapper, DiTWrapper,
MMDiTWrapper,
) )
model_config = config["model"] model_config = config["model"]
@@ -350,7 +342,7 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
elif diffusion_model_type == 'dit': elif diffusion_model_type == 'dit':
diffusion_model = DiTWrapper(**diffusion_model_config) diffusion_model = DiTWrapper(**diffusion_model_config)
elif diffusion_model_type == 'mmdit': elif diffusion_model_type == 'mmdit':
diffusion_model = MMDiTWrapper(**diffusion_model_config) raise NotImplementedError("mmdit diffusion model type is not supported")
io_channels = model_config.get('io_channels', None) io_channels = model_config.get('io_channels', None)
assert io_channels is not None, "Must specify io_channels in model config" assert io_channels is not None, "Must specify io_channels in model config"
@@ -401,12 +393,7 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
extra_kwargs["diffusion_objective"] = diffusion_objective extra_kwargs["diffusion_objective"] = diffusion_objective
elif model_type == "diffusion_prior": elif model_type == "diffusion_prior":
prior_type = model_config.get("prior_type", None) raise NotImplementedError("diffusion_prior model type is not supported")
assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
if prior_type == "mono_stereo":
from prismaudio_core.models.diffusion_prior import MonoToStereoDiffusionPrior
wrapper_fn = MonoToStereoDiffusionPrior
return wrapper_fn( return wrapper_fn(
diffusion_model, diffusion_model,