From f99d2666e80bbcf69521509eda959739b8ed41e8 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 27 Mar 2026 21:21:39 +0100 Subject: [PATCH] fix: interpolate sync_cond to match audio sequence length in transformer Sync_MLP interpolates sync features based on video duration, but audio latent length depends on the user-set audio duration. When video != audio duration, the sequences diverge. Resample sync_cond to x's length before the gated addition so any video/audio duration combo works. Co-Authored-By: Claude Sonnet 4.6 --- prismaudio_core/models/transformer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/prismaudio_core/models/transformer.py b/prismaudio_core/models/transformer.py index 2ee75c4..16e465a 100644 --- a/prismaudio_core/models/transformer.py +++ b/prismaudio_core/models/transformer.py @@ -919,12 +919,18 @@ class ContinuousTransformer(nn.Module): x = self.fusion_mlp(x) if sync_cond is not None: + # Resample sync_cond to match audio sequence length if needed + if sync_cond.shape[1] != x.shape[1]: + sync_cond = torch.nn.functional.interpolate( + sync_cond.transpose(1, 2), size=x.shape[1], + mode='linear', align_corners=False, + ).transpose(1, 2) if self.sync_film_generator is not None: scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) x = x * (1 + scale) + shift elif self.sync_gate is not None: gate_value = torch.sigmoid(self.sync_gate) - x = x + gate_value * sync_cond + x = x + gate_value * sync_cond # else: # x = x + sync_cond