feat: add debug_zero_video/sync toggles and feature stats logging to sampler

Allows isolating which feature set causes quality issues:
- debug_zero_video: zero video_features → text+sync only
- debug_zero_sync: zero sync_features → text+video only
Also logs mean/std/shape for all three feature tensors on every run.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 21:40:34 +01:00
parent 140cc5ee9a
commit 83a7f2787b
+25 -3
View File
@@ -20,6 +20,10 @@ class PrismAudioSampler:
"cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}), "cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
}, },
"optional": {
"debug_zero_video": ("BOOLEAN", {"default": False, "tooltip": "Zero out video_features (keep text+sync) — isolates video feature issues"}),
"debug_zero_sync": ("BOOLEAN", {"default": False, "tooltip": "Zero out sync_features (keep text+video) — isolates sync feature issues"}),
},
} }
RETURN_TYPES = ("AUDIO",) RETURN_TYPES = ("AUDIO",)
@@ -27,7 +31,7 @@ class PrismAudioSampler:
FUNCTION = "generate" FUNCTION = "generate"
CATEGORY = PRISMAUDIO_CATEGORY CATEGORY = PRISMAUDIO_CATEGORY
def generate(self, model, features, duration, steps, cfg_scale, seed): def generate(self, model, features, duration, steps, cfg_scale, seed, debug_zero_video=False, debug_zero_sync=False):
device = get_device() device = get_device()
dtype = model["dtype"] dtype = model["dtype"]
strategy = model["strategy"] strategy = model["strategy"]
@@ -42,12 +46,30 @@ class PrismAudioSampler:
# Determine if video features are present (not all zeros) # Determine if video features are present (not all zeros)
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0 has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
video_feat = features["video_features"].to(device, dtype=dtype)
sync_feat = features["sync_features"].to(device, dtype=dtype)
if debug_zero_video:
print("[PrismAudio] DEBUG: zeroing video_features", flush=True)
video_feat = torch.zeros_like(video_feat)
has_video = False
if debug_zero_sync:
print("[PrismAudio] DEBUG: zeroing sync_features", flush=True)
sync_feat = torch.zeros(8, 768, device=device, dtype=dtype)
vf_stats = features["video_features"]
sf_stats = features["sync_features"]
tf_stats = features["text_features"]
print(f"[PrismAudio] feature stats — video: shape={tuple(vf_stats.shape)} mean={vf_stats.float().mean():.3f} std={vf_stats.float().std():.3f}", flush=True)
print(f"[PrismAudio] feature stats — sync: shape={tuple(sf_stats.shape)} mean={sf_stats.float().mean():.3f} std={sf_stats.float().std():.3f}", flush=True)
print(f"[PrismAudio] feature stats — text: shape={tuple(tf_stats.shape)} mean={tf_stats.float().mean():.3f} std={tf_stats.float().std():.3f}", flush=True)
# Build metadata as a TUPLE of dicts (one per batch sample) # Build metadata as a TUPLE of dicts (one per batch sample)
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this # MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
sample_meta = { sample_meta = {
"video_features": features["video_features"].to(device, dtype=dtype), "video_features": video_feat,
"text_features": features["text_features"].to(device, dtype=dtype), "text_features": features["text_features"].to(device, dtype=dtype),
"sync_features": features["sync_features"].to(device, dtype=dtype), "sync_features": sync_feat,
"video_exist": torch.tensor(has_video), "video_exist": torch.tensor(has_video),
} }
metadata = (sample_meta,) metadata = (sample_meta,)