From b60ff4111bbacec4510b74fd1bf7884464e4e504 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 27 Mar 2026 17:05:57 +0100 Subject: [PATCH] feat: extract prismaudio_core config and model factory Co-Authored-By: Claude Opus 4.6 --- prismaudio_core/__init__.py | 5 + prismaudio_core/configs/prismaudio.json | 141 ++++++++ prismaudio_core/factory.py | 426 ++++++++++++++++++++++++ 3 files changed, 572 insertions(+) create mode 100644 prismaudio_core/__init__.py create mode 100644 prismaudio_core/configs/prismaudio.json create mode 100644 prismaudio_core/factory.py diff --git a/prismaudio_core/__init__.py b/prismaudio_core/__init__.py new file mode 100644 index 0000000..064b327 --- /dev/null +++ b/prismaudio_core/__init__.py @@ -0,0 +1,5 @@ +""" +PrismAudio core inference modules. +Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch). +Only inference-critical code — no training, no JAX/TF dependencies. +""" diff --git a/prismaudio_core/configs/prismaudio.json b/prismaudio_core/configs/prismaudio.json new file mode 100644 index 0000000..19d24a0 --- /dev/null +++ b/prismaudio_core/configs/prismaudio.json @@ -0,0 +1,141 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "video_features", + "type": "cond_mlp", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "text_features", + "type": "cond_mlp", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "sync_mlp", + "config": { + "dim": 768, + "output_dim": 1024 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["video_features","text_features"], + "add_cond_ids": ["video_features"], + "sync_cond_ids": ["sync_features"], + "type": "dit", + "diffusion_objective": "rectified_flow", + "config": { + "io_channels": 64, + "embed_dim": 1024, + "depth": 24, + "num_heads": 16, + "cond_token_dim": 1024, + "add_token_dim": 1024, + "sync_token_dim": 1024, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer", + "attn_kwargs":{ + "qk_norm": "rns" + }, + "use_gated": true, + "use_sync_gated": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.1, + "pre_encoded": true, + "timestep_sampler": "trunc_logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 100000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/videoprism/test/0Cu33yBwAPg_000060.npz", + "dataset/videoprism/test/bmKtI808DsU_000009.npz", + "dataset/videoprism/test/VC0c22cJTbM_000424.npz", + "dataset/videoprism/test/F3gsbUTdc2U_000090.npz", + "dataset/videoprism/test/WatvT8A8iug_000100.npz", + "dataset/videoprism/test/0nvBTp-q7tU_000112.npz", + "dataset/videoprism/test/3-PFuDkTM48_000080.npz", + "dataset/videoprism/test/luSAuu-BoPs_000232.npz", + "dataset/videoprism/test/__8UJxW0aOQ_000002.npz", + "dataset/videoprism/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/prismaudio_core/factory.py b/prismaudio_core/factory.py new file mode 100644 index 0000000..c621026 --- /dev/null +++ b/prismaudio_core/factory.py @@ -0,0 +1,426 @@ +""" +Model factory functions for PrismAudio inference. + +Extracted from: + - PrismAudio/models/factory.py + - PrismAudio/models/autoencoders.py (create_autoencoder_from_config) + - PrismAudio/models/diffusion.py (create_diffusion_cond_from_config) + - PrismAudio/models/conditioners.py (create_multi_conditioner_from_conditioning_config) + +Source: https://github.com/FunAudioLLM/ThinkSound (prismaudio branch) +Only inference-critical factory functions are retained. +""" + +import json +import typing as tp +from typing import Dict, Any + +import numpy as np + + +def create_model_from_config(model_config): + model_type = model_config.get('model_type', None) + + assert model_type is not None, 'model_type must be specified in model config' + + if model_type == 'autoencoder': + return create_autoencoder_from_config(model_config) + elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior" or model_type == "diffusion_infill" or model_type == "mm_diffusion_cond": + return create_diffusion_cond_from_config(model_config) + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + + +def create_pretransform_from_config(pretransform_config, sample_rate): + pretransform_type = pretransform_config.get('type', None) + + assert pretransform_type is not None, 'type must be specified in pretransform config' + + if pretransform_type == 'autoencoder': + from prismaudio_core.models.pretransforms import AutoencoderPretransform + + # Create fake top-level config to pass sample rate to autoencoder constructor + # This is a bit of a hack but it keeps us from re-defining the sample rate in the config + autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} + autoencoder = create_autoencoder_from_config(autoencoder_config) + + scale = pretransform_config.get("scale", 1.0) + model_half = pretransform_config.get("model_half", False) + iterate_batch = pretransform_config.get("iterate_batch", False) + chunked = pretransform_config.get("chunked", False) + + pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) + elif pretransform_type == 'wavelet': + from prismaudio_core.models.pretransforms import WaveletPretransform + + wavelet_config = pretransform_config["config"] + channels = wavelet_config["channels"] + levels = wavelet_config["levels"] + wavelet = wavelet_config["wavelet"] + + pretransform = WaveletPretransform(channels, levels, wavelet) + elif pretransform_type == 'pqmf': + from prismaudio_core.models.pretransforms import PQMFPretransform + pqmf_config = pretransform_config["config"] + pretransform = PQMFPretransform(**pqmf_config) + elif pretransform_type == 'dac_pretrained': + from prismaudio_core.models.pretransforms import PretrainedDACPretransform + pretrained_dac_config = pretransform_config["config"] + pretransform = PretrainedDACPretransform(**pretrained_dac_config) + elif pretransform_type == "audiocraft_pretrained": + from prismaudio_core.models.pretransforms import AudiocraftCompressionPretransform + + audiocraft_config = pretransform_config["config"] + pretransform = AudiocraftCompressionPretransform(**audiocraft_config) + else: + raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') + + enable_grad = pretransform_config.get('enable_grad', False) + pretransform.enable_grad = enable_grad + + pretransform.eval().requires_grad_(pretransform.enable_grad) + + return pretransform + + +def create_bottleneck_from_config(bottleneck_config): + bottleneck_type = bottleneck_config.get('type', None) + + assert bottleneck_type is not None, 'type must be specified in bottleneck config' + + if bottleneck_type == 'tanh': + from prismaudio_core.models.bottleneck import TanhBottleneck + bottleneck = TanhBottleneck() + elif bottleneck_type == 'vae': + from prismaudio_core.models.bottleneck import VAEBottleneck + bottleneck = VAEBottleneck() + elif bottleneck_type == 'rvq': + from prismaudio_core.models.bottleneck import RVQBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQBottleneck(**quantizer_params) + elif bottleneck_type == "dac_rvq": + from prismaudio_core.models.bottleneck import DACRVQBottleneck + + bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) + + elif bottleneck_type == 'rvq_vae': + from prismaudio_core.models.bottleneck import RVQVAEBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQVAEBottleneck(**quantizer_params) + + elif bottleneck_type == 'dac_rvq_vae': + from prismaudio_core.models.bottleneck import DACRVQVAEBottleneck + bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) + elif bottleneck_type == 'l2_norm': + from prismaudio_core.models.bottleneck import L2Bottleneck + bottleneck = L2Bottleneck() + elif bottleneck_type == "wasserstein": + from prismaudio_core.models.bottleneck import WassersteinBottleneck + bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) + elif bottleneck_type == "fsq": + from prismaudio_core.models.bottleneck import FSQBottleneck + bottleneck = FSQBottleneck(**bottleneck_config["config"]) + else: + raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') + + requires_grad = bottleneck_config.get('requires_grad', True) + if not requires_grad: + for param in bottleneck.parameters(): + param.requires_grad = False + + return bottleneck + + +def create_autoencoder_from_config(config: Dict[str, Any]): + """Create an AudioAutoencoder from a config dictionary. + + Originally in PrismAudio/models/autoencoders.py. + """ + from prismaudio_core.models.autoencoders import ( + AudioAutoencoder, + create_encoder_from_config, + create_decoder_from_config, + ) + + ae_config = config["model"] + + encoder = create_encoder_from_config(ae_config["encoder"]) + decoder = create_decoder_from_config(ae_config["decoder"]) + + bottleneck = ae_config.get("bottleneck", None) + + latent_dim = ae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = ae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = ae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + in_channels = ae_config.get("in_channels", None) + out_channels = ae_config.get("out_channels", None) + + pretransform = ae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = create_bottleneck_from_config(bottleneck) + + soft_clip = ae_config["decoder"].get("soft_clip", False) + + return AudioAutoencoder( + encoder, + decoder, + io_channels=io_channels, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + sample_rate=sample_rate, + bottleneck=bottleneck, + pretransform=pretransform, + in_channels=in_channels, + out_channels=out_channels, + soft_clip=soft_clip + ) + + +def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]): + """Create a MultiConditioner from a conditioning config dictionary. + + Originally in PrismAudio/models/conditioners.py. + """ + from prismaudio_core.models.conditioners import ( + MultiConditioner, + T5Conditioner, + CLAPTextConditioner, + CLIPTextConditioner, + MetaCLIPTextConditioner, + CLAPAudioConditioner, + Cond_MLP, + Global_MLP, + Sync_MLP, + Cond_MLP_1, + Cond_ConvMLP, + Cond_MLP_Global, + Cond_MLP_Global_1, + Cond_MLP_Global_2, + Video_Global, + Video_Sync, + Text_Linear, + CLIPConditioner, + IntConditioner, + NumberConditioner, + PhonemeConditioner, + TokenizerLUTConditioner, + PretransformConditioner, + mm_unchang, + ) + from prismaudio_core.models.utils import load_ckpt_state_dict + + conditioners = {} + cond_dim = config["cond_dim"] + + default_keys = config.get("default_keys", {}) + + for conditioner_info in config["configs"]: + id = conditioner_info["id"] + + conditioner_type = conditioner_info["type"] + + conditioner_config = {"output_dim": cond_dim} + + conditioner_config.update(conditioner_info["config"]) + if conditioner_type == "t5": + conditioners[id] = T5Conditioner(**conditioner_config) + elif conditioner_type == "clap_text": + conditioners[id] = CLAPTextConditioner(**conditioner_config) + elif conditioner_type == "clip_text": + conditioners[id] = CLIPTextConditioner(**conditioner_config) + elif conditioner_type == "metaclip_text": + conditioners[id] = MetaCLIPTextConditioner(**conditioner_config) + elif conditioner_type == "clap_audio": + conditioners[id] = CLAPAudioConditioner(**conditioner_config) + elif conditioner_type == "cond_mlp": + conditioners[id] = Cond_MLP(**conditioner_config) + elif conditioner_type == "global_mlp": + conditioners[id] = Global_MLP(**conditioner_config) + elif conditioner_type == "sync_mlp": + conditioners[id] = Sync_MLP(**conditioner_config) + elif conditioner_type == "cond_mlp_1": + conditioners[id] = Cond_MLP_1(**conditioner_config) + elif conditioner_type == "cond_convmlp": + conditioners[id] = Cond_ConvMLP(**conditioner_config) + elif conditioner_type == "cond_mlp_global": + conditioners[id] = Cond_MLP_Global(**conditioner_config) + elif conditioner_type == "cond_mlp_global_1": + conditioners[id] = Cond_MLP_Global_1(**conditioner_config) + elif conditioner_type == "cond_mlp_global_2": + conditioners[id] = Cond_MLP_Global_2(**conditioner_config) + elif conditioner_type == "video_global": + conditioners[id] = Video_Global(**conditioner_config) + elif conditioner_type == "video_sync": + conditioners[id] = Video_Sync(**conditioner_config) + elif conditioner_type == "text_linear": + conditioners[id] = Text_Linear(**conditioner_config) + elif conditioner_type == "video_clip": + conditioners[id] = CLIPConditioner(**conditioner_config) + elif conditioner_type == "int": + conditioners[id] = IntConditioner(**conditioner_config) + elif conditioner_type == "number": + conditioners[id] = NumberConditioner(**conditioner_config) + elif conditioner_type == "phoneme": + conditioners[id] = PhonemeConditioner(**conditioner_config) + elif conditioner_type == "lut": + conditioners[id] = TokenizerLUTConditioner(**conditioner_config) + elif conditioner_type == "pretransform": + sample_rate = conditioner_config.pop("sample_rate", None) + assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners" + + pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate) + + if conditioner_config.get("pretransform_ckpt_path", None) is not None: + pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path"))) + + conditioners[id] = PretransformConditioner(pretransform, **conditioner_config) + elif conditioner_type == "mm_unchang": + conditioners[id] = mm_unchang(**conditioner_config) + else: + raise ValueError(f"Unknown conditioner type: {conditioner_type}") + + return MultiConditioner(conditioners, default_keys=default_keys) + + +def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): + """Create a ConditionedDiffusionModelWrapper from a config dictionary. + + Originally in PrismAudio/models/diffusion.py. + """ + from prismaudio_core.models.diffusion import ( + ConditionedDiffusionModelWrapper, + MMConditionedDiffusionModelWrapper, + UNetCFG1DWrapper, + UNet1DCondWrapper, + DiTWrapper, + MMDiTWrapper, + ) + + model_config = config["model"] + + model_type = config["model_type"] + + diffusion_config = model_config.get('diffusion', None) + assert diffusion_config is not None, "Must specify diffusion config" + + diffusion_model_type = diffusion_config.get('type', None) + assert diffusion_model_type is not None, "Must specify diffusion model type" + + diffusion_model_config = diffusion_config.get('config', None) + assert diffusion_model_config is not None, "Must specify diffusion model config" + + if diffusion_model_type == 'adp_cfg_1d': + diffusion_model = UNetCFG1DWrapper(**diffusion_model_config) + elif diffusion_model_type == 'adp_1d': + diffusion_model = UNet1DCondWrapper(**diffusion_model_config) + elif diffusion_model_type == 'dit': + diffusion_model = DiTWrapper(**diffusion_model_config) + elif diffusion_model_type == 'mmdit': + diffusion_model = MMDiTWrapper(**diffusion_model_config) + + io_channels = model_config.get('io_channels', None) + assert io_channels is not None, "Must specify io_channels in model config" + + sample_rate = config.get('sample_rate', None) + assert sample_rate is not None, "Must specify sample_rate in config" + + diffusion_objective = diffusion_config.get('diffusion_objective', 'v') + + conditioning_config = model_config.get('conditioning', None) + + conditioner = None + if conditioning_config is not None: + conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) + + cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) + add_cond_ids = diffusion_config.get('add_cond_ids', []) + sync_cond_ids = diffusion_config.get('sync_cond_ids', []) + global_cond_ids = diffusion_config.get('global_cond_ids', []) + input_concat_ids = diffusion_config.get('input_concat_ids', []) + prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) + mm_cond_ids = diffusion_config.get('mm_cond_ids', []) + zero_init = diffusion_config.get('zero_init', False) + pretransform = model_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d": + min_input_length *= np.prod(diffusion_model_config["factors"]) + elif diffusion_model_type == "dit": + min_input_length *= diffusion_model.model.patch_size + + # Get the proper wrapper class + + extra_kwargs = {} + + if model_type == "mm_diffusion_cond": + wrapper_fn = MMConditionedDiffusionModelWrapper + extra_kwargs["diffusion_objective"] = diffusion_objective + extra_kwargs["mm_cond_ids"] = mm_cond_ids + + if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill': + wrapper_fn = ConditionedDiffusionModelWrapper + extra_kwargs["diffusion_objective"] = diffusion_objective + + elif model_type == "diffusion_prior": + prior_type = model_config.get("prior_type", None) + assert prior_type is not None, "Must specify prior_type in diffusion prior model config" + + if prior_type == "mono_stereo": + from prismaudio_core.models.diffusion_prior import MonoToStereoDiffusionPrior + wrapper_fn = MonoToStereoDiffusionPrior + + return wrapper_fn( + diffusion_model, + conditioner, + min_input_length=min_input_length, + sample_rate=sample_rate, + cross_attn_cond_ids=cross_attention_ids, + global_cond_ids=global_cond_ids, + input_concat_ids=input_concat_ids, + prepend_cond_ids=prepend_cond_ids, + add_cond_ids=add_cond_ids, + sync_cond_ids=sync_cond_ids, + pretransform=pretransform, + io_channels=io_channels, + zero_init=zero_init, + **extra_kwargs + )