feat: PrismAudioFeatureLoader node for pre-computed .npz files
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,53 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from .utils import PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
# Keys consumed by the conditioners (video_features, text_features, sync_features)
|
||||||
|
# global_video_features and global_text_features are NOT consumed by any conditioner
|
||||||
|
# in the prismaudio.json config — they are unused.
|
||||||
|
REQUIRED_KEYS = [
|
||||||
|
"video_features",
|
||||||
|
"text_features",
|
||||||
|
"sync_features",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PrismAudioFeatureLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
|
||||||
|
RETURN_NAMES = ("features",)
|
||||||
|
FUNCTION = "load_features"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def load_features(self, npz_path):
|
||||||
|
if not os.path.exists(npz_path):
|
||||||
|
raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}")
|
||||||
|
|
||||||
|
data = np.load(npz_path, allow_pickle=True)
|
||||||
|
|
||||||
|
features = {}
|
||||||
|
for key in REQUIRED_KEYS:
|
||||||
|
if key in data:
|
||||||
|
features[key] = torch.from_numpy(data[key]).float()
|
||||||
|
else:
|
||||||
|
print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros")
|
||||||
|
# Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None
|
||||||
|
# Sync_MLP requires length divisible by 8 (segments of 8 frames)
|
||||||
|
if key == "sync_features":
|
||||||
|
features[key] = torch.zeros(8, 768)
|
||||||
|
else:
|
||||||
|
features[key] = torch.zeros(1, 1024)
|
||||||
|
|
||||||
|
# Load duration if present
|
||||||
|
if "duration" in data:
|
||||||
|
features["duration"] = float(data["duration"])
|
||||||
|
|
||||||
|
return (features,)
|
||||||
Reference in New Issue
Block a user