diff --git a/nodes/sampler.py b/nodes/sampler.py index 1aaf90f..fbafdd2 100644 --- a/nodes/sampler.py +++ b/nodes/sampler.py @@ -18,6 +18,7 @@ class PrismAudioSampler: "duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds. Set to 0 to use the video duration from features automatically."}), "steps": ("INT", {"default": 100, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}), "cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}), + "sync_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 3.0, "step": 0.05, "tooltip": "Scale factor for sync conditioning. Higher values tighten audio-visual sync at the cost of audio naturalness; 0.0 disables sync guidance entirely."}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), }, } @@ -27,7 +28,7 @@ class PrismAudioSampler: FUNCTION = "generate" CATEGORY = PRISMAUDIO_CATEGORY - def generate(self, model, features, duration, steps, cfg_scale, seed): + def generate(self, model, features, duration, steps, cfg_scale, sync_strength, seed): device = get_device() dtype = model["dtype"] strategy = model["strategy"] @@ -43,6 +44,16 @@ class PrismAudioSampler: # Compute latent dimensions latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO) + # Sync temporal coverage diagnostic + sync_frames = features["sync_features"].shape[0] + sync_duration_covered = sync_frames / 25.0 # Synchformer always extracts at 25fps + print(f"[PrismAudio] sync: {sync_frames} frames @ 25fps = {sync_duration_covered:.2f}s | " + f"audio target: {latent_length} latent frames = {duration:.2f}s", flush=True) + if abs(sync_duration_covered - duration) > 0.5: + print(f"[PrismAudio] Warning: sync coverage ({sync_duration_covered:.2f}s) differs from " + f"audio duration ({duration:.2f}s) by more than 0.5s — consider re-extracting features " + f"with the correct video duration.", flush=True) + # Note: no seq length config needed — the model adapts to input tensor shapes # dynamically via its transformer architecture. @@ -76,6 +87,13 @@ class PrismAudioSampler: if not has_video: _substitute_empty_features(diffusion, conditioning, device, dtype) + # Scale sync conditioning after the conditioner MLP (clean linear scale, + # avoids SiLU nonlinearity in Sync_MLP). The CFG null path always uses zeros, + # so this directly scales the sync guidance magnitude: cfg_scale * (strength*cond - 0). + # Only applied when video is present — T2A uses learned empty_sync_feat, not raw sync. + if has_video and sync_strength != 1.0 and 'sync_features' in conditioning: + conditioning['sync_features'][0] = conditioning['sync_features'][0] * sync_strength + # Assemble conditioning inputs for the DiT cond_inputs = diffusion.get_conditioning_inputs(conditioning)