diff --git a/star_pipeline.py b/star_pipeline.py index 3ce44e3..0cfb2eb 100644 --- a/star_pipeline.py +++ b/star_pipeline.py @@ -302,18 +302,18 @@ def run_star_inference( text_encoder.device = "cpu" torch.cuda.empty_cache() - # -- Diffusion sampling (autocast needed for fp16 VAE / UNet) -- - with torch.amp.autocast("cuda"): - # ---- Stage 2: VAE encode ---- - if offload != "disabled": - _move(vae, device) - video_data_feature = vae_encode(vae, video_data, chunk_size=vae_enc_chunk) - if offload != "disabled": - _move(vae, "cpu") - # Free the full-res pixel tensor — only latents needed from here. - del video_data - torch.cuda.empty_cache() + # ---- Stage 2: VAE encode (outside autocast, matches original STAR) ---- + if offload != "disabled": + _move(vae, device) + video_data_feature = vae_encode(vae, video_data, chunk_size=vae_enc_chunk) + if offload != "disabled": + _move(vae, "cpu") + # Free the full-res pixel tensor — only latents needed from here. + del video_data + torch.cuda.empty_cache() + # -- Diffusion sampling + VAE decode (under autocast) -- + with torch.amp.autocast("cuda"): t = torch.LongTensor([total_noise_levels - 1]).to(device) noised_lr = diffusion.diffuse(video_data_feature, t)