fix: weights_only=False for SelVA checkpoints (PyTorch 2.6 compat)
PyTorch 2.6 changed the default to weights_only=True. SelVA checkpoints contain non-tensor types (numpy scalars etc.) that fail strict unpickling. All weights come from trusted sources (jnwnlee/selva HF repo). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -19,7 +19,7 @@ class AutoEncoderModule(nn.Module):
|
||||
need_vae_encoder: bool = True):
|
||||
super().__init__()
|
||||
self.vae: VAE = get_my_vae(mode).eval()
|
||||
vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
|
||||
vae_state_dict = torch.load(vae_ckpt_path, weights_only=False, map_location='cpu')
|
||||
self.vae.load_state_dict(vae_state_dict)
|
||||
self.vae.remove_weight_norm()
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class BigVGAN(nn.Module):
|
||||
super().__init__()
|
||||
vocoder_cfg = OmegaConf.load(config_path)
|
||||
self.vocoder = BigVGANVocoder(vocoder_cfg).eval()
|
||||
vocoder_ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)['generator']
|
||||
vocoder_ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)['generator']
|
||||
self.vocoder.load_state_dict(vocoder_ckpt)
|
||||
|
||||
self.weight_norm_removed = False
|
||||
|
||||
@@ -57,7 +57,7 @@ class FeaturesUtils(nn.Module):
|
||||
|
||||
self.synchformer = Synchformer(video=True, audio=False)
|
||||
self.synchformer.load_state_dict(
|
||||
torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
|
||||
torch.load(synchformer_ckpt, weights_only=False, map_location='cpu'))
|
||||
|
||||
self.text_encoder_t5 = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||
self.tokenizer_t5 = T5TokenizerFast.from_pretrained('google/flan-t5-base')
|
||||
|
||||
Reference in New Issue
Block a user