Files
ComfyUI-SelVA/nodes/feature_loader.py
T
2026-03-27 18:04:32 +01:00

54 lines
1.7 KiB
Python

import os
import numpy as np
import torch
from .utils import PRISMAUDIO_CATEGORY
# Keys consumed by the conditioners (video_features, text_features, sync_features)
# global_video_features and global_text_features are NOT consumed by any conditioner
# in the prismaudio.json config — they are unused.
REQUIRED_KEYS = [
"video_features",
"text_features",
"sync_features",
]
class PrismAudioFeatureLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}),
},
}
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
RETURN_NAMES = ("features",)
FUNCTION = "load_features"
CATEGORY = PRISMAUDIO_CATEGORY
def load_features(self, npz_path):
if not os.path.exists(npz_path):
raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}")
data = np.load(npz_path, allow_pickle=True)
features = {}
for key in REQUIRED_KEYS:
if key in data:
features[key] = torch.from_numpy(data[key]).float()
else:
print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros")
# Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None
# Sync_MLP requires length divisible by 8 (segments of 8 frames)
if key == "sync_features":
features[key] = torch.zeros(8, 768)
else:
features[key] = torch.zeros(1, 1024)
# Load duration if present
if "duration" in data:
features["duration"] = float(data["duration"])
return (features,)