fix: guarantee offload cleanup on exception with try/finally
Both nodes moved models to GPU before work then back to CPU after. Any exception (OOM, cancellation, bad input) would skip the cleanup, leaving models on GPU permanently until ComfyUI restarts. Wrap the entire work block in try/finally so offload_to_cpu cleanup always runs regardless of how the node exits. Also removes the unused `mode` variable in SelvaSampler. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -158,50 +158,52 @@ class SelvaFeatureExtractor:
|
||||
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
|
||||
pbar = comfy.utils.ProgressBar(3)
|
||||
|
||||
with torch.no_grad():
|
||||
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
|
||||
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
||||
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
||||
if mask is not None:
|
||||
clip_frames = _apply_mask(clip_frames, mask)
|
||||
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
|
||||
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {'(masked)' if mask is not None else ''}", flush=True)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
|
||||
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
||||
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
||||
if mask is not None:
|
||||
clip_frames = _apply_mask(clip_frames, mask)
|
||||
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
|
||||
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {'(masked)' if mask is not None else ''}", flush=True)
|
||||
|
||||
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
|
||||
pbar.update(1)
|
||||
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
|
||||
pbar.update(1)
|
||||
|
||||
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
|
||||
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
|
||||
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
||||
if mask is not None:
|
||||
sync_frames = _apply_mask(sync_frames, mask)
|
||||
# Pad to minimum 16 frames (TextSynchformer segment size)
|
||||
if sync_frames.shape[0] < 16:
|
||||
pad = 16 - sync_frames.shape[0]
|
||||
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
|
||||
# Normalize [0,1] → [-1,1]
|
||||
mean = _SYNC_MEAN.to(sync_frames.device)
|
||||
std = _SYNC_STD.to(sync_frames.device)
|
||||
sync_frames = (sync_frames - mean) / std
|
||||
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
|
||||
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {'(masked)' if mask is not None else ''}", flush=True)
|
||||
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
|
||||
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
|
||||
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
||||
if mask is not None:
|
||||
sync_frames = _apply_mask(sync_frames, mask)
|
||||
# Pad to minimum 16 frames (TextSynchformer segment size)
|
||||
if sync_frames.shape[0] < 16:
|
||||
pad = 16 - sync_frames.shape[0]
|
||||
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
|
||||
# Normalize [0,1] → [-1,1]
|
||||
mean = _SYNC_MEAN.to(sync_frames.device)
|
||||
std = _SYNC_STD.to(sync_frames.device)
|
||||
sync_frames = (sync_frames - mean) / std
|
||||
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
|
||||
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {'(masked)' if mask is not None else ''}", flush=True)
|
||||
|
||||
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
|
||||
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
|
||||
pbar.update(1)
|
||||
text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask)
|
||||
sync_features = net_video_enc.encode_video_with_sync(
|
||||
sync_input, text_f=text_f, text_mask=text_mask
|
||||
) # [1, T_sync, 768]
|
||||
pbar.update(1)
|
||||
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
|
||||
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
|
||||
pbar.update(1)
|
||||
text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask)
|
||||
sync_features = net_video_enc.encode_video_with_sync(
|
||||
sync_input, text_f=text_f, text_mask=text_mask
|
||||
) # [1, T_sync, 768]
|
||||
pbar.update(1)
|
||||
|
||||
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
|
||||
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
|
||||
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
|
||||
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
feature_utils.to(get_offload_device())
|
||||
net_video_enc.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
finally:
|
||||
if strategy == "offload_to_cpu":
|
||||
feature_utils.to(get_offload_device())
|
||||
net_video_enc.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
np.savez(
|
||||
cached_path,
|
||||
|
||||
+59
-58
@@ -54,7 +54,6 @@ class SelvaSampler:
|
||||
strategy = model["strategy"]
|
||||
net_generator = model["generator"]
|
||||
feature_utils = model["feature_utils"]
|
||||
mode = model["mode"]
|
||||
|
||||
# Validate that features were extracted with the same model variant
|
||||
feat_variant = features.get("variant")
|
||||
@@ -88,76 +87,78 @@ class SelvaSampler:
|
||||
feature_utils.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024]
|
||||
sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768]
|
||||
try:
|
||||
clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024]
|
||||
sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768]
|
||||
|
||||
print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True)
|
||||
print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True)
|
||||
|
||||
# Update model rotary position embeddings for actual feature shapes and duration.
|
||||
# Use actual feature dimensions (not seq_cfg) to avoid rounding assertion mismatches.
|
||||
net_generator.update_seq_lengths(
|
||||
latent_seq_len=seq_cfg.latent_seq_len,
|
||||
clip_seq_len=clip_f.shape[1],
|
||||
sync_seq_len=sync_f.shape[1],
|
||||
)
|
||||
print(f"[SelVA] seq: latent={seq_cfg.latent_seq_len} clip={clip_f.shape[1]} sync={sync_f.shape[1]}", flush=True)
|
||||
|
||||
with torch.no_grad():
|
||||
# Encode text conditioning
|
||||
text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D]
|
||||
|
||||
# Encode negative prompt (or use empty conditions)
|
||||
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
||||
if negative_prompt.strip() else None
|
||||
|
||||
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
||||
empty_conditions = net_generator.get_empty_conditions(
|
||||
bs=1, negative_text_features=neg_text_clip
|
||||
# Update model rotary position embeddings for actual feature shapes and duration.
|
||||
# Use actual feature dimensions (not seq_cfg) to avoid rounding assertion mismatches.
|
||||
net_generator.update_seq_lengths(
|
||||
latent_seq_len=seq_cfg.latent_seq_len,
|
||||
clip_seq_len=clip_f.shape[1],
|
||||
sync_seq_len=sync_f.shape[1],
|
||||
)
|
||||
print(f"[SelVA] seq: latent={seq_cfg.latent_seq_len} clip={clip_f.shape[1]} sync={sync_f.shape[1]}", flush=True)
|
||||
|
||||
# Initial noise (MPS doesn't support torch.Generator on device)
|
||||
gen_device = "cpu" if device.type == "mps" else device
|
||||
rng = torch.Generator(device=gen_device).manual_seed(seed)
|
||||
x0 = torch.randn(
|
||||
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
||||
device=gen_device, dtype=dtype, generator=rng,
|
||||
).to(device)
|
||||
with torch.no_grad():
|
||||
# Encode text conditioning
|
||||
text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D]
|
||||
|
||||
# Flow matching ODE (Euler)
|
||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
# Encode negative prompt (or use empty conditions)
|
||||
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
||||
if negative_prompt.strip() else None
|
||||
|
||||
def ode_wrapper_tracked(t, x):
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
pbar.update(1)
|
||||
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
||||
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
||||
empty_conditions = net_generator.get_empty_conditions(
|
||||
bs=1, negative_text_features=neg_text_clip
|
||||
)
|
||||
|
||||
# Initial noise (MPS doesn't support torch.Generator on device)
|
||||
gen_device = "cpu" if device.type == "mps" else device
|
||||
rng = torch.Generator(device=gen_device).manual_seed(seed)
|
||||
x0 = torch.randn(
|
||||
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
||||
device=gen_device, dtype=dtype, generator=rng,
|
||||
).to(device)
|
||||
|
||||
# Flow matching ODE (Euler)
|
||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
def ode_wrapper_tracked(t, x):
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
pbar.update(1)
|
||||
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
||||
|
||||
try:
|
||||
x1 = fm.to_data(ode_wrapper_tracked, x0)
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
raise RuntimeError(
|
||||
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
|
||||
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
||||
)
|
||||
|
||||
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
||||
|
||||
# Decode: latent → mel → audio
|
||||
try:
|
||||
x1 = fm.to_data(ode_wrapper_tracked, x0)
|
||||
with torch.no_grad():
|
||||
x1_unnorm = net_generator.unnormalize(x1)
|
||||
spec = feature_utils.decode(x1_unnorm) # latent → mel spectrogram
|
||||
audio = feature_utils.vocode(spec) # mel → waveform
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
raise RuntimeError(
|
||||
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
|
||||
"[SelVA] CUDA out of memory during decode/vocode. Try switching offload_strategy "
|
||||
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
||||
)
|
||||
|
||||
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
||||
|
||||
# Decode: latent → mel → audio
|
||||
try:
|
||||
with torch.no_grad():
|
||||
x1_unnorm = net_generator.unnormalize(x1)
|
||||
spec = feature_utils.decode(x1_unnorm) # latent → mel spectrogram
|
||||
audio = feature_utils.vocode(spec) # mel → waveform
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
raise RuntimeError(
|
||||
"[SelVA] CUDA out of memory during decode/vocode. Try switching offload_strategy "
|
||||
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
||||
)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
net_generator.to(get_offload_device())
|
||||
feature_utils.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
finally:
|
||||
if strategy == "offload_to_cpu":
|
||||
net_generator.to(get_offload_device())
|
||||
feature_utils.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
# Ensure [1, 1, samples] and normalize to [-1,1]
|
||||
audio = audio.float()
|
||||
|
||||
Reference in New Issue
Block a user