feat: add debug_zero_video/sync toggles and feature stats logging to sampler
Allows isolating which feature set causes quality issues: - debug_zero_video: zero video_features → text+sync only - debug_zero_sync: zero sync_features → text+video only Also logs mean/std/shape for all three feature tensors on every run. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+25
-3
@@ -20,6 +20,10 @@ class PrismAudioSampler:
|
|||||||
"cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
|
"cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
},
|
},
|
||||||
|
"optional": {
|
||||||
|
"debug_zero_video": ("BOOLEAN", {"default": False, "tooltip": "Zero out video_features (keep text+sync) — isolates video feature issues"}),
|
||||||
|
"debug_zero_sync": ("BOOLEAN", {"default": False, "tooltip": "Zero out sync_features (keep text+video) — isolates sync feature issues"}),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
RETURN_TYPES = ("AUDIO",)
|
||||||
@@ -27,7 +31,7 @@ class PrismAudioSampler:
|
|||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
CATEGORY = PRISMAUDIO_CATEGORY
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
def generate(self, model, features, duration, steps, cfg_scale, seed):
|
def generate(self, model, features, duration, steps, cfg_scale, seed, debug_zero_video=False, debug_zero_sync=False):
|
||||||
device = get_device()
|
device = get_device()
|
||||||
dtype = model["dtype"]
|
dtype = model["dtype"]
|
||||||
strategy = model["strategy"]
|
strategy = model["strategy"]
|
||||||
@@ -42,12 +46,30 @@ class PrismAudioSampler:
|
|||||||
# Determine if video features are present (not all zeros)
|
# Determine if video features are present (not all zeros)
|
||||||
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
|
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
|
||||||
|
|
||||||
|
video_feat = features["video_features"].to(device, dtype=dtype)
|
||||||
|
sync_feat = features["sync_features"].to(device, dtype=dtype)
|
||||||
|
|
||||||
|
if debug_zero_video:
|
||||||
|
print("[PrismAudio] DEBUG: zeroing video_features", flush=True)
|
||||||
|
video_feat = torch.zeros_like(video_feat)
|
||||||
|
has_video = False
|
||||||
|
if debug_zero_sync:
|
||||||
|
print("[PrismAudio] DEBUG: zeroing sync_features", flush=True)
|
||||||
|
sync_feat = torch.zeros(8, 768, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
vf_stats = features["video_features"]
|
||||||
|
sf_stats = features["sync_features"]
|
||||||
|
tf_stats = features["text_features"]
|
||||||
|
print(f"[PrismAudio] feature stats — video: shape={tuple(vf_stats.shape)} mean={vf_stats.float().mean():.3f} std={vf_stats.float().std():.3f}", flush=True)
|
||||||
|
print(f"[PrismAudio] feature stats — sync: shape={tuple(sf_stats.shape)} mean={sf_stats.float().mean():.3f} std={sf_stats.float().std():.3f}", flush=True)
|
||||||
|
print(f"[PrismAudio] feature stats — text: shape={tuple(tf_stats.shape)} mean={tf_stats.float().mean():.3f} std={tf_stats.float().std():.3f}", flush=True)
|
||||||
|
|
||||||
# Build metadata as a TUPLE of dicts (one per batch sample)
|
# Build metadata as a TUPLE of dicts (one per batch sample)
|
||||||
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
|
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
|
||||||
sample_meta = {
|
sample_meta = {
|
||||||
"video_features": features["video_features"].to(device, dtype=dtype),
|
"video_features": video_feat,
|
||||||
"text_features": features["text_features"].to(device, dtype=dtype),
|
"text_features": features["text_features"].to(device, dtype=dtype),
|
||||||
"sync_features": features["sync_features"].to(device, dtype=dtype),
|
"sync_features": sync_feat,
|
||||||
"video_exist": torch.tensor(has_video),
|
"video_exist": torch.tensor(has_video),
|
||||||
}
|
}
|
||||||
metadata = (sample_meta,)
|
metadata = (sample_meta,)
|
||||||
|
|||||||
Reference in New Issue
Block a user