Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2d96d5aa5d | |||
| 0c62c6eef4 | |||
| 83e4b5dd98 | |||
| 2e75e2d076 | |||
| c08fe58fe7 | |||
| 9e84890877 | |||
| 2e98e453a4 | |||
| daf0304243 | |||
| 5ce7b0edcb | |||
| 8d8407ec9d |
+48
-6
@@ -4,6 +4,7 @@ import collections
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
cupy = None
|
cupy = None
|
||||||
@@ -14,12 +15,11 @@ def _ensure_cupy():
|
|||||||
try:
|
try:
|
||||||
import cupy as _cupy
|
import cupy as _cupy
|
||||||
cupy = _cupy
|
cupy = _cupy
|
||||||
except ImportError:
|
except Exception:
|
||||||
raise RuntimeError(
|
# Broad catch: an installed-but-broken cupy (e.g. incompatible
|
||||||
"cupy is required for BIM-VFI. Install it with:\n"
|
# NumPy) raises non-ImportError exceptions at import time. Treat any
|
||||||
" pip install cupy-cuda12x (or cupy-cuda11x for CUDA 11)\n"
|
# failure as "cupy unavailable"; the PyTorch fallback will be used.
|
||||||
"Or run install.py from the ComfyUI-Tween directory."
|
pass
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
@@ -246,6 +246,44 @@ def cuda_launch(strKey:str):
|
|||||||
# end
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
def _pytorch_costvol_impl(tenOne, tenTwo, intKernelSize):
|
||||||
|
"""Pure-PyTorch local cost volume via unfold + dot product."""
|
||||||
|
B, C, H, W = tenOne.shape
|
||||||
|
pad = (intKernelSize - 1) // 2
|
||||||
|
|
||||||
|
# Pad tenTwo so out-of-bounds yields 0 (matches CUDA kernel)
|
||||||
|
tenTwo_padded = F.pad(tenTwo, [pad, pad, pad, pad])
|
||||||
|
|
||||||
|
# Unfold into patches: (B, C, H, W, K, K)
|
||||||
|
patches = tenTwo_padded.unfold(2, intKernelSize, 1).unfold(3, intKernelSize, 1)
|
||||||
|
# Reshape to (B, C, H, W, K*K)
|
||||||
|
patches = patches.contiguous().view(B, C, H, W, intKernelSize * intKernelSize)
|
||||||
|
|
||||||
|
# Dot product over C dimension: (B, H, W, K*K)
|
||||||
|
tenOut = (tenOne.unsqueeze(-1) * patches).sum(dim=1)
|
||||||
|
|
||||||
|
# Permute to (B, K*K, H, W) to match CUDA output layout
|
||||||
|
tenOut = tenOut.permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
|
||||||
|
|
||||||
|
_costvol_fn = None
|
||||||
|
|
||||||
|
def _pytorch_costvol(tenOne, tenTwo, intKernelSize):
|
||||||
|
global _costvol_fn
|
||||||
|
if _costvol_fn is None:
|
||||||
|
try:
|
||||||
|
_costvol_fn = torch.compile(_pytorch_costvol_impl)
|
||||||
|
except Exception:
|
||||||
|
_costvol_fn = _pytorch_costvol_impl
|
||||||
|
try:
|
||||||
|
return _costvol_fn(tenOne, tenTwo, intKernelSize)
|
||||||
|
except Exception:
|
||||||
|
_costvol_fn = _pytorch_costvol_impl
|
||||||
|
return _costvol_fn(tenOne, tenTwo, intKernelSize)
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
@@ -253,6 +291,8 @@ class costvol_func(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
|
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
|
||||||
def forward(self, tenOne, tenTwo, intKernelSize):
|
def forward(self, tenOne, tenTwo, intKernelSize):
|
||||||
|
_ensure_cupy()
|
||||||
|
if tenOne.is_cuda and cupy is not None:
|
||||||
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
|
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
|
||||||
|
|
||||||
cuda_launch(cuda_kernel('costvol_out', '''
|
cuda_launch(cuda_kernel('costvol_out', '''
|
||||||
@@ -302,6 +342,8 @@ class costvol_func(torch.autograd.Function):
|
|||||||
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
|
args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), intKernelSize, tenOut.data_ptr()],
|
||||||
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
tenOut = _pytorch_costvol(tenOne, tenTwo, intKernelSize)
|
||||||
|
|
||||||
self.save_for_backward(tenOne, tenTwo)
|
self.save_for_backward(tenOne, tenTwo)
|
||||||
self.intKernelSize = intKernelSize
|
self.intKernelSize = intKernelSize
|
||||||
|
|||||||
@@ -0,0 +1,297 @@
|
|||||||
|
# Pure-PyTorch Fallbacks for cupy Kernels
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** Make BIM-VFI, SGM-VFI, and GIMM-VFI work without cupy by adding pure-PyTorch fallback implementations of softsplat and costvol.
|
||||||
|
|
||||||
|
**Architecture:** Each kernel file (`sgm_vfi_arch/softsplat.py`, `gimm_vfi_arch/.../softsplat.py`, `bim_vfi_arch/costvol.py`) gets a `_pytorch_*` fallback function. The `softsplat_func.forward()` and `costvol_func.forward()` methods dispatch to cupy when available, otherwise use the fallback. The `_check_cupy()` gate in `nodes.py` is removed so models can load on any backend.
|
||||||
|
|
||||||
|
**Tech Stack:** PyTorch (`scatter_add_`, `F.unfold`, `F.pad`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Add pure-PyTorch softsplat fallback to SGM-VFI
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `sgm_vfi_arch/softsplat.py`
|
||||||
|
|
||||||
|
**Step 1: Add cupy availability flag and fallback function**
|
||||||
|
|
||||||
|
At the top of `sgm_vfi_arch/softsplat.py`, change the hard `import cupy` to a try/except, and add the fallback function after the `cuda_launch` function (before the `softsplat()` function).
|
||||||
|
|
||||||
|
Replace:
|
||||||
|
```python
|
||||||
|
import cupy
|
||||||
|
```
|
||||||
|
With:
|
||||||
|
```python
|
||||||
|
try:
|
||||||
|
import cupy
|
||||||
|
except ImportError:
|
||||||
|
cupy = None
|
||||||
|
```
|
||||||
|
|
||||||
|
Add this fallback function (after `cuda_launch`, before `softsplat`):
|
||||||
|
```python
|
||||||
|
def _pytorch_softsplat(tenIn, tenFlow):
|
||||||
|
B, C, H, W = tenIn.shape
|
||||||
|
tenOut = tenIn.new_zeros(B, C, H, W)
|
||||||
|
|
||||||
|
# Build base grid: (x, y) for each pixel
|
||||||
|
grid_y, grid_x = torch.meshgrid(
|
||||||
|
torch.arange(H, device=tenIn.device, dtype=tenIn.dtype),
|
||||||
|
torch.arange(W, device=tenIn.device, dtype=tenIn.dtype),
|
||||||
|
indexing='ij',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Target positions
|
||||||
|
flt_x = grid_x.unsqueeze(0) + tenFlow[:, 0, :, :] # (B, H, W)
|
||||||
|
flt_y = grid_y.unsqueeze(0) + tenFlow[:, 1, :, :]
|
||||||
|
|
||||||
|
# Filter non-finite
|
||||||
|
valid = torch.isfinite(flt_x) & torch.isfinite(flt_y)
|
||||||
|
flt_x = torch.where(valid, flt_x, torch.zeros_like(flt_x))
|
||||||
|
flt_y = torch.where(valid, flt_y, torch.zeros_like(flt_y))
|
||||||
|
|
||||||
|
# Four neighbors (NW, NE, SW, SE)
|
||||||
|
nw_x = flt_x.floor().long()
|
||||||
|
nw_y = flt_y.floor().long()
|
||||||
|
|
||||||
|
# Bilinear weights
|
||||||
|
frac_x = flt_x - nw_x.float()
|
||||||
|
frac_y = flt_y - nw_y.float()
|
||||||
|
w_nw = (1.0 - frac_x) * (1.0 - frac_y)
|
||||||
|
w_ne = frac_x * (1.0 - frac_y)
|
||||||
|
w_sw = (1.0 - frac_x) * frac_y
|
||||||
|
w_se = frac_x * frac_y
|
||||||
|
|
||||||
|
# Zero out invalid pixels
|
||||||
|
w_nw = w_nw * valid
|
||||||
|
w_ne = w_ne * valid
|
||||||
|
w_sw = w_sw * valid
|
||||||
|
w_se = w_se * valid
|
||||||
|
|
||||||
|
# For each of the 4 neighbors, scatter into output
|
||||||
|
for dx, dy, w in [(0, 0, w_nw), (1, 0, w_ne), (0, 1, w_sw), (1, 1, w_se)]:
|
||||||
|
tx = nw_x + dx
|
||||||
|
ty = nw_y + dy
|
||||||
|
in_bounds = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H)
|
||||||
|
w_masked = w * in_bounds
|
||||||
|
|
||||||
|
# Flatten to 1D index for scatter_add
|
||||||
|
idx = (ty.clamp(0, H - 1) * W + tx.clamp(0, W - 1)) # (B, H, W)
|
||||||
|
idx = idx.unsqueeze(1).expand_as(tenIn) # (B, C, H, W)
|
||||||
|
weighted = tenIn * w_masked.unsqueeze(1) # (B, C, H, W)
|
||||||
|
tenOut.view(B, C, -1).scatter_add_(2, idx.reshape(B, C, -1), weighted.reshape(B, C, -1))
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Update softsplat_func.forward to use fallback**
|
||||||
|
|
||||||
|
In `softsplat_func.forward()`, replace the `elif tenIn.is_cuda != True: assert(False)` block so it dispatches to the fallback when cupy is unavailable or when not on CUDA:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Current:
|
||||||
|
if tenIn.is_cuda == True:
|
||||||
|
cuda_launch(cuda_kernel(...))(...)
|
||||||
|
elif tenIn.is_cuda != True:
|
||||||
|
assert(False)
|
||||||
|
|
||||||
|
# New:
|
||||||
|
if tenIn.is_cuda and cupy is not None:
|
||||||
|
cuda_launch(cuda_kernel(...))(...)
|
||||||
|
else:
|
||||||
|
tenOut = _pytorch_softsplat(tenIn, tenFlow)
|
||||||
|
```
|
||||||
|
|
||||||
|
Also guard the `@cupy.memoize` decorator on `cuda_launch`:
|
||||||
|
```python
|
||||||
|
# Current:
|
||||||
|
@cupy.memoize(for_each_device=True)
|
||||||
|
def cuda_launch(strKey:str):
|
||||||
|
|
||||||
|
# New:
|
||||||
|
def cuda_launch(strKey:str):
|
||||||
|
```
|
||||||
|
(The function already has its own dict-based caching via `objCudacache`, and the memoize is redundant anyway. But the real issue is it crashes at import when cupy=None.)
|
||||||
|
|
||||||
|
Wait - actually `cuda_launch` uses `cupy.RawKernel` inside, so it's only ever called on the cupy path. The `@cupy.memoize` decorator is the problem: it runs at import time. Replace it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Replace @cupy.memoize(for_each_device=True) with a simple cache dict
|
||||||
|
_cuda_launch_cache = {}
|
||||||
|
|
||||||
|
def cuda_launch(strKey:str):
|
||||||
|
if strKey not in _cuda_launch_cache:
|
||||||
|
if 'CUDA_HOME' not in os.environ:
|
||||||
|
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
|
||||||
|
_cuda_launch_cache[strKey] = cupy.RawKernel(
|
||||||
|
objCudacache[strKey]['strKernel'],
|
||||||
|
objCudacache[strKey]['strFunction'],
|
||||||
|
options=tuple(['-I ' + os.environ['CUDA_HOME'],
|
||||||
|
'-I ' + os.environ['CUDA_HOME'] + '/include'])
|
||||||
|
)
|
||||||
|
return _cuda_launch_cache[strKey]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add sgm_vfi_arch/softsplat.py
|
||||||
|
git commit -m "feat: add pure-PyTorch softsplat fallback for SGM-VFI"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Add pure-PyTorch softsplat fallback to GIMM-VFI
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `gimm_vfi_arch/generalizable_INR/modules/softsplat.py`
|
||||||
|
|
||||||
|
**Step 1: Add cupy availability flag and fallback function**
|
||||||
|
|
||||||
|
Same pattern as Task 1. Replace `import cupy` with try/except. Add the same `_pytorch_softsplat()` function. Replace `@cupy.memoize(for_each_device=True)` on `cuda_launch` with a dict cache.
|
||||||
|
|
||||||
|
The GIMM softsplat.py already has `@torch.compiler.disable()` on `cuda_launch` — keep that decorator.
|
||||||
|
|
||||||
|
**Step 2: Update softsplat_func.forward dispatch**
|
||||||
|
|
||||||
|
Same pattern: if `tenIn.is_cuda and cupy is not None` → cupy path, else → `_pytorch_softsplat`.
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add gimm_vfi_arch/generalizable_INR/modules/softsplat.py
|
||||||
|
git commit -m "feat: add pure-PyTorch softsplat fallback for GIMM-VFI"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: Add pure-PyTorch costvol fallback to BIM-VFI
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `bim_vfi_arch/costvol.py`
|
||||||
|
|
||||||
|
**Step 1: Add the fallback function**
|
||||||
|
|
||||||
|
After the existing `cuda_launch` function, add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _pytorch_costvol(tenOne, tenTwo, intKernelSize):
|
||||||
|
B, C, H, W = tenOne.shape
|
||||||
|
pad = (intKernelSize - 1) // 2
|
||||||
|
|
||||||
|
# Pad tenTwo with zeros so out-of-bounds accesses yield 0 (matches CUDA kernel)
|
||||||
|
tenTwo_padded = F.pad(tenTwo, [pad, pad, pad, pad])
|
||||||
|
|
||||||
|
# Unfold into (B, C, K*K, H, W) patches
|
||||||
|
patches = tenTwo_padded.unfold(2, intKernelSize, 1).unfold(3, intKernelSize, 1)
|
||||||
|
# patches shape: (B, C, H, W, K, K)
|
||||||
|
patches = patches.contiguous().view(B, C, H, W, intKernelSize * intKernelSize)
|
||||||
|
# -> (B, C, H, W, K^2)
|
||||||
|
|
||||||
|
# Dot product: sum over C
|
||||||
|
# tenOne: (B, C, H, W) -> (B, C, H, W, 1)
|
||||||
|
tenOut = (tenOne.unsqueeze(-1) * patches).sum(dim=1)
|
||||||
|
# tenOut: (B, H, W, K^2)
|
||||||
|
|
||||||
|
# Permute to (B, K^2, H, W) to match CUDA output layout
|
||||||
|
tenOut = tenOut.permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
```
|
||||||
|
|
||||||
|
Add `import torch.nn.functional as F` at the top if not already present.
|
||||||
|
|
||||||
|
**Step 2: Update costvol_func.forward dispatch**
|
||||||
|
|
||||||
|
The current forward unconditionally calls `cuda_launch(cuda_kernel(...))`. Change to:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@staticmethod
|
||||||
|
@torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)
|
||||||
|
def forward(self, tenOne, tenTwo, intKernelSize):
|
||||||
|
if tenOne.is_cuda and cupy is not None:
|
||||||
|
# existing cupy code (unchanged)
|
||||||
|
tenOut = tenOne.new_empty([tenOne.shape[0], intKernelSize ** 2, tenOne.shape[2], tenOne.shape[3]])
|
||||||
|
cuda_launch(cuda_kernel(...))(...)
|
||||||
|
else:
|
||||||
|
tenOut = _pytorch_costvol(tenOne, tenTwo, intKernelSize)
|
||||||
|
|
||||||
|
self.save_for_backward(tenOne, tenTwo)
|
||||||
|
self.intKernelSize = intKernelSize
|
||||||
|
return tenOut
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add bim_vfi_arch/costvol.py
|
||||||
|
git commit -m "feat: add pure-PyTorch costvol fallback for BIM-VFI"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: Remove _check_cupy gate from nodes.py
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `nodes.py`
|
||||||
|
|
||||||
|
**Step 1: Remove the _check_cupy function and all its call sites**
|
||||||
|
|
||||||
|
Delete the `_check_cupy()` function definition (lines 22-41). Remove the three calls:
|
||||||
|
- Line 209: `_check_cupy("BIM-VFI")` (in BIM-VFI load)
|
||||||
|
- Line 1377: `_check_cupy("SGM-VFI")` (in SGM-VFI load)
|
||||||
|
- Line 1804: `_check_cupy("GIMM-VFI")` (in GIMM-VFI load)
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add nodes.py
|
||||||
|
git commit -m "feat: remove cupy requirement gate, models now fallback to pure PyTorch"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: Make install.py not force cupy installation
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `install.py`
|
||||||
|
|
||||||
|
**Step 1: Change cupy from required to optional**
|
||||||
|
|
||||||
|
Make cupy a soft dependency — try to install it but don't fail if it can't be installed (ROCm users, no CUDA toolkit, etc.). Change `install()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def install():
|
||||||
|
# Install core requirements first
|
||||||
|
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
|
||||||
|
subprocess.check_call([
|
||||||
|
sys.executable, "-m", "pip", "install", "-r", requirements_path
|
||||||
|
])
|
||||||
|
|
||||||
|
# Try to install cupy for NVIDIA users (optional, improves performance)
|
||||||
|
cupy_pkg = get_cupy_package()
|
||||||
|
if cupy_pkg:
|
||||||
|
try:
|
||||||
|
subprocess.check_call([
|
||||||
|
sys.executable, "-m", "pip", "install", cupy_pkg
|
||||||
|
])
|
||||||
|
print(f"[Tween] cupy installed successfully ({cupy_pkg})")
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
print(f"[Tween] WARNING: Could not install {cupy_pkg}. "
|
||||||
|
f"BIM-VFI, SGM-VFI, and GIMM-VFI will use slower PyTorch fallback.")
|
||||||
|
else:
|
||||||
|
print("[Tween] cupy not available (no NVIDIA CUDA). "
|
||||||
|
"BIM-VFI, SGM-VFI, and GIMM-VFI will use PyTorch fallback.")
|
||||||
|
```
|
||||||
|
|
||||||
|
Also stop writing cupy into `requirements.txt` — remove the `update_requirements` call and function.
|
||||||
|
|
||||||
|
**Step 2: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add install.py
|
||||||
|
git commit -m "feat: make cupy optional in install.py"
|
||||||
|
```
|
||||||
@@ -9,7 +9,13 @@
|
|||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import cupy
|
try:
|
||||||
|
import cupy
|
||||||
|
except Exception:
|
||||||
|
# Broad catch: an installed-but-broken cupy (e.g. incompatible NumPy)
|
||||||
|
# raises non-ImportError exceptions at import time. Treat any failure as
|
||||||
|
# "cupy unavailable" and fall back to the pure-PyTorch implementation.
|
||||||
|
cupy = None
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
@@ -260,31 +266,94 @@ def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict):
|
|||||||
# end
|
# end
|
||||||
|
|
||||||
|
|
||||||
@cupy.memoize(for_each_device=True)
|
_cuda_launch_cache = {}
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
def cuda_launch(strKey: str):
|
def cuda_launch(strKey: str):
|
||||||
try:
|
if strKey not in _cuda_launch_cache:
|
||||||
os.environ.setdefault("CUDA_HOME", cupy.cuda.get_cuda_path())
|
|
||||||
except Exception:
|
|
||||||
if "CUDA_HOME" not in os.environ:
|
if "CUDA_HOME" not in os.environ:
|
||||||
raise RuntimeError("'CUDA_HOME' not set, unable to find cuda-toolkit installation.")
|
try:
|
||||||
|
cuda_path = cupy.cuda.get_cuda_path()
|
||||||
|
except Exception:
|
||||||
|
cuda_path = None
|
||||||
|
if cuda_path is None:
|
||||||
|
cuda_path = "/usr/local/cuda"
|
||||||
|
os.environ["CUDA_HOME"] = cuda_path
|
||||||
strKernel = objCudacache[strKey]["strKernel"]
|
strKernel = objCudacache[strKey]["strKernel"]
|
||||||
strFunction = objCudacache[strKey]["strFunction"]
|
strFunction = objCudacache[strKey]["strFunction"]
|
||||||
|
_cuda_launch_cache[strKey] = cupy.RawModule(
|
||||||
return cupy.RawModule(
|
|
||||||
code=strKernel,
|
code=strKernel,
|
||||||
options=(
|
options=(
|
||||||
"-I " + os.environ["CUDA_HOME"],
|
"-I " + os.environ["CUDA_HOME"],
|
||||||
"-I " + os.environ["CUDA_HOME"] + "/include",
|
"-I " + os.environ["CUDA_HOME"] + "/include",
|
||||||
),
|
),
|
||||||
).get_function(strFunction)
|
).get_function(strFunction)
|
||||||
|
return _cuda_launch_cache[strKey]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
|
def _pytorch_softsplat_impl(tenIn, tenFlow):
|
||||||
|
"""Pure-PyTorch forward warp via bilinear splatting (scatter_add)."""
|
||||||
|
B, C, H, W = tenIn.shape
|
||||||
|
tenOut = tenIn.new_zeros(B, C, H, W)
|
||||||
|
|
||||||
|
grid_y, grid_x = torch.meshgrid(
|
||||||
|
torch.arange(H, device=tenIn.device, dtype=tenIn.dtype),
|
||||||
|
torch.arange(W, device=tenIn.device, dtype=tenIn.dtype),
|
||||||
|
indexing='ij',
|
||||||
|
)
|
||||||
|
|
||||||
|
flt_x = grid_x.unsqueeze(0) + tenFlow[:, 0, :, :]
|
||||||
|
flt_y = grid_y.unsqueeze(0) + tenFlow[:, 1, :, :]
|
||||||
|
|
||||||
|
valid = torch.isfinite(flt_x) & torch.isfinite(flt_y)
|
||||||
|
flt_x = torch.where(valid, flt_x, torch.zeros_like(flt_x))
|
||||||
|
flt_y = torch.where(valid, flt_y, torch.zeros_like(flt_y))
|
||||||
|
|
||||||
|
nw_x = flt_x.floor().long()
|
||||||
|
nw_y = flt_y.floor().long()
|
||||||
|
frac_x = flt_x - nw_x.to(flt_x.dtype)
|
||||||
|
frac_y = flt_y - nw_y.to(flt_y.dtype)
|
||||||
|
|
||||||
|
w_nw = (1.0 - frac_x) * (1.0 - frac_y) * valid
|
||||||
|
w_ne = frac_x * (1.0 - frac_y) * valid
|
||||||
|
w_sw = (1.0 - frac_x) * frac_y * valid
|
||||||
|
w_se = frac_x * frac_y * valid
|
||||||
|
|
||||||
|
out_flat = tenOut.view(B, C, -1)
|
||||||
|
|
||||||
|
for dx, dy, w in [(0, 0, w_nw), (1, 0, w_ne), (0, 1, w_sw), (1, 1, w_se)]:
|
||||||
|
tx = nw_x + dx
|
||||||
|
ty = nw_y + dy
|
||||||
|
in_bounds = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H)
|
||||||
|
w_masked = w * in_bounds
|
||||||
|
idx = (ty.clamp(0, H - 1) * W + tx.clamp(0, W - 1))
|
||||||
|
idx = idx.unsqueeze(1).expand_as(tenIn)
|
||||||
|
weighted = tenIn * w_masked.unsqueeze(1)
|
||||||
|
out_flat.scatter_add_(2, idx.reshape(B, C, -1), weighted.reshape(B, C, -1))
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
|
||||||
|
|
||||||
|
_softsplat_fn = None
|
||||||
|
|
||||||
|
def _pytorch_softsplat(tenIn, tenFlow):
|
||||||
|
global _softsplat_fn
|
||||||
|
if _softsplat_fn is None:
|
||||||
|
try:
|
||||||
|
_softsplat_fn = torch.compile(_pytorch_softsplat_impl)
|
||||||
|
except Exception:
|
||||||
|
_softsplat_fn = _pytorch_softsplat_impl
|
||||||
|
try:
|
||||||
|
return _softsplat_fn(tenIn, tenFlow)
|
||||||
|
except Exception:
|
||||||
|
_softsplat_fn = _pytorch_softsplat_impl
|
||||||
|
return _softsplat_fn(tenIn, tenFlow)
|
||||||
|
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False):
|
def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False):
|
||||||
assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"]
|
assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"]
|
||||||
@@ -366,7 +435,7 @@ class softsplat_func(torch.autograd.Function):
|
|||||||
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
|
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
|
||||||
)
|
)
|
||||||
|
|
||||||
if tenIn.is_cuda == True:
|
if tenIn.is_cuda and cupy is not None:
|
||||||
cuda_launch(
|
cuda_launch(
|
||||||
cuda_kernel(
|
cuda_kernel(
|
||||||
"softsplat_out",
|
"softsplat_out",
|
||||||
@@ -439,8 +508,8 @@ class softsplat_func(torch.autograd.Function):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif tenIn.is_cuda != True:
|
else:
|
||||||
assert False
|
tenOut = _pytorch_softsplat(tenIn, tenFlow)
|
||||||
|
|
||||||
# end
|
# end
|
||||||
|
|
||||||
|
|||||||
+17
-22
@@ -8,44 +8,39 @@ def get_cupy_package():
|
|||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
print("[Tween] WARNING: CUDA not available. cupy requires CUDA.")
|
|
||||||
return None
|
return None
|
||||||
cuda_version = torch.version.cuda
|
cuda_version = torch.version.cuda
|
||||||
if cuda_version is None:
|
if cuda_version is None:
|
||||||
print("[Tween] WARNING: PyTorch has no CUDA version info.")
|
|
||||||
return None
|
return None
|
||||||
major = int(cuda_version.split(".")[0])
|
major = int(cuda_version.split(".")[0])
|
||||||
cupy_pkg = f"cupy-cuda{major}x"
|
cupy_pkg = f"cupy-cuda{major}x"
|
||||||
print(f"[Tween] Detected CUDA {cuda_version}, will use {cupy_pkg}")
|
|
||||||
return cupy_pkg
|
return cupy_pkg
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"[Tween] WARNING: Could not detect CUDA version: {e}")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def update_requirements(cupy_pkg):
|
|
||||||
"""Write the correct cupy package into requirements.txt."""
|
|
||||||
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
|
|
||||||
lines = []
|
|
||||||
if os.path.exists(requirements_path):
|
|
||||||
with open(requirements_path, "r") as f:
|
|
||||||
lines = [l.rstrip() for l in f if not l.strip().startswith("cupy")]
|
|
||||||
if cupy_pkg and cupy_pkg not in lines:
|
|
||||||
lines.append(cupy_pkg)
|
|
||||||
with open(requirements_path, "w") as f:
|
|
||||||
f.write("\n".join(lines) + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def install():
|
def install():
|
||||||
cupy_pkg = get_cupy_package()
|
# Install core requirements first
|
||||||
if cupy_pkg:
|
|
||||||
update_requirements(cupy_pkg)
|
|
||||||
|
|
||||||
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
|
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
|
||||||
subprocess.check_call([
|
subprocess.check_call([
|
||||||
sys.executable, "-m", "pip", "install", "-r", requirements_path
|
sys.executable, "-m", "pip", "install", "-r", requirements_path
|
||||||
])
|
])
|
||||||
|
|
||||||
|
# Try to install cupy for NVIDIA users (optional, improves performance)
|
||||||
|
cupy_pkg = get_cupy_package()
|
||||||
|
if cupy_pkg:
|
||||||
|
try:
|
||||||
|
subprocess.check_call([
|
||||||
|
sys.executable, "-m", "pip", "install", cupy_pkg
|
||||||
|
])
|
||||||
|
print(f"[Tween] cupy installed ({cupy_pkg}) — fast CUDA kernels enabled")
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
print(f"[Tween] WARNING: Could not install {cupy_pkg}. "
|
||||||
|
f"BIM-VFI, SGM-VFI, and GIMM-VFI will use slower PyTorch fallback.")
|
||||||
|
else:
|
||||||
|
print("[Tween] cupy skipped (no NVIDIA CUDA). "
|
||||||
|
"BIM-VFI, SGM-VFI, and GIMM-VFI will use PyTorch fallback.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
install()
|
install()
|
||||||
|
|||||||
@@ -19,26 +19,6 @@ from .gimm_vfi_arch import clear_gimm_caches
|
|||||||
logger = logging.getLogger("Tween")
|
logger = logging.getLogger("Tween")
|
||||||
|
|
||||||
|
|
||||||
def _check_cupy(model_name):
|
|
||||||
"""Raise a clear error if cupy is not installed."""
|
|
||||||
try:
|
|
||||||
import cupy # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
try:
|
|
||||||
cuda_ver = torch.version.cuda or "unknown"
|
|
||||||
major = int(cuda_ver.split(".")[0])
|
|
||||||
cupy_pkg = f"cupy-cuda{major}x"
|
|
||||||
except Exception:
|
|
||||||
cuda_ver = "unknown"
|
|
||||||
cupy_pkg = "cupy-cuda12x # adjust to your CUDA version"
|
|
||||||
raise RuntimeError(
|
|
||||||
f"{model_name} requires cupy but it is not installed.\n\n"
|
|
||||||
f"Your PyTorch CUDA version: {cuda_ver}\n\n"
|
|
||||||
f"Install it with:\n"
|
|
||||||
f" pip install {cupy_pkg}\n\n"
|
|
||||||
f"If you are unsure of your CUDA version, run:\n"
|
|
||||||
f" python -c \"import torch; print(torch.version.cuda)\""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_system_ram_gb():
|
def _get_system_ram_gb():
|
||||||
@@ -206,7 +186,6 @@ class LoadBIMVFIModel:
|
|||||||
CATEGORY = "video/BIM-VFI"
|
CATEGORY = "video/BIM-VFI"
|
||||||
|
|
||||||
def load_model(self, model_path, auto_pyr_level, pyr_level):
|
def load_model(self, model_path, auto_pyr_level, pyr_level):
|
||||||
_check_cupy("BIM-VFI")
|
|
||||||
full_path = os.path.join(MODEL_DIR, model_path)
|
full_path = os.path.join(MODEL_DIR, model_path)
|
||||||
|
|
||||||
if not os.path.exists(full_path):
|
if not os.path.exists(full_path):
|
||||||
@@ -1374,7 +1353,6 @@ class LoadSGMVFIModel:
|
|||||||
CATEGORY = "video/SGM-VFI"
|
CATEGORY = "video/SGM-VFI"
|
||||||
|
|
||||||
def load_model(self, model_path, tta, num_key_points):
|
def load_model(self, model_path, tta, num_key_points):
|
||||||
_check_cupy("SGM-VFI")
|
|
||||||
full_path = os.path.join(SGM_MODEL_DIR, model_path)
|
full_path = os.path.join(SGM_MODEL_DIR, model_path)
|
||||||
|
|
||||||
if not os.path.exists(full_path):
|
if not os.path.exists(full_path):
|
||||||
@@ -1801,7 +1779,6 @@ class LoadGIMMVFIModel:
|
|||||||
CATEGORY = "video/GIMM-VFI"
|
CATEGORY = "video/GIMM-VFI"
|
||||||
|
|
||||||
def load_model(self, model_path, ds_factor):
|
def load_model(self, model_path, ds_factor):
|
||||||
_check_cupy("GIMM-VFI")
|
|
||||||
full_path = os.path.join(GIMM_MODEL_DIR, model_path)
|
full_path = os.path.join(GIMM_MODEL_DIR, model_path)
|
||||||
|
|
||||||
# Auto-download main model if missing
|
# Auto-download main model if missing
|
||||||
|
|||||||
+88
-11
@@ -1,7 +1,13 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import cupy
|
try:
|
||||||
|
import cupy
|
||||||
|
except Exception:
|
||||||
|
# Broad catch: an installed-but-broken cupy (e.g. incompatible NumPy)
|
||||||
|
# raises non-ImportError exceptions at import time. Treat any failure as
|
||||||
|
# "cupy unavailable" and fall back to the pure-PyTorch implementation.
|
||||||
|
cupy = None
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
@@ -216,20 +222,91 @@ def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
|
|||||||
# end
|
# end
|
||||||
|
|
||||||
|
|
||||||
@cupy.memoize(for_each_device=True)
|
_cuda_launch_cache = {}
|
||||||
def cuda_launch(strKey:str):
|
|
||||||
if 'CUDA_HOME' not in os.environ:
|
|
||||||
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
|
|
||||||
# end
|
|
||||||
|
|
||||||
return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction'],
|
def cuda_launch(strKey:str):
|
||||||
options=tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include']))
|
if strKey not in _cuda_launch_cache:
|
||||||
|
if 'CUDA_HOME' not in os.environ:
|
||||||
|
try:
|
||||||
|
cuda_path = cupy.cuda.get_cuda_path()
|
||||||
|
except Exception:
|
||||||
|
cuda_path = None
|
||||||
|
if cuda_path is None:
|
||||||
|
cuda_path = '/usr/local/cuda'
|
||||||
|
os.environ['CUDA_HOME'] = cuda_path
|
||||||
|
_cuda_launch_cache[strKey] = cupy.RawKernel(
|
||||||
|
objCudacache[strKey]['strKernel'],
|
||||||
|
objCudacache[strKey]['strFunction'],
|
||||||
|
options=tuple(['-I ' + os.environ['CUDA_HOME'],
|
||||||
|
'-I ' + os.environ['CUDA_HOME'] + '/include'])
|
||||||
|
)
|
||||||
|
return _cuda_launch_cache[strKey]
|
||||||
# end
|
# end
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
|
def _pytorch_softsplat_impl(tenIn, tenFlow):
|
||||||
|
"""Pure-PyTorch forward warp via bilinear splatting (scatter_add)."""
|
||||||
|
B, C, H, W = tenIn.shape
|
||||||
|
tenOut = tenIn.new_zeros(B, C, H, W)
|
||||||
|
|
||||||
|
grid_y, grid_x = torch.meshgrid(
|
||||||
|
torch.arange(H, device=tenIn.device, dtype=tenIn.dtype),
|
||||||
|
torch.arange(W, device=tenIn.device, dtype=tenIn.dtype),
|
||||||
|
indexing='ij',
|
||||||
|
)
|
||||||
|
|
||||||
|
flt_x = grid_x.unsqueeze(0) + tenFlow[:, 0, :, :]
|
||||||
|
flt_y = grid_y.unsqueeze(0) + tenFlow[:, 1, :, :]
|
||||||
|
|
||||||
|
valid = torch.isfinite(flt_x) & torch.isfinite(flt_y)
|
||||||
|
flt_x = torch.where(valid, flt_x, torch.zeros_like(flt_x))
|
||||||
|
flt_y = torch.where(valid, flt_y, torch.zeros_like(flt_y))
|
||||||
|
|
||||||
|
nw_x = flt_x.floor().long()
|
||||||
|
nw_y = flt_y.floor().long()
|
||||||
|
frac_x = flt_x - nw_x.to(flt_x.dtype)
|
||||||
|
frac_y = flt_y - nw_y.to(flt_y.dtype)
|
||||||
|
|
||||||
|
w_nw = (1.0 - frac_x) * (1.0 - frac_y) * valid
|
||||||
|
w_ne = frac_x * (1.0 - frac_y) * valid
|
||||||
|
w_sw = (1.0 - frac_x) * frac_y * valid
|
||||||
|
w_se = frac_x * frac_y * valid
|
||||||
|
|
||||||
|
out_flat = tenOut.view(B, C, -1)
|
||||||
|
|
||||||
|
for dx, dy, w in [(0, 0, w_nw), (1, 0, w_ne), (0, 1, w_sw), (1, 1, w_se)]:
|
||||||
|
tx = nw_x + dx
|
||||||
|
ty = nw_y + dy
|
||||||
|
in_bounds = (tx >= 0) & (tx < W) & (ty >= 0) & (ty < H)
|
||||||
|
w_masked = w * in_bounds
|
||||||
|
idx = (ty.clamp(0, H - 1) * W + tx.clamp(0, W - 1))
|
||||||
|
idx = idx.unsqueeze(1).expand_as(tenIn)
|
||||||
|
weighted = tenIn * w_masked.unsqueeze(1)
|
||||||
|
out_flat.scatter_add_(2, idx.reshape(B, C, -1), weighted.reshape(B, C, -1))
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
|
||||||
|
|
||||||
|
_softsplat_fn = None
|
||||||
|
|
||||||
|
def _pytorch_softsplat(tenIn, tenFlow):
|
||||||
|
global _softsplat_fn
|
||||||
|
if _softsplat_fn is None:
|
||||||
|
try:
|
||||||
|
_softsplat_fn = torch.compile(_pytorch_softsplat_impl)
|
||||||
|
except Exception:
|
||||||
|
_softsplat_fn = _pytorch_softsplat_impl
|
||||||
|
try:
|
||||||
|
return _softsplat_fn(tenIn, tenFlow)
|
||||||
|
except Exception:
|
||||||
|
_softsplat_fn = _pytorch_softsplat_impl
|
||||||
|
return _softsplat_fn(tenIn, tenFlow)
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str):
|
def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str):
|
||||||
assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
|
assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
|
||||||
|
|
||||||
@@ -281,7 +358,7 @@ class softsplat_func(torch.autograd.Function):
|
|||||||
def forward(self, tenIn, tenFlow):
|
def forward(self, tenIn, tenFlow):
|
||||||
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
|
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
|
||||||
|
|
||||||
if tenIn.is_cuda == True:
|
if tenIn.is_cuda and cupy is not None:
|
||||||
cuda_launch(cuda_kernel('softsplat_out', '''
|
cuda_launch(cuda_kernel('softsplat_out', '''
|
||||||
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
|
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
|
||||||
const int n,
|
const int n,
|
||||||
@@ -345,8 +422,8 @@ class softsplat_func(torch.autograd.Function):
|
|||||||
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif tenIn.is_cuda != True:
|
else:
|
||||||
assert(False)
|
tenOut = _pytorch_softsplat(tenIn, tenFlow)
|
||||||
|
|
||||||
# end
|
# end
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user