fix: interpolate sync_cond to match audio sequence length in transformer

Sync_MLP interpolates sync features based on video duration, but audio
latent length depends on the user-set audio duration. When video != audio
duration, the sequences diverge. Resample sync_cond to x's length before
the gated addition so any video/audio duration combo works.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 21:21:39 +01:00
parent 934a401633
commit f99d2666e8
+7 -1
View File
@@ -919,12 +919,18 @@ class ContinuousTransformer(nn.Module):
x = self.fusion_mlp(x) x = self.fusion_mlp(x)
if sync_cond is not None: if sync_cond is not None:
# Resample sync_cond to match audio sequence length if needed
if sync_cond.shape[1] != x.shape[1]:
sync_cond = torch.nn.functional.interpolate(
sync_cond.transpose(1, 2), size=x.shape[1],
mode='linear', align_corners=False,
).transpose(1, 2)
if self.sync_film_generator is not None: if self.sync_film_generator is not None:
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
x = x * (1 + scale) + shift x = x * (1 + scale) + shift
elif self.sync_gate is not None: elif self.sync_gate is not None:
gate_value = torch.sigmoid(self.sync_gate) gate_value = torch.sigmoid(self.sync_gate)
x = x + gate_value * sync_cond x = x + gate_value * sync_cond
# else: # else:
# x = x + sync_cond # x = x + sync_cond