Move VAE encode outside autocast to match original STAR pipeline
The original STAR code runs vae_encode() before the amp.autocast() block. Our code had it inside, which changes how the encoder processes tensors and can produce different latent representations. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -302,9 +302,7 @@ def run_star_inference(
|
||||
text_encoder.device = "cpu"
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# -- Diffusion sampling (autocast needed for fp16 VAE / UNet) --
|
||||
with torch.amp.autocast("cuda"):
|
||||
# ---- Stage 2: VAE encode ----
|
||||
# ---- Stage 2: VAE encode (outside autocast, matches original STAR) ----
|
||||
if offload != "disabled":
|
||||
_move(vae, device)
|
||||
video_data_feature = vae_encode(vae, video_data, chunk_size=vae_enc_chunk)
|
||||
@@ -314,6 +312,8 @@ def run_star_inference(
|
||||
del video_data
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# -- Diffusion sampling + VAE decode (under autocast) --
|
||||
with torch.amp.autocast("cuda"):
|
||||
t = torch.LongTensor([total_noise_levels - 1]).to(device)
|
||||
noised_lr = diffusion.diffuse(video_data_feature, t)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user