fix: weights_only=False for SelVA checkpoints (PyTorch 2.6 compat)

PyTorch 2.6 changed the default to weights_only=True. SelVA checkpoints
contain non-tensor types (numpy scalars etc.) that fail strict unpickling.
All weights come from trusted sources (jnwnlee/selva HF repo).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-04 16:38:31 +02:00
parent 40388ba6de
commit 614a2e02aa
4 changed files with 5 additions and 5 deletions
+2 -2
View File
@@ -74,14 +74,14 @@ class SelvaModelLoader:
print(f"[SelVA] Loading TextSynch from {video_enc_path}", flush=True) print(f"[SelVA] Loading TextSynch from {video_enc_path}", flush=True)
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval() net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
net_video_enc.load_weights( net_video_enc.load_weights(
torch.load(video_enc_path, map_location="cpu", weights_only=True) torch.load(video_enc_path, map_location="cpu", weights_only=False)
) )
print(f"[SelVA] Loading MMAudio ({variant}) from {gen_path}", flush=True) print(f"[SelVA] Loading MMAudio ({variant}) from {gen_path}", flush=True)
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
net_generator = get_my_mmaudio(variant).to(device, dtype).eval() net_generator = get_my_mmaudio(variant).to(device, dtype).eval()
net_generator.load_weights( net_generator.load_weights(
torch.load(gen_path, map_location="cpu", weights_only=True) torch.load(gen_path, map_location="cpu", weights_only=False)
) )
print("[SelVA] Loading FeaturesUtils (CLIP + T5 + Synchformer + VAE)...", flush=True) print("[SelVA] Loading FeaturesUtils (CLIP + T5 + Synchformer + VAE)...", flush=True)
+1 -1
View File
@@ -19,7 +19,7 @@ class AutoEncoderModule(nn.Module):
need_vae_encoder: bool = True): need_vae_encoder: bool = True):
super().__init__() super().__init__()
self.vae: VAE = get_my_vae(mode).eval() self.vae: VAE = get_my_vae(mode).eval()
vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu') vae_state_dict = torch.load(vae_ckpt_path, weights_only=False, map_location='cpu')
self.vae.load_state_dict(vae_state_dict) self.vae.load_state_dict(vae_state_dict)
self.vae.remove_weight_norm() self.vae.remove_weight_norm()
+1 -1
View File
@@ -15,7 +15,7 @@ class BigVGAN(nn.Module):
super().__init__() super().__init__()
vocoder_cfg = OmegaConf.load(config_path) vocoder_cfg = OmegaConf.load(config_path)
self.vocoder = BigVGANVocoder(vocoder_cfg).eval() self.vocoder = BigVGANVocoder(vocoder_cfg).eval()
vocoder_ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)['generator'] vocoder_ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)['generator']
self.vocoder.load_state_dict(vocoder_ckpt) self.vocoder.load_state_dict(vocoder_ckpt)
self.weight_norm_removed = False self.weight_norm_removed = False
+1 -1
View File
@@ -57,7 +57,7 @@ class FeaturesUtils(nn.Module):
self.synchformer = Synchformer(video=True, audio=False) self.synchformer = Synchformer(video=True, audio=False)
self.synchformer.load_state_dict( self.synchformer.load_state_dict(
torch.load(synchformer_ckpt, weights_only=True, map_location='cpu')) torch.load(synchformer_ckpt, weights_only=False, map_location='cpu'))
self.text_encoder_t5 = T5EncoderModel.from_pretrained('google/flan-t5-base') self.text_encoder_t5 = T5EncoderModel.from_pretrained('google/flan-t5-base')
self.tokenizer_t5 = T5TokenizerFast.from_pretrained('google/flan-t5-base') self.tokenizer_t5 = T5TokenizerFast.from_pretrained('google/flan-t5-base')