fix: remove MMDiTWrapper import and dead code paths from factory.py
MMDiTWrapper was removed from diffusion.py during cleanup but the import in factory.py was missed, causing ImportError on every model load. Also stub wavelet and diffusion_prior paths that reference deleted modules. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -51,14 +51,7 @@ def create_pretransform_from_config(pretransform_config, sample_rate):
|
|||||||
|
|
||||||
pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
|
pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
|
||||||
elif pretransform_type == 'wavelet':
|
elif pretransform_type == 'wavelet':
|
||||||
from prismaudio_core.models.pretransforms import WaveletPretransform
|
raise NotImplementedError("wavelet pretransform type is not supported")
|
||||||
|
|
||||||
wavelet_config = pretransform_config["config"]
|
|
||||||
channels = wavelet_config["channels"]
|
|
||||||
levels = wavelet_config["levels"]
|
|
||||||
wavelet = wavelet_config["wavelet"]
|
|
||||||
|
|
||||||
pretransform = WaveletPretransform(channels, levels, wavelet)
|
|
||||||
elif pretransform_type == 'pqmf':
|
elif pretransform_type == 'pqmf':
|
||||||
from prismaudio_core.models.pretransforms import PQMFPretransform
|
from prismaudio_core.models.pretransforms import PQMFPretransform
|
||||||
pqmf_config = pretransform_config["config"]
|
pqmf_config = pretransform_config["config"]
|
||||||
@@ -327,7 +320,6 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
|||||||
UNetCFG1DWrapper,
|
UNetCFG1DWrapper,
|
||||||
UNet1DCondWrapper,
|
UNet1DCondWrapper,
|
||||||
DiTWrapper,
|
DiTWrapper,
|
||||||
MMDiTWrapper,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = config["model"]
|
model_config = config["model"]
|
||||||
@@ -350,7 +342,7 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
|||||||
elif diffusion_model_type == 'dit':
|
elif diffusion_model_type == 'dit':
|
||||||
diffusion_model = DiTWrapper(**diffusion_model_config)
|
diffusion_model = DiTWrapper(**diffusion_model_config)
|
||||||
elif diffusion_model_type == 'mmdit':
|
elif diffusion_model_type == 'mmdit':
|
||||||
diffusion_model = MMDiTWrapper(**diffusion_model_config)
|
raise NotImplementedError("mmdit diffusion model type is not supported")
|
||||||
|
|
||||||
io_channels = model_config.get('io_channels', None)
|
io_channels = model_config.get('io_channels', None)
|
||||||
assert io_channels is not None, "Must specify io_channels in model config"
|
assert io_channels is not None, "Must specify io_channels in model config"
|
||||||
@@ -401,12 +393,7 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
|||||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||||
|
|
||||||
elif model_type == "diffusion_prior":
|
elif model_type == "diffusion_prior":
|
||||||
prior_type = model_config.get("prior_type", None)
|
raise NotImplementedError("diffusion_prior model type is not supported")
|
||||||
assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
|
|
||||||
|
|
||||||
if prior_type == "mono_stereo":
|
|
||||||
from prismaudio_core.models.diffusion_prior import MonoToStereoDiffusionPrior
|
|
||||||
wrapper_fn = MonoToStereoDiffusionPrior
|
|
||||||
|
|
||||||
return wrapper_fn(
|
return wrapper_fn(
|
||||||
diffusion_model,
|
diffusion_model,
|
||||||
|
|||||||
Reference in New Issue
Block a user