diff --git a/video_to_video/modules/embedder.py b/video_to_video/modules/embedder.py index 9b2e760..29cc0fd 100644 --- a/video_to_video/modules/embedder.py +++ b/video_to_video/modules/embedder.py @@ -54,9 +54,17 @@ class FrozenOpenCLIPEmbedder(nn.Module): def encode_with_transformer(self, text): x = self.model.token_embedding(text) x = x + self.model.positional_embedding - x = x.permute(1, 0, 2) + # Newer open_clip sets batch_first=True on MHA, so the resblocks + # expect [batch, seq, embed]. Older versions use batch_first=False + # and expect [seq, batch, embed]. Only permute for the old layout. + needs_permute = not getattr( + self.model.transformer.resblocks[0].attn, "batch_first", False + ) + if needs_permute: + x = x.permute(1, 0, 2) x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) - x = x.permute(1, 0, 2) + if needs_permute: + x = x.permute(1, 0, 2) x = self.model.ln_final(x) return x