diff --git a/nodes/sampler.py b/nodes/sampler.py index 1d89b84..58dd712 100644 --- a/nodes/sampler.py +++ b/nodes/sampler.py @@ -144,22 +144,28 @@ class PrismAudioSampler: def _substitute_empty_features(diffusion, conditioning, device, dtype): - """Replace sync conditioning with learned empty embedding when video is absent. + """Replace video/sync conditioning with learned empty embeddings when video is absent. - Only substitutes sync_features — NOT video_features. The reference code - (predict.py/app.py) checks for 'metaclip_features' which doesn't exist in the - prismaudio.json config, so video substitution never runs. Cond_MLP with zero - input + bias-free linear layers naturally produces near-zero output. + empty_clip_feat and empty_sync_feat are learned null embeddings in the conditioner + output space (1024-dim). Passing zero features through bias-free Cond_MLP produces + near-zero activations, NOT the learned null signal the model was trained with. The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim]. """ dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model - # Only substitute sync_features (matching reference behavior for prismaudio config) + # Substitute video_features with learned empty_clip_feat + if hasattr(dit, 'empty_clip_feat') and 'video_features' in conditioning: + empty = dit.empty_clip_feat.to(device, dtype=dtype) # [1, 1024] + batch_size = conditioning['video_features'][0].shape[0] + empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024] + conditioning['video_features'][0] = empty_expanded + conditioning['video_features'][1] = torch.ones(batch_size, 1, device=device) + + # Substitute sync_features with learned empty_sync_feat if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning: - empty = dit.empty_sync_feat.to(device, dtype=dtype) - cond_tensor = conditioning['sync_features'][0] - batch_size = cond_tensor.shape[0] - empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) + empty = dit.empty_sync_feat.to(device, dtype=dtype) # [1, 1024] + batch_size = conditioning['sync_features'][0].shape[0] + empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024] conditioning['sync_features'][0] = empty_expanded conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device) diff --git a/nodes/text_only.py b/nodes/text_only.py index f4434a8..336c3b8 100644 --- a/nodes/text_only.py +++ b/nodes/text_only.py @@ -15,7 +15,7 @@ class PrismAudioTextOnly: return { "required": { "model": ("PRISMAUDIO_MODEL",), - "text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Text description for audio generation"}), + "text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Detailed chain-of-thought description of the audio scene. Use long, descriptive text — e.g. 'A large dog barks sharply twice, with ambient outdoor background noise. The sound is clear and close.' Short prompts produce lower quality."}), "duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}), "steps": ("INT", {"default": 24, "min": 1, "max": 100}), "cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1}),