diff --git a/nodes/sampler.py b/nodes/sampler.py index 2991253..1d89b84 100644 --- a/nodes/sampler.py +++ b/nodes/sampler.py @@ -20,6 +20,10 @@ class PrismAudioSampler: "cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), }, + "optional": { + "debug_zero_video": ("BOOLEAN", {"default": False, "tooltip": "Zero out video_features (keep text+sync) — isolates video feature issues"}), + "debug_zero_sync": ("BOOLEAN", {"default": False, "tooltip": "Zero out sync_features (keep text+video) — isolates sync feature issues"}), + }, } RETURN_TYPES = ("AUDIO",) @@ -27,7 +31,7 @@ class PrismAudioSampler: FUNCTION = "generate" CATEGORY = PRISMAUDIO_CATEGORY - def generate(self, model, features, duration, steps, cfg_scale, seed): + def generate(self, model, features, duration, steps, cfg_scale, seed, debug_zero_video=False, debug_zero_sync=False): device = get_device() dtype = model["dtype"] strategy = model["strategy"] @@ -42,12 +46,30 @@ class PrismAudioSampler: # Determine if video features are present (not all zeros) has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0 + video_feat = features["video_features"].to(device, dtype=dtype) + sync_feat = features["sync_features"].to(device, dtype=dtype) + + if debug_zero_video: + print("[PrismAudio] DEBUG: zeroing video_features", flush=True) + video_feat = torch.zeros_like(video_feat) + has_video = False + if debug_zero_sync: + print("[PrismAudio] DEBUG: zeroing sync_features", flush=True) + sync_feat = torch.zeros(8, 768, device=device, dtype=dtype) + + vf_stats = features["video_features"] + sf_stats = features["sync_features"] + tf_stats = features["text_features"] + print(f"[PrismAudio] feature stats — video: shape={tuple(vf_stats.shape)} mean={vf_stats.float().mean():.3f} std={vf_stats.float().std():.3f}", flush=True) + print(f"[PrismAudio] feature stats — sync: shape={tuple(sf_stats.shape)} mean={sf_stats.float().mean():.3f} std={sf_stats.float().std():.3f}", flush=True) + print(f"[PrismAudio] feature stats — text: shape={tuple(tf_stats.shape)} mean={tf_stats.float().mean():.3f} std={tf_stats.float().std():.3f}", flush=True) + # Build metadata as a TUPLE of dicts (one per batch sample) # MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this sample_meta = { - "video_features": features["video_features"].to(device, dtype=dtype), - "text_features": features["text_features"].to(device, dtype=dtype), - "sync_features": features["sync_features"].to(device, dtype=dtype), + "video_features": video_feat, + "text_features": features["text_features"].to(device, dtype=dtype), + "sync_features": sync_feat, "video_exist": torch.tensor(has_video), } metadata = (sample_meta,)