diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..99a5a88 --- /dev/null +++ b/__init__.py @@ -0,0 +1,6 @@ +""" +ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026). +""" +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/nodes/__init__.py b/nodes/__init__.py new file mode 100644 index 0000000..51182bc --- /dev/null +++ b/nodes/__init__.py @@ -0,0 +1,19 @@ +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} + +_NODES = { + "PrismAudioModelLoader": (".model_loader", "PrismAudioModelLoader", "PrismAudio Model Loader"), + "PrismAudioFeatureLoader": (".feature_loader", "PrismAudioFeatureLoader", "PrismAudio Feature Loader"), + "PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"), + "PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"), + "PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"), +} + +for key, (module_path, class_name, display_name) in _NODES.items(): + try: + import importlib + mod = importlib.import_module(module_path, package=__name__) + NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name) + NODE_DISPLAY_NAME_MAPPINGS[key] = display_name + except (ImportError, AttributeError) as e: + print(f"[PrismAudio] Skipping {key}: {e}") diff --git a/nodes/utils.py b/nodes/utils.py new file mode 100644 index 0000000..e016ad6 --- /dev/null +++ b/nodes/utils.py @@ -0,0 +1,64 @@ +import os +import torch +import folder_paths +import comfy.model_management as mm + +PRISMAUDIO_CATEGORY = "PrismAudio" +SAMPLE_RATE = 44100 +DOWNSAMPLING_RATIO = 2048 +IO_CHANNELS = 64 + +def get_prismaudio_model_dir(): + model_dir = os.path.join(folder_paths.models_dir, "prismaudio") + os.makedirs(model_dir, exist_ok=True) + return model_dir + +def register_model_folder(): + model_dir = get_prismaudio_model_dir() + folder_paths.add_model_folder_path("prismaudio", model_dir) + +def get_device(): + return mm.get_torch_device() + +def get_offload_device(): + return mm.unet_offload_device() + +def get_free_memory(device=None): + if device is None: + device = get_device() + return mm.get_free_memory(device) + +def soft_empty_cache(): + mm.soft_empty_cache() + +def determine_precision(preference, device): + if preference != "auto": + return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference] + if device.type == "cpu": + return torch.float32 + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + return torch.bfloat16 + return torch.float16 + +def determine_offload_strategy(preference): + if preference != "auto": + return preference + free_mem = get_free_memory() + gb = free_mem / (1024 ** 3) + if gb >= 24: + return "keep_in_vram" + else: + return "offload_to_cpu" + +def try_import_flash_attn(): + try: + import flash_attn + return flash_attn + except ImportError: + return None + +def resolve_hf_token(): + env_token = os.environ.get("HF_TOKEN") + if env_token: + return env_token + return None diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a995cb3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +einops>=0.7.0 +safetensors +huggingface_hub +transformers>=4.52.3 +k-diffusion>=0.1.1 +alias-free-torch +descript-audio-codec +tqdm