Move VAE encode outside autocast to match original STAR pipeline

The original STAR code runs vae_encode() before the amp.autocast() block.
Our code had it inside, which changes how the encoder processes tensors
and can produce different latent representations.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-15 02:14:25 +01:00
parent 0537d9d8a5
commit e2025c6ca0

View File

@@ -302,18 +302,18 @@ def run_star_inference(
text_encoder.device = "cpu" text_encoder.device = "cpu"
torch.cuda.empty_cache() torch.cuda.empty_cache()
# -- Diffusion sampling (autocast needed for fp16 VAE / UNet) -- # ---- Stage 2: VAE encode (outside autocast, matches original STAR) ----
with torch.amp.autocast("cuda"): if offload != "disabled":
# ---- Stage 2: VAE encode ---- _move(vae, device)
if offload != "disabled": video_data_feature = vae_encode(vae, video_data, chunk_size=vae_enc_chunk)
_move(vae, device) if offload != "disabled":
video_data_feature = vae_encode(vae, video_data, chunk_size=vae_enc_chunk) _move(vae, "cpu")
if offload != "disabled": # Free the full-res pixel tensor — only latents needed from here.
_move(vae, "cpu") del video_data
# Free the full-res pixel tensor — only latents needed from here. torch.cuda.empty_cache()
del video_data
torch.cuda.empty_cache()
# -- Diffusion sampling + VAE decode (under autocast) --
with torch.amp.autocast("cuda"):
t = torch.LongTensor([total_noise_levels - 1]).to(device) t = torch.LongTensor([total_noise_levels - 1]).to(device)
noised_lr = diffusion.diffuse(video_data_feature, t) noised_lr = diffusion.diffuse(video_data_feature, t)