diff --git a/prismaudio_core/factory.py b/prismaudio_core/factory.py index c621026..7eeef44 100644 --- a/prismaudio_core/factory.py +++ b/prismaudio_core/factory.py @@ -51,14 +51,7 @@ def create_pretransform_from_config(pretransform_config, sample_rate): pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) elif pretransform_type == 'wavelet': - from prismaudio_core.models.pretransforms import WaveletPretransform - - wavelet_config = pretransform_config["config"] - channels = wavelet_config["channels"] - levels = wavelet_config["levels"] - wavelet = wavelet_config["wavelet"] - - pretransform = WaveletPretransform(channels, levels, wavelet) + raise NotImplementedError("wavelet pretransform type is not supported") elif pretransform_type == 'pqmf': from prismaudio_core.models.pretransforms import PQMFPretransform pqmf_config = pretransform_config["config"] @@ -327,7 +320,6 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): UNetCFG1DWrapper, UNet1DCondWrapper, DiTWrapper, - MMDiTWrapper, ) model_config = config["model"] @@ -350,7 +342,7 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): elif diffusion_model_type == 'dit': diffusion_model = DiTWrapper(**diffusion_model_config) elif diffusion_model_type == 'mmdit': - diffusion_model = MMDiTWrapper(**diffusion_model_config) + raise NotImplementedError("mmdit diffusion model type is not supported") io_channels = model_config.get('io_channels', None) assert io_channels is not None, "Must specify io_channels in model config" @@ -401,12 +393,7 @@ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): extra_kwargs["diffusion_objective"] = diffusion_objective elif model_type == "diffusion_prior": - prior_type = model_config.get("prior_type", None) - assert prior_type is not None, "Must specify prior_type in diffusion prior model config" - - if prior_type == "mono_stereo": - from prismaudio_core.models.diffusion_prior import MonoToStereoDiffusionPrior - wrapper_fn = MonoToStereoDiffusionPrior + raise NotImplementedError("diffusion_prior model type is not supported") return wrapper_fn( diffusion_model,