fix: guarantee offload cleanup on exception with try/finally

Both nodes moved models to GPU before work then back to CPU after.
Any exception (OOM, cancellation, bad input) would skip the cleanup,
leaving models on GPU permanently until ComfyUI restarts.

Wrap the entire work block in try/finally so offload_to_cpu cleanup
always runs regardless of how the node exits. Also removes the unused
`mode` variable in SelvaSampler.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 08:40:39 +02:00
parent 8bb2fb7015
commit 3dd6badfd9
2 changed files with 100 additions and 97 deletions
+41 -39
View File
@@ -158,50 +158,52 @@ class SelvaFeatureExtractor:
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
pbar = comfy.utils.ProgressBar(3)
with torch.no_grad():
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
if mask is not None:
clip_frames = _apply_mask(clip_frames, mask)
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {'(masked)' if mask is not None else ''}", flush=True)
try:
with torch.no_grad():
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
if mask is not None:
clip_frames = _apply_mask(clip_frames, mask)
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {'(masked)' if mask is not None else ''}", flush=True)
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
pbar.update(1)
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
pbar.update(1)
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
if mask is not None:
sync_frames = _apply_mask(sync_frames, mask)
# Pad to minimum 16 frames (TextSynchformer segment size)
if sync_frames.shape[0] < 16:
pad = 16 - sync_frames.shape[0]
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
# Normalize [0,1] → [-1,1]
mean = _SYNC_MEAN.to(sync_frames.device)
std = _SYNC_STD.to(sync_frames.device)
sync_frames = (sync_frames - mean) / std
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {'(masked)' if mask is not None else ''}", flush=True)
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
if mask is not None:
sync_frames = _apply_mask(sync_frames, mask)
# Pad to minimum 16 frames (TextSynchformer segment size)
if sync_frames.shape[0] < 16:
pad = 16 - sync_frames.shape[0]
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
# Normalize [0,1] → [-1,1]
mean = _SYNC_MEAN.to(sync_frames.device)
std = _SYNC_STD.to(sync_frames.device)
sync_frames = (sync_frames - mean) / std
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {'(masked)' if mask is not None else ''}", flush=True)
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
pbar.update(1)
text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask)
sync_features = net_video_enc.encode_video_with_sync(
sync_input, text_f=text_f, text_mask=text_mask
) # [1, T_sync, 768]
pbar.update(1)
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
pbar.update(1)
text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask)
sync_features = net_video_enc.encode_video_with_sync(
sync_input, text_f=text_f, text_mask=text_mask
) # [1, T_sync, 768]
pbar.update(1)
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
if strategy == "offload_to_cpu":
feature_utils.to(get_offload_device())
net_video_enc.to(get_offload_device())
soft_empty_cache()
finally:
if strategy == "offload_to_cpu":
feature_utils.to(get_offload_device())
net_video_enc.to(get_offload_device())
soft_empty_cache()
np.savez(
cached_path,
+59 -58
View File
@@ -54,7 +54,6 @@ class SelvaSampler:
strategy = model["strategy"]
net_generator = model["generator"]
feature_utils = model["feature_utils"]
mode = model["mode"]
# Validate that features were extracted with the same model variant
feat_variant = features.get("variant")
@@ -88,76 +87,78 @@ class SelvaSampler:
feature_utils.to(device)
soft_empty_cache()
clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024]
sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768]
try:
clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024]
sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768]
print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True)
print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True)
# Update model rotary position embeddings for actual feature shapes and duration.
# Use actual feature dimensions (not seq_cfg) to avoid rounding assertion mismatches.
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
print(f"[SelVA] seq: latent={seq_cfg.latent_seq_len} clip={clip_f.shape[1]} sync={sync_f.shape[1]}", flush=True)
with torch.no_grad():
# Encode text conditioning
text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D]
# Encode negative prompt (or use empty conditions)
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
if negative_prompt.strip() else None
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = net_generator.get_empty_conditions(
bs=1, negative_text_features=neg_text_clip
# Update model rotary position embeddings for actual feature shapes and duration.
# Use actual feature dimensions (not seq_cfg) to avoid rounding assertion mismatches.
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
print(f"[SelVA] seq: latent={seq_cfg.latent_seq_len} clip={clip_f.shape[1]} sync={sync_f.shape[1]}", flush=True)
# Initial noise (MPS doesn't support torch.Generator on device)
gen_device = "cpu" if device.type == "mps" else device
rng = torch.Generator(device=gen_device).manual_seed(seed)
x0 = torch.randn(
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
device=gen_device, dtype=dtype, generator=rng,
).to(device)
with torch.no_grad():
# Encode text conditioning
text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D]
# Flow matching ODE (Euler)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
pbar = comfy.utils.ProgressBar(steps)
# Encode negative prompt (or use empty conditions)
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
if negative_prompt.strip() else None
def ode_wrapper_tracked(t, x):
comfy.model_management.throw_exception_if_processing_interrupted()
pbar.update(1)
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = net_generator.get_empty_conditions(
bs=1, negative_text_features=neg_text_clip
)
# Initial noise (MPS doesn't support torch.Generator on device)
gen_device = "cpu" if device.type == "mps" else device
rng = torch.Generator(device=gen_device).manual_seed(seed)
x0 = torch.randn(
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
device=gen_device, dtype=dtype, generator=rng,
).to(device)
# Flow matching ODE (Euler)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
pbar = comfy.utils.ProgressBar(steps)
def ode_wrapper_tracked(t, x):
comfy.model_management.throw_exception_if_processing_interrupted()
pbar.update(1)
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
try:
x1 = fm.to_data(ode_wrapper_tracked, x0)
except torch.cuda.OutOfMemoryError:
raise RuntimeError(
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
)
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
# Decode: latent → mel → audio
try:
x1 = fm.to_data(ode_wrapper_tracked, x0)
with torch.no_grad():
x1_unnorm = net_generator.unnormalize(x1)
spec = feature_utils.decode(x1_unnorm) # latent → mel spectrogram
audio = feature_utils.vocode(spec) # mel → waveform
except torch.cuda.OutOfMemoryError:
raise RuntimeError(
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
"[SelVA] CUDA out of memory during decode/vocode. Try switching offload_strategy "
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
)
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
# Decode: latent → mel → audio
try:
with torch.no_grad():
x1_unnorm = net_generator.unnormalize(x1)
spec = feature_utils.decode(x1_unnorm) # latent → mel spectrogram
audio = feature_utils.vocode(spec) # mel → waveform
except torch.cuda.OutOfMemoryError:
raise RuntimeError(
"[SelVA] CUDA out of memory during decode/vocode. Try switching offload_strategy "
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
)
if strategy == "offload_to_cpu":
net_generator.to(get_offload_device())
feature_utils.to(get_offload_device())
soft_empty_cache()
finally:
if strategy == "offload_to_cpu":
net_generator.to(get_offload_device())
feature_utils.to(get_offload_device())
soft_empty_cache()
# Ensure [1, 1, samples] and normalize to [-1,1]
audio = audio.float()