Clean up debug logging and fix precision setting for autocast
Remove all [STAR DEBUG] print statements added during quality investigation. Fix autocast to actually use the selected precision dtype (fp16/bf16) instead of always defaulting to fp16. fp32 now properly disables autocast for full-precision inference. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -271,14 +271,8 @@ def run_star_inference(
|
||||
|
||||
total_noise_levels = int(round(denoise * 1000))
|
||||
|
||||
def _dbg(name, t):
|
||||
print(f"[STAR DEBUG] {name}: shape={list(t.shape)} dtype={t.dtype} "
|
||||
f"min={t.min().item():.4f} max={t.max().item():.4f} "
|
||||
f"mean={t.float().mean().item():.4f} std={t.float().std().item():.4f}")
|
||||
|
||||
# -- Convert ComfyUI frames to STAR format --
|
||||
video_data = comfyui_to_star_frames(images) # [F, 3, H, W]
|
||||
_dbg("input_frames", video_data)
|
||||
|
||||
# Keep a copy at input resolution (on CPU) for colour correction later
|
||||
input_frames_star = video_data.clone().cpu()
|
||||
@@ -286,8 +280,6 @@ def run_star_inference(
|
||||
frames_num, _, orig_h, orig_w = video_data.shape
|
||||
target_h = orig_h * upscale
|
||||
target_w = orig_w * upscale
|
||||
print(f"[STAR DEBUG] orig={orig_h}x{orig_w} target={target_h}x{target_w} "
|
||||
f"frames={frames_num} noise_levels={total_noise_levels}")
|
||||
|
||||
# -- Bilinear upscale to target resolution --
|
||||
video_data = F.interpolate(video_data, size=(target_h, target_w), mode="bilinear", align_corners=False)
|
||||
@@ -296,7 +288,6 @@ def run_star_inference(
|
||||
# -- Pad to model-friendly resolution --
|
||||
padding = pad_to_fit(h, w)
|
||||
video_data = F.pad(video_data, padding, "constant", 1)
|
||||
print(f"[STAR DEBUG] padded={video_data.shape[2]}x{video_data.shape[3]} padding={padding}")
|
||||
|
||||
video_data = video_data.unsqueeze(0).to(device) # [1, F, 3, H_pad, W_pad]
|
||||
|
||||
@@ -305,9 +296,7 @@ def run_star_inference(
|
||||
text_encoder.model.to(device)
|
||||
text_encoder.device = device
|
||||
text = ((prompt.strip() + " ") if prompt.strip() else "") + cfg.positive_prompt
|
||||
print(f"[STAR DEBUG] prompt: {text[:80]}...")
|
||||
y = text_encoder(text).detach()
|
||||
_dbg("text_embedding", y)
|
||||
if offload != "disabled":
|
||||
text_encoder.model.to("cpu")
|
||||
text_encoder.device = "cpu"
|
||||
@@ -317,7 +306,6 @@ def run_star_inference(
|
||||
if offload != "disabled":
|
||||
_move(vae, device)
|
||||
video_data_feature = vae_encode(vae, video_data, chunk_size=vae_enc_chunk)
|
||||
_dbg("vae_latent", video_data_feature)
|
||||
if offload != "disabled":
|
||||
_move(vae, "cpu")
|
||||
# Free the full-res pixel tensor — only latents needed from here.
|
||||
@@ -325,10 +313,11 @@ def run_star_inference(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# -- Diffusion sampling + VAE decode (under autocast) --
|
||||
with torch.amp.autocast("cuda"):
|
||||
# Use the selected precision for autocast (fp16/bf16), or disable for fp32.
|
||||
use_autocast = dtype != torch.float32
|
||||
with torch.amp.autocast("cuda", dtype=dtype, enabled=use_autocast):
|
||||
t = torch.LongTensor([total_noise_levels - 1]).to(device)
|
||||
noised_lr = diffusion.diffuse(video_data_feature, t)
|
||||
_dbg("noised_lr", noised_lr)
|
||||
|
||||
model_kwargs = [{"y": y}, {"y": negative_y}, {"hint": video_data_feature}]
|
||||
|
||||
@@ -380,7 +369,6 @@ def run_star_inference(
|
||||
)
|
||||
finally:
|
||||
_solvers_mod.trange = _orig_trange
|
||||
_dbg("diffusion_output", gen_vid)
|
||||
if offload != "disabled":
|
||||
_move(generator, "cpu")
|
||||
|
||||
@@ -392,7 +380,6 @@ def run_star_inference(
|
||||
if offload != "disabled":
|
||||
_move(vae, device)
|
||||
vid_tensor_gen = vae_decode_chunk(vae, gen_vid, chunk_size=vae_dec_chunk)
|
||||
_dbg("vae_decoded", vid_tensor_gen)
|
||||
if offload != "disabled":
|
||||
_move(vae, "cpu")
|
||||
|
||||
@@ -403,7 +390,6 @@ def run_star_inference(
|
||||
# -- Reshape to [B, C, F, H, W] then convert to ComfyUI format --
|
||||
gen_video = rearrange(vid_tensor_gen, "(b f) c h w -> b c f h w", b=1)
|
||||
gen_video = gen_video.float().cpu()
|
||||
_dbg("final_output", gen_video)
|
||||
|
||||
result = star_output_to_comfyui(gen_video) # [F, H, W, 3]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user