feat: add debug_zero_video/sync toggles and feature stats logging to sampler
Allows isolating which feature set causes quality issues: - debug_zero_video: zero video_features → text+sync only - debug_zero_sync: zero sync_features → text+video only Also logs mean/std/shape for all three feature tensors on every run. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+25
-3
@@ -20,6 +20,10 @@ class PrismAudioSampler:
|
||||
"cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
"optional": {
|
||||
"debug_zero_video": ("BOOLEAN", {"default": False, "tooltip": "Zero out video_features (keep text+sync) — isolates video feature issues"}),
|
||||
"debug_zero_sync": ("BOOLEAN", {"default": False, "tooltip": "Zero out sync_features (keep text+video) — isolates sync feature issues"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
@@ -27,7 +31,7 @@ class PrismAudioSampler:
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def generate(self, model, features, duration, steps, cfg_scale, seed):
|
||||
def generate(self, model, features, duration, steps, cfg_scale, seed, debug_zero_video=False, debug_zero_sync=False):
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
strategy = model["strategy"]
|
||||
@@ -42,12 +46,30 @@ class PrismAudioSampler:
|
||||
# Determine if video features are present (not all zeros)
|
||||
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
|
||||
|
||||
video_feat = features["video_features"].to(device, dtype=dtype)
|
||||
sync_feat = features["sync_features"].to(device, dtype=dtype)
|
||||
|
||||
if debug_zero_video:
|
||||
print("[PrismAudio] DEBUG: zeroing video_features", flush=True)
|
||||
video_feat = torch.zeros_like(video_feat)
|
||||
has_video = False
|
||||
if debug_zero_sync:
|
||||
print("[PrismAudio] DEBUG: zeroing sync_features", flush=True)
|
||||
sync_feat = torch.zeros(8, 768, device=device, dtype=dtype)
|
||||
|
||||
vf_stats = features["video_features"]
|
||||
sf_stats = features["sync_features"]
|
||||
tf_stats = features["text_features"]
|
||||
print(f"[PrismAudio] feature stats — video: shape={tuple(vf_stats.shape)} mean={vf_stats.float().mean():.3f} std={vf_stats.float().std():.3f}", flush=True)
|
||||
print(f"[PrismAudio] feature stats — sync: shape={tuple(sf_stats.shape)} mean={sf_stats.float().mean():.3f} std={sf_stats.float().std():.3f}", flush=True)
|
||||
print(f"[PrismAudio] feature stats — text: shape={tuple(tf_stats.shape)} mean={tf_stats.float().mean():.3f} std={tf_stats.float().std():.3f}", flush=True)
|
||||
|
||||
# Build metadata as a TUPLE of dicts (one per batch sample)
|
||||
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
|
||||
sample_meta = {
|
||||
"video_features": features["video_features"].to(device, dtype=dtype),
|
||||
"video_features": video_feat,
|
||||
"text_features": features["text_features"].to(device, dtype=dtype),
|
||||
"sync_features": features["sync_features"].to(device, dtype=dtype),
|
||||
"sync_features": sync_feat,
|
||||
"video_exist": torch.tensor(has_video),
|
||||
}
|
||||
metadata = (sample_meta,)
|
||||
|
||||
Reference in New Issue
Block a user