diff --git a/prismaudio_core/models/transformer.py b/prismaudio_core/models/transformer.py index 2ee75c4..16e465a 100644 --- a/prismaudio_core/models/transformer.py +++ b/prismaudio_core/models/transformer.py @@ -919,12 +919,18 @@ class ContinuousTransformer(nn.Module): x = self.fusion_mlp(x) if sync_cond is not None: + # Resample sync_cond to match audio sequence length if needed + if sync_cond.shape[1] != x.shape[1]: + sync_cond = torch.nn.functional.interpolate( + sync_cond.transpose(1, 2), size=x.shape[1], + mode='linear', align_corners=False, + ).transpose(1, 2) if self.sync_film_generator is not None: scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) x = x * (1 + scale) + shift elif self.sync_gate is not None: gate_value = torch.sigmoid(self.sync_gate) - x = x + gate_value * sync_cond + x = x + gate_value * sync_cond # else: # x = x + sync_cond