fix: interpolate sync_cond to match audio sequence length in transformer
Sync_MLP interpolates sync features based on video duration, but audio latent length depends on the user-set audio duration. When video != audio duration, the sequences diverge. Resample sync_cond to x's length before the gated addition so any video/audio duration combo works. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -919,12 +919,18 @@ class ContinuousTransformer(nn.Module):
|
|||||||
x = self.fusion_mlp(x)
|
x = self.fusion_mlp(x)
|
||||||
|
|
||||||
if sync_cond is not None:
|
if sync_cond is not None:
|
||||||
|
# Resample sync_cond to match audio sequence length if needed
|
||||||
|
if sync_cond.shape[1] != x.shape[1]:
|
||||||
|
sync_cond = torch.nn.functional.interpolate(
|
||||||
|
sync_cond.transpose(1, 2), size=x.shape[1],
|
||||||
|
mode='linear', align_corners=False,
|
||||||
|
).transpose(1, 2)
|
||||||
if self.sync_film_generator is not None:
|
if self.sync_film_generator is not None:
|
||||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||||
x = x * (1 + scale) + shift
|
x = x * (1 + scale) + shift
|
||||||
elif self.sync_gate is not None:
|
elif self.sync_gate is not None:
|
||||||
gate_value = torch.sigmoid(self.sync_gate)
|
gate_value = torch.sigmoid(self.sync_gate)
|
||||||
x = x + gate_value * sync_cond
|
x = x + gate_value * sync_cond
|
||||||
# else:
|
# else:
|
||||||
# x = x + sync_cond
|
# x = x + sync_cond
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user