Fix warp cache buildup when all_on_gpu is enabled

The all_on_gpu guard was preventing warp cache clearing and
torch.cuda.empty_cache() from ever running, causing unbounded
VRAM growth during long interpolation runs. Cache clearing now
runs on the clear_cache_after_n_frames interval regardless of
the all_on_gpu setting.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-18 21:27:47 +01:00
parent 13a89c5831
commit 396dafeefc

View File

@@ -173,7 +173,7 @@ class BIMVFIInterpolate:
}), }),
"clear_cache_after_n_frames": ("INT", { "clear_cache_after_n_frames": ("INT", {
"default": 10, "min": 1, "max": 100, "step": 1, "default": 10, "min": 1, "max": 100, "step": 1,
"tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower. Ignored when all_on_gpu is enabled.", "tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower.",
}), }),
"keep_device": ("BOOLEAN", { "keep_device": ("BOOLEAN", {
"default": True, "default": True,
@@ -248,7 +248,7 @@ class BIMVFIInterpolate:
pbar.update_absolute(step_ref[0]) pbar.update_absolute(step_ref[0])
pairs_since_clear += actual_batch pairs_since_clear += actual_batch
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available(): if pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
clear_backwarp_cache() clear_backwarp_cache()
torch.cuda.empty_cache() torch.cuda.empty_cache()
pairs_since_clear = 0 pairs_since_clear = 0
@@ -256,7 +256,7 @@ class BIMVFIInterpolate:
new_frames.append(frames[-1:]) new_frames.append(frames[-1:])
frames = torch.cat(new_frames, dim=0) frames = torch.cat(new_frames, dim=0)
if not all_on_gpu and torch.cuda.is_available(): if torch.cuda.is_available():
clear_backwarp_cache() clear_backwarp_cache()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -731,7 +731,7 @@ class EMAVFIInterpolate:
}), }),
"clear_cache_after_n_frames": ("INT", { "clear_cache_after_n_frames": ("INT", {
"default": 10, "min": 1, "max": 100, "step": 1, "default": 10, "min": 1, "max": 100, "step": 1,
"tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower. Ignored when all_on_gpu is enabled.", "tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower.",
}), }),
"keep_device": ("BOOLEAN", { "keep_device": ("BOOLEAN", {
"default": True, "default": True,
@@ -799,7 +799,7 @@ class EMAVFIInterpolate:
pbar.update_absolute(step_ref[0]) pbar.update_absolute(step_ref[0])
pairs_since_clear += actual_batch pairs_since_clear += actual_batch
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available(): if pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
clear_ema_warp_cache() clear_ema_warp_cache()
torch.cuda.empty_cache() torch.cuda.empty_cache()
pairs_since_clear = 0 pairs_since_clear = 0
@@ -807,7 +807,7 @@ class EMAVFIInterpolate:
new_frames.append(frames[-1:]) new_frames.append(frames[-1:])
frames = torch.cat(new_frames, dim=0) frames = torch.cat(new_frames, dim=0)
if not all_on_gpu and torch.cuda.is_available(): if torch.cuda.is_available():
clear_ema_warp_cache() clear_ema_warp_cache()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -1145,7 +1145,7 @@ class SGMVFIInterpolate:
}), }),
"clear_cache_after_n_frames": ("INT", { "clear_cache_after_n_frames": ("INT", {
"default": 10, "min": 1, "max": 100, "step": 1, "default": 10, "min": 1, "max": 100, "step": 1,
"tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower. Ignored when all_on_gpu is enabled.", "tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower.",
}), }),
"keep_device": ("BOOLEAN", { "keep_device": ("BOOLEAN", {
"default": True, "default": True,
@@ -1213,7 +1213,7 @@ class SGMVFIInterpolate:
pbar.update_absolute(step_ref[0]) pbar.update_absolute(step_ref[0])
pairs_since_clear += actual_batch pairs_since_clear += actual_batch
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available(): if pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
clear_sgm_warp_cache() clear_sgm_warp_cache()
torch.cuda.empty_cache() torch.cuda.empty_cache()
pairs_since_clear = 0 pairs_since_clear = 0
@@ -1221,7 +1221,7 @@ class SGMVFIInterpolate:
new_frames.append(frames[-1:]) new_frames.append(frames[-1:])
frames = torch.cat(new_frames, dim=0) frames = torch.cat(new_frames, dim=0)
if not all_on_gpu and torch.cuda.is_available(): if torch.cuda.is_available():
clear_sgm_warp_cache() clear_sgm_warp_cache()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -1576,7 +1576,7 @@ class GIMMVFIInterpolate:
}), }),
"clear_cache_after_n_frames": ("INT", { "clear_cache_after_n_frames": ("INT", {
"default": 10, "min": 1, "max": 100, "step": 1, "default": 10, "min": 1, "max": 100, "step": 1,
"tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower. Ignored when all_on_gpu is enabled.", "tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower.",
}), }),
"keep_device": ("BOOLEAN", { "keep_device": ("BOOLEAN", {
"default": True, "default": True,
@@ -1641,7 +1641,7 @@ class GIMMVFIInterpolate:
pbar.update_absolute(step_ref[0]) pbar.update_absolute(step_ref[0])
pairs_since_clear += 1 pairs_since_clear += 1
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available(): if pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
clear_gimm_caches() clear_gimm_caches()
torch.cuda.empty_cache() torch.cuda.empty_cache()
pairs_since_clear = 0 pairs_since_clear = 0
@@ -1649,7 +1649,7 @@ class GIMMVFIInterpolate:
new_frames.append(frames[-1:]) new_frames.append(frames[-1:])
result = torch.cat(new_frames, dim=0) result = torch.cat(new_frames, dim=0)
if not all_on_gpu and torch.cuda.is_available(): if torch.cuda.is_available():
clear_gimm_caches() clear_gimm_caches()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -1689,7 +1689,7 @@ class GIMMVFIInterpolate:
pbar.update_absolute(step_ref[0]) pbar.update_absolute(step_ref[0])
pairs_since_clear += actual_batch pairs_since_clear += actual_batch
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available(): if pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
clear_gimm_caches() clear_gimm_caches()
torch.cuda.empty_cache() torch.cuda.empty_cache()
pairs_since_clear = 0 pairs_since_clear = 0
@@ -1697,7 +1697,7 @@ class GIMMVFIInterpolate:
new_frames.append(frames[-1:]) new_frames.append(frames[-1:])
frames = torch.cat(new_frames, dim=0) frames = torch.cat(new_frames, dim=0)
if not all_on_gpu and torch.cuda.is_available(): if torch.cuda.is_available():
clear_gimm_caches() clear_gimm_caches()
torch.cuda.empty_cache() torch.cuda.empty_cache()