Move VAE encode outside autocast to match original STAR pipeline
The original STAR code runs vae_encode() before the amp.autocast() block. Our code had it inside, which changes how the encoder processes tensors and can produce different latent representations. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -302,9 +302,7 @@ def run_star_inference(
|
|||||||
text_encoder.device = "cpu"
|
text_encoder.device = "cpu"
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# -- Diffusion sampling (autocast needed for fp16 VAE / UNet) --
|
# ---- Stage 2: VAE encode (outside autocast, matches original STAR) ----
|
||||||
with torch.amp.autocast("cuda"):
|
|
||||||
# ---- Stage 2: VAE encode ----
|
|
||||||
if offload != "disabled":
|
if offload != "disabled":
|
||||||
_move(vae, device)
|
_move(vae, device)
|
||||||
video_data_feature = vae_encode(vae, video_data, chunk_size=vae_enc_chunk)
|
video_data_feature = vae_encode(vae, video_data, chunk_size=vae_enc_chunk)
|
||||||
@@ -314,6 +312,8 @@ def run_star_inference(
|
|||||||
del video_data
|
del video_data
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# -- Diffusion sampling + VAE decode (under autocast) --
|
||||||
|
with torch.amp.autocast("cuda"):
|
||||||
t = torch.LongTensor([total_noise_levels - 1]).to(device)
|
t = torch.LongTensor([total_noise_levels - 1]).to(device)
|
||||||
noised_lr = diffusion.diffuse(video_data_feature, t)
|
noised_lr = diffusion.diffuse(video_data_feature, t)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user