Add SGM-VFI (CVPR 2024) frame interpolation support
SGM-VFI combines local flow estimation with sparse global matching (GMFlow) to handle large motion and occlusion-heavy scenes. Adds 3 new nodes: Load SGM-VFI Model, SGM-VFI Interpolate, SGM-VFI Segment Interpolate. Architecture files vendored from MCG-NJU/SGM-VFI with device-awareness fixes (no hardcoded .cuda()), relative imports, and debug code removed. README updated with model comparison table. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
69
README.md
69
README.md
@@ -1,6 +1,21 @@
|
|||||||
# ComfyUI BIM-VFI + EMA-VFI
|
# ComfyUI BIM-VFI + EMA-VFI + SGM-VFI
|
||||||
|
|
||||||
ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025) and [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023). Designed for long videos with thousands of frames — processes them without running out of VRAM.
|
ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), and [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024). Designed for long videos with thousands of frames — processes them without running out of VRAM.
|
||||||
|
|
||||||
|
## Which model should I use?
|
||||||
|
|
||||||
|
| | BIM-VFI | EMA-VFI | SGM-VFI |
|
||||||
|
|---|---------|---------|---------|
|
||||||
|
| **Best for** | General-purpose, non-uniform motion | Fast inference, light VRAM | Large motion, occlusion-heavy scenes |
|
||||||
|
| **Quality** | Highest overall | Good | Best on large motion |
|
||||||
|
| **Speed** | Moderate | Fastest | Slowest |
|
||||||
|
| **VRAM** | ~2 GB/pair | ~1.5 GB/pair | ~3 GB/pair |
|
||||||
|
| **Params** | ~17M | ~14–65M | ~15M + GMFlow |
|
||||||
|
| **Arbitrary timestep** | Yes | Yes (with `_t` checkpoint) | No (fixed 0.5) |
|
||||||
|
| **Paper** | CVPR 2025 | CVPR 2023 | CVPR 2024 |
|
||||||
|
| **License** | Research only | Apache 2.0 | Apache 2.0 |
|
||||||
|
|
||||||
|
**TL;DR:** Start with **BIM-VFI** for best quality. Use **EMA-VFI** if you need speed or lower VRAM. Use **SGM-VFI** if your video has large camera motion or fast-moving objects that the others struggle with.
|
||||||
|
|
||||||
## Nodes
|
## Nodes
|
||||||
|
|
||||||
@@ -66,7 +81,32 @@ Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate.
|
|||||||
|
|
||||||
Same as EMA-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
|
Same as EMA-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
|
||||||
|
|
||||||
**Output frame count (both models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7
|
### SGM-VFI
|
||||||
|
|
||||||
|
#### Load SGM-VFI Model
|
||||||
|
|
||||||
|
Loads an SGM-VFI checkpoint. Auto-downloads from Google Drive on first use to `ComfyUI/models/sgm-vfi/`. Variant (base/small) is auto-detected from the filename (default is small).
|
||||||
|
|
||||||
|
| Input | Description |
|
||||||
|
|-------|-------------|
|
||||||
|
| **model_path** | Checkpoint file from `models/sgm-vfi/` |
|
||||||
|
| **tta** | Test-time augmentation: flip input and average with unflipped result (~2x slower, slightly better quality) |
|
||||||
|
| **num_key_points** | Sparsity of global matching (0.0 = global everywhere, 0.5 = default balance, higher = faster) |
|
||||||
|
|
||||||
|
Available checkpoints:
|
||||||
|
| Checkpoint | Variant | Params |
|
||||||
|
|-----------|---------|--------|
|
||||||
|
| `ours-1-2-points.pth` | Small | ~15M + GMFlow |
|
||||||
|
|
||||||
|
#### SGM-VFI Interpolate
|
||||||
|
|
||||||
|
Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate.
|
||||||
|
|
||||||
|
#### SGM-VFI Segment Interpolate
|
||||||
|
|
||||||
|
Same as SGM-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
|
||||||
|
|
||||||
|
**Output frame count (all models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@@ -94,8 +134,8 @@ python install.py
|
|||||||
### Requirements
|
### Requirements
|
||||||
|
|
||||||
- PyTorch with CUDA
|
- PyTorch with CUDA
|
||||||
- `cupy` (matching your CUDA version, for BIM-VFI)
|
- `cupy` (matching your CUDA version, for BIM-VFI and SGM-VFI)
|
||||||
- `timm` (for EMA-VFI)
|
- `timm` (for EMA-VFI and SGM-VFI)
|
||||||
- `gdown` (for model auto-download)
|
- `gdown` (for model auto-download)
|
||||||
|
|
||||||
## VRAM Guide
|
## VRAM Guide
|
||||||
@@ -109,7 +149,7 @@ python install.py
|
|||||||
|
|
||||||
## Acknowledgments
|
## Acknowledgments
|
||||||
|
|
||||||
This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab) and the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU. Architecture files in `bim_vfi_arch/` and `ema_vfi_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, inference-only paths).
|
This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, and the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU. Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, and `sgm_vfi_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, inference-only paths).
|
||||||
|
|
||||||
**BiM-VFI:**
|
**BiM-VFI:**
|
||||||
> Wonyong Seo, Jihyong Oh, and Munchurl Kim.
|
> Wonyong Seo, Jihyong Oh, and Munchurl Kim.
|
||||||
@@ -141,8 +181,25 @@ This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VF
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**SGM-VFI:**
|
||||||
|
> Guozhen Zhang, Yuhan Zhu, Evan Zheran Liu, Haonan Wang, Mingzhen Sun, Gangshan Wu, and Limin Wang.
|
||||||
|
> "Sparse Global Matching for Video Frame Interpolation with Large Motion."
|
||||||
|
> *IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2024.
|
||||||
|
> [[arXiv]](https://arxiv.org/abs/2404.06913) [[GitHub]](https://github.com/MCG-NJU/SGM-VFI)
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{zhang2024sgmvfi,
|
||||||
|
title={Sparse Global Matching for Video Frame Interpolation with Large Motion},
|
||||||
|
author={Zhang, Guozhen and Zhu, Yuhan and Liu, Evan Zheran and Wang, Haonan and Sun, Mingzhen and Wu, Gangshan and Wang, Limin},
|
||||||
|
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
|
year={2024}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details.
|
The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details.
|
||||||
|
|
||||||
The EMA-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/EMA-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/EMA-VFI) for details.
|
The EMA-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/EMA-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/EMA-VFI) for details.
|
||||||
|
|
||||||
|
The SGM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/SGM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/SGM-VFI) for details.
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ _auto_install_deps()
|
|||||||
from .nodes import (
|
from .nodes import (
|
||||||
LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, BIMVFIConcatVideos,
|
LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, BIMVFIConcatVideos,
|
||||||
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
|
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
|
||||||
|
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
|
||||||
)
|
)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
@@ -50,6 +51,9 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LoadEMAVFIModel": LoadEMAVFIModel,
|
"LoadEMAVFIModel": LoadEMAVFIModel,
|
||||||
"EMAVFIInterpolate": EMAVFIInterpolate,
|
"EMAVFIInterpolate": EMAVFIInterpolate,
|
||||||
"EMAVFISegmentInterpolate": EMAVFISegmentInterpolate,
|
"EMAVFISegmentInterpolate": EMAVFISegmentInterpolate,
|
||||||
|
"LoadSGMVFIModel": LoadSGMVFIModel,
|
||||||
|
"SGMVFIInterpolate": SGMVFIInterpolate,
|
||||||
|
"SGMVFISegmentInterpolate": SGMVFISegmentInterpolate,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@@ -60,4 +64,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LoadEMAVFIModel": "Load EMA-VFI Model",
|
"LoadEMAVFIModel": "Load EMA-VFI Model",
|
||||||
"EMAVFIInterpolate": "EMA-VFI Interpolate",
|
"EMAVFIInterpolate": "EMA-VFI Interpolate",
|
||||||
"EMAVFISegmentInterpolate": "EMA-VFI Segment Interpolate",
|
"EMAVFISegmentInterpolate": "EMA-VFI Segment Interpolate",
|
||||||
|
"LoadSGMVFIModel": "Load SGM-VFI Model",
|
||||||
|
"SGMVFIInterpolate": "SGM-VFI Interpolate",
|
||||||
|
"SGMVFISegmentInterpolate": "SGM-VFI Segment Interpolate",
|
||||||
}
|
}
|
||||||
|
|||||||
159
inference.py
159
inference.py
@@ -7,6 +7,8 @@ import torch.nn as nn
|
|||||||
from .bim_vfi_arch import BiMVFI
|
from .bim_vfi_arch import BiMVFI
|
||||||
from .ema_vfi_arch import feature_extractor as ema_feature_extractor
|
from .ema_vfi_arch import feature_extractor as ema_feature_extractor
|
||||||
from .ema_vfi_arch import MultiScaleFlow as EMAMultiScaleFlow
|
from .ema_vfi_arch import MultiScaleFlow as EMAMultiScaleFlow
|
||||||
|
from .sgm_vfi_arch import feature_extractor as sgm_feature_extractor
|
||||||
|
from .sgm_vfi_arch import MultiScaleFlow as SGMMultiScaleFlow
|
||||||
from .utils.padder import InputPadder
|
from .utils.padder import InputPadder
|
||||||
|
|
||||||
logger = logging.getLogger("BIM-VFI")
|
logger = logging.getLogger("BIM-VFI")
|
||||||
@@ -282,3 +284,160 @@ class EMAVFIModel:
|
|||||||
pred = self._inference(img0, img1, timestep=time_step)
|
pred = self._inference(img0, img1, timestep=time_step)
|
||||||
pred = padder.unpad(pred)
|
pred = padder.unpad(pred)
|
||||||
return torch.clamp(pred, 0, 1)
|
return torch.clamp(pred, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SGM-VFI model wrapper
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _sgm_init_model_config(F=16, W=7, depth=[2, 2, 2, 4], num_key_points=0.5):
|
||||||
|
"""Build SGM-VFI model config dicts (backbone + multiscale)."""
|
||||||
|
return {
|
||||||
|
'embed_dims': [F, 2*F, 4*F, 8*F],
|
||||||
|
'num_heads': [8*F//32],
|
||||||
|
'mlp_ratios': [4],
|
||||||
|
'qkv_bias': True,
|
||||||
|
'norm_layer': partial(nn.LayerNorm, eps=1e-6),
|
||||||
|
'depths': depth,
|
||||||
|
'window_sizes': [W]
|
||||||
|
}, {
|
||||||
|
'embed_dims': [F, 2*F, 4*F, 8*F],
|
||||||
|
'motion_dims': [0, 0, 0, 8*F//depth[-1]],
|
||||||
|
'depths': depth,
|
||||||
|
'scales': [8],
|
||||||
|
'hidden_dims': [4*F],
|
||||||
|
'c': F,
|
||||||
|
'num_key_points': num_key_points,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _sgm_detect_variant(filename):
|
||||||
|
"""Auto-detect SGM-VFI model variant from filename.
|
||||||
|
|
||||||
|
Returns (F, depth).
|
||||||
|
Default is small (F=16) since the primary checkpoint (ours-1-2-points)
|
||||||
|
is a small model. Only detect base when "base" is in the filename.
|
||||||
|
"""
|
||||||
|
name = filename.lower()
|
||||||
|
is_base = "base" in name
|
||||||
|
if is_base:
|
||||||
|
return 32, [2, 2, 2, 6]
|
||||||
|
else:
|
||||||
|
return 16, [2, 2, 2, 4]
|
||||||
|
|
||||||
|
|
||||||
|
class SGMVFIModel:
|
||||||
|
"""Clean inference wrapper around SGM-VFI for ComfyUI integration."""
|
||||||
|
|
||||||
|
def __init__(self, checkpoint_path, variant="auto", num_key_points=0.5, tta=False, device="cpu"):
|
||||||
|
import os
|
||||||
|
filename = os.path.basename(checkpoint_path)
|
||||||
|
|
||||||
|
if variant == "auto":
|
||||||
|
F_dim, depth = _sgm_detect_variant(filename)
|
||||||
|
elif variant == "small":
|
||||||
|
F_dim, depth = 16, [2, 2, 2, 4]
|
||||||
|
else: # base
|
||||||
|
F_dim, depth = 32, [2, 2, 2, 6]
|
||||||
|
|
||||||
|
self.tta = tta
|
||||||
|
self.device = device
|
||||||
|
self.variant_name = "small" if F_dim == 16 else "base"
|
||||||
|
|
||||||
|
backbone_cfg, multiscale_cfg = _sgm_init_model_config(
|
||||||
|
F=F_dim, depth=depth, num_key_points=num_key_points)
|
||||||
|
backbone = sgm_feature_extractor(**backbone_cfg)
|
||||||
|
self.model = SGMMultiScaleFlow(backbone, **multiscale_cfg)
|
||||||
|
self._load_checkpoint(checkpoint_path)
|
||||||
|
self.model.eval()
|
||||||
|
self.model.to(device)
|
||||||
|
|
||||||
|
def _load_checkpoint(self, checkpoint_path):
|
||||||
|
"""Load checkpoint with module prefix stripping and buffer filtering."""
|
||||||
|
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
# Handle wrapped checkpoint formats
|
||||||
|
if isinstance(state_dict, dict):
|
||||||
|
if "model" in state_dict:
|
||||||
|
state_dict = state_dict["model"]
|
||||||
|
elif "state_dict" in state_dict:
|
||||||
|
state_dict = state_dict["state_dict"]
|
||||||
|
|
||||||
|
# Strip "module." prefix and filter out attn_mask/HW buffers
|
||||||
|
cleaned = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if "attn_mask" in k or k.endswith(".HW"):
|
||||||
|
continue
|
||||||
|
key = k
|
||||||
|
if key.startswith("module."):
|
||||||
|
key = key[len("module."):]
|
||||||
|
cleaned[key] = v
|
||||||
|
|
||||||
|
self.model.load_state_dict(cleaned, strict=False)
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
"""Move model to device (returns self for chaining)."""
|
||||||
|
self.device = device
|
||||||
|
self.model.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _inference(self, img0, img1, timestep=0.5):
|
||||||
|
"""Run single inference pass. Inputs already padded, on device."""
|
||||||
|
B = img0.shape[0]
|
||||||
|
imgs = torch.cat((img0, img1), 1)
|
||||||
|
|
||||||
|
if self.tta:
|
||||||
|
imgs_ = imgs.flip(2).flip(3)
|
||||||
|
input_batch = torch.cat((imgs, imgs_), 0)
|
||||||
|
_, _, _, preds, _ = self.model(input_batch, timestep=timestep)
|
||||||
|
return (preds[:B] + preds[B:].flip(2).flip(3)) / 2.
|
||||||
|
else:
|
||||||
|
_, _, _, pred, _ = self.model(imgs, timestep=timestep)
|
||||||
|
return pred
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def interpolate_pair(self, frame0, frame1, time_step=0.5):
|
||||||
|
"""Interpolate a single frame between two input frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame0: [1, C, H, W] tensor, float32, range [0, 1]
|
||||||
|
frame1: [1, C, H, W] tensor, float32, range [0, 1]
|
||||||
|
time_step: float in (0, 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Interpolated frame as [1, C, H, W] tensor, float32, clamped to [0, 1]
|
||||||
|
"""
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
img0 = frame0.to(device)
|
||||||
|
img1 = frame1.to(device)
|
||||||
|
|
||||||
|
padder = InputPadder(img0.shape, divisor=32, mode='replicate', center=True)
|
||||||
|
img0, img1 = padder.pad(img0, img1)
|
||||||
|
|
||||||
|
pred = self._inference(img0, img1, timestep=time_step)
|
||||||
|
pred = padder.unpad(pred)
|
||||||
|
return torch.clamp(pred, 0, 1)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def interpolate_batch(self, frames0, frames1, time_step=0.5):
|
||||||
|
"""Interpolate multiple frame pairs at once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames0: [B, C, H, W] tensor, float32, range [0, 1]
|
||||||
|
frames1: [B, C, H, W] tensor, float32, range [0, 1]
|
||||||
|
time_step: float in (0, 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1]
|
||||||
|
"""
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
img0 = frames0.to(device)
|
||||||
|
img1 = frames1.to(device)
|
||||||
|
|
||||||
|
padder = InputPadder(img0.shape, divisor=32, mode='replicate', center=True)
|
||||||
|
img0, img1 = padder.pad(img0, img1)
|
||||||
|
|
||||||
|
pred = self._inference(img0, img1, timestep=time_step)
|
||||||
|
pred = padder.unpad(pred)
|
||||||
|
return torch.clamp(pred, 0, 1)
|
||||||
|
|||||||
318
nodes.py
318
nodes.py
@@ -8,9 +8,10 @@ import torch
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
from .inference import BiMVFIModel, EMAVFIModel
|
from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel
|
||||||
from .bim_vfi_arch import clear_backwarp_cache
|
from .bim_vfi_arch import clear_backwarp_cache
|
||||||
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
|
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
|
||||||
|
from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache
|
||||||
|
|
||||||
logger = logging.getLogger("BIM-VFI")
|
logger = logging.getLogger("BIM-VFI")
|
||||||
|
|
||||||
@@ -31,6 +32,14 @@ EMA_MODEL_DIR = os.path.join(folder_paths.models_dir, "ema-vfi")
|
|||||||
if not os.path.exists(EMA_MODEL_DIR):
|
if not os.path.exists(EMA_MODEL_DIR):
|
||||||
os.makedirs(EMA_MODEL_DIR, exist_ok=True)
|
os.makedirs(EMA_MODEL_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# Google Drive folder ID for SGM-VFI pretrained models
|
||||||
|
SGM_GDRIVE_FOLDER_ID = "1S5O6W0a7XQDHgBtP9HnmoxYEzWBIzSJq"
|
||||||
|
SGM_DEFAULT_MODEL = "ours-1-2-points.pth"
|
||||||
|
|
||||||
|
SGM_MODEL_DIR = os.path.join(folder_paths.models_dir, "sgm-vfi")
|
||||||
|
if not os.path.exists(SGM_MODEL_DIR):
|
||||||
|
os.makedirs(SGM_MODEL_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def get_available_models():
|
def get_available_models():
|
||||||
"""List available checkpoint files in the bim-vfi model directory."""
|
"""List available checkpoint files in the bim-vfi model directory."""
|
||||||
@@ -767,3 +776,310 @@ class EMAVFISegmentInterpolate(EMAVFIInterpolate):
|
|||||||
result = result[1:] # skip duplicate boundary frame
|
result = result[1:] # skip duplicate boundary frame
|
||||||
|
|
||||||
return (result, model)
|
return (result, model)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SGM-VFI nodes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_available_sgm_models():
|
||||||
|
"""List available checkpoint files in the sgm-vfi model directory."""
|
||||||
|
models = []
|
||||||
|
if os.path.isdir(SGM_MODEL_DIR):
|
||||||
|
for f in os.listdir(SGM_MODEL_DIR):
|
||||||
|
if f.endswith((".pkl", ".pth", ".pt", ".ckpt", ".safetensors")):
|
||||||
|
models.append(f)
|
||||||
|
if not models:
|
||||||
|
models.append(SGM_DEFAULT_MODEL) # Will trigger auto-download
|
||||||
|
return sorted(models)
|
||||||
|
|
||||||
|
|
||||||
|
def download_sgm_model_from_gdrive(folder_id, dest_path):
|
||||||
|
"""Download SGM-VFI model from Google Drive folder using gdown."""
|
||||||
|
try:
|
||||||
|
import gdown
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"gdown is required to auto-download the SGM-VFI model. "
|
||||||
|
"Install it with: pip install gdown"
|
||||||
|
)
|
||||||
|
filename = os.path.basename(dest_path)
|
||||||
|
url = f"https://drive.google.com/drive/folders/{folder_id}"
|
||||||
|
logger.info(f"Downloading {filename} from Google Drive folder to {dest_path}...")
|
||||||
|
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
||||||
|
gdown.download_folder(url, output=os.path.dirname(dest_path), quiet=False, remaining_ok=True)
|
||||||
|
if not os.path.exists(dest_path):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to download {filename}. Please download manually from "
|
||||||
|
f"https://drive.google.com/drive/folders/{folder_id} "
|
||||||
|
f"and place it in {os.path.dirname(dest_path)}"
|
||||||
|
)
|
||||||
|
logger.info("Download complete.")
|
||||||
|
|
||||||
|
|
||||||
|
class LoadSGMVFIModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model_path": (get_available_sgm_models(), {
|
||||||
|
"default": SGM_DEFAULT_MODEL,
|
||||||
|
"tooltip": "Checkpoint file from models/sgm-vfi/. Auto-downloads on first use if missing. "
|
||||||
|
"Variant (base/small) is auto-detected from filename.",
|
||||||
|
}),
|
||||||
|
"tta": ("BOOLEAN", {
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Test-time augmentation: flip input and average with unflipped result. "
|
||||||
|
"~2x slower but slightly better quality.",
|
||||||
|
}),
|
||||||
|
"num_key_points": ("FLOAT", {
|
||||||
|
"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05,
|
||||||
|
"tooltip": "Sparsity of global matching. 0.0 = global matching everywhere (slower, better for large motion). "
|
||||||
|
"Higher = sparser keypoints (faster). Default 0.5 is a good balance.",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("SGM_VFI_MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
FUNCTION = "load_model"
|
||||||
|
CATEGORY = "video/SGM-VFI"
|
||||||
|
|
||||||
|
def load_model(self, model_path, tta, num_key_points):
|
||||||
|
full_path = os.path.join(SGM_MODEL_DIR, model_path)
|
||||||
|
|
||||||
|
if not os.path.exists(full_path):
|
||||||
|
logger.info(f"Model not found at {full_path}, attempting download...")
|
||||||
|
download_sgm_model_from_gdrive(SGM_GDRIVE_FOLDER_ID, full_path)
|
||||||
|
|
||||||
|
wrapper = SGMVFIModel(
|
||||||
|
checkpoint_path=full_path,
|
||||||
|
variant="auto",
|
||||||
|
num_key_points=num_key_points,
|
||||||
|
tta=tta,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"SGM-VFI model loaded (variant={wrapper.variant_name}, num_key_points={num_key_points}, tta={tta})")
|
||||||
|
return (wrapper,)
|
||||||
|
|
||||||
|
|
||||||
|
class SGMVFIInterpolate:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE", {
|
||||||
|
"tooltip": "Input image batch. Output frame count: 2x=(2N-1), 4x=(4N-3), 8x=(8N-7).",
|
||||||
|
}),
|
||||||
|
"model": ("SGM_VFI_MODEL", {
|
||||||
|
"tooltip": "SGM-VFI model from the Load SGM-VFI Model node.",
|
||||||
|
}),
|
||||||
|
"multiplier": ([2, 4, 8], {
|
||||||
|
"default": 2,
|
||||||
|
"tooltip": "Frame rate multiplier. 2x=one interpolation pass, 4x=two recursive passes, 8x=three. Higher = more frames but longer processing.",
|
||||||
|
}),
|
||||||
|
"clear_cache_after_n_frames": ("INT", {
|
||||||
|
"default": 10, "min": 1, "max": 100, "step": 1,
|
||||||
|
"tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower. Ignored when all_on_gpu is enabled.",
|
||||||
|
}),
|
||||||
|
"keep_device": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Keep model on GPU between frame pairs. Faster but uses more VRAM constantly. Disable to free VRAM between pairs (slower due to CPU-GPU transfers).",
|
||||||
|
}),
|
||||||
|
"all_on_gpu": ("BOOLEAN", {
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Store all intermediate frames on GPU instead of CPU. Much faster (no transfers) but requires enough VRAM for all frames. Recommended for 48GB+ cards.",
|
||||||
|
}),
|
||||||
|
"batch_size": ("INT", {
|
||||||
|
"default": 1, "min": 1, "max": 64, "step": 1,
|
||||||
|
"tooltip": "Number of frame pairs to process simultaneously. Higher = faster but uses more VRAM. Start with 1, increase until VRAM is full.",
|
||||||
|
}),
|
||||||
|
"chunk_size": ("INT", {
|
||||||
|
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||||
|
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
RETURN_NAMES = ("images",)
|
||||||
|
FUNCTION = "interpolate"
|
||||||
|
CATEGORY = "video/SGM-VFI"
|
||||||
|
|
||||||
|
def _interpolate_frames(self, frames, model, num_passes, batch_size,
|
||||||
|
device, storage_device, keep_device, all_on_gpu,
|
||||||
|
clear_cache_after_n_frames, pbar, step_ref):
|
||||||
|
"""Run all interpolation passes on a chunk of frames."""
|
||||||
|
for pass_idx in range(num_passes):
|
||||||
|
new_frames = []
|
||||||
|
num_pairs = frames.shape[0] - 1
|
||||||
|
pairs_since_clear = 0
|
||||||
|
|
||||||
|
for i in range(0, num_pairs, batch_size):
|
||||||
|
batch_end = min(i + batch_size, num_pairs)
|
||||||
|
actual_batch = batch_end - i
|
||||||
|
|
||||||
|
frames0 = frames[i:batch_end]
|
||||||
|
frames1 = frames[i + 1:batch_end + 1]
|
||||||
|
|
||||||
|
if not keep_device:
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
mids = model.interpolate_batch(frames0, frames1, time_step=0.5)
|
||||||
|
mids = mids.to(storage_device)
|
||||||
|
|
||||||
|
if not keep_device:
|
||||||
|
model.to("cpu")
|
||||||
|
|
||||||
|
for j in range(actual_batch):
|
||||||
|
new_frames.append(frames[i + j:i + j + 1])
|
||||||
|
new_frames.append(mids[j:j+1])
|
||||||
|
|
||||||
|
step_ref[0] += actual_batch
|
||||||
|
pbar.update_absolute(step_ref[0])
|
||||||
|
|
||||||
|
pairs_since_clear += actual_batch
|
||||||
|
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
|
||||||
|
clear_sgm_warp_cache()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
pairs_since_clear = 0
|
||||||
|
|
||||||
|
new_frames.append(frames[-1:])
|
||||||
|
frames = torch.cat(new_frames, dim=0)
|
||||||
|
|
||||||
|
if not all_on_gpu and torch.cuda.is_available():
|
||||||
|
clear_sgm_warp_cache()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return frames
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _count_steps(num_frames, num_passes):
|
||||||
|
"""Count total interpolation steps for a given input frame count."""
|
||||||
|
n = num_frames
|
||||||
|
total = 0
|
||||||
|
for _ in range(num_passes):
|
||||||
|
total += n - 1
|
||||||
|
n = 2 * n - 1
|
||||||
|
return total
|
||||||
|
|
||||||
|
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||||
|
keep_device, all_on_gpu, batch_size, chunk_size):
|
||||||
|
if images.shape[0] < 2:
|
||||||
|
return (images,)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
|
||||||
|
|
||||||
|
if all_on_gpu:
|
||||||
|
keep_device = True
|
||||||
|
|
||||||
|
storage_device = device if all_on_gpu else torch.device("cpu")
|
||||||
|
|
||||||
|
# Convert from ComfyUI [B, H, W, C] to model [B, C, H, W]
|
||||||
|
all_frames = images.permute(0, 3, 1, 2).to(storage_device)
|
||||||
|
total_input = all_frames.shape[0]
|
||||||
|
|
||||||
|
# Build chunk boundaries (1-frame overlap between consecutive chunks)
|
||||||
|
if chunk_size < 2 or chunk_size >= total_input:
|
||||||
|
chunks = [(0, total_input)]
|
||||||
|
else:
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
while start < total_input - 1:
|
||||||
|
end = min(start + chunk_size, total_input)
|
||||||
|
chunks.append((start, end))
|
||||||
|
start = end - 1 # overlap by 1 frame
|
||||||
|
if end == total_input:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Calculate total progress steps across all chunks
|
||||||
|
total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks)
|
||||||
|
pbar = ProgressBar(total_steps)
|
||||||
|
step_ref = [0]
|
||||||
|
|
||||||
|
if keep_device:
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
result_chunks = []
|
||||||
|
for chunk_idx, (chunk_start, chunk_end) in enumerate(chunks):
|
||||||
|
chunk_frames = all_frames[chunk_start:chunk_end].clone()
|
||||||
|
|
||||||
|
chunk_result = self._interpolate_frames(
|
||||||
|
chunk_frames, model, num_passes, batch_size,
|
||||||
|
device, storage_device, keep_device, all_on_gpu,
|
||||||
|
clear_cache_after_n_frames, pbar, step_ref,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip first frame of subsequent chunks (duplicate of previous chunk's last frame)
|
||||||
|
if chunk_idx > 0:
|
||||||
|
chunk_result = chunk_result[1:]
|
||||||
|
|
||||||
|
# Move completed chunk to CPU to bound memory when chunking
|
||||||
|
if len(chunks) > 1:
|
||||||
|
chunk_result = chunk_result.cpu()
|
||||||
|
|
||||||
|
result_chunks.append(chunk_result)
|
||||||
|
|
||||||
|
result = torch.cat(result_chunks, dim=0)
|
||||||
|
# Convert back to ComfyUI [B, H, W, C], on CPU
|
||||||
|
result = result.cpu().permute(0, 2, 3, 1)
|
||||||
|
return (result,)
|
||||||
|
|
||||||
|
|
||||||
|
class SGMVFISegmentInterpolate(SGMVFIInterpolate):
|
||||||
|
"""Process a numbered segment of the input batch for SGM-VFI.
|
||||||
|
|
||||||
|
Chain multiple instances with Save nodes between them to bound peak RAM.
|
||||||
|
The model pass-through output forces sequential execution so each segment
|
||||||
|
saves and frees from RAM before the next starts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
base = SGMVFIInterpolate.INPUT_TYPES()
|
||||||
|
base["required"]["segment_index"] = ("INT", {
|
||||||
|
"default": 0, "min": 0, "max": 10000, "step": 1,
|
||||||
|
"tooltip": "Which segment to process (0-based). Bounds RAM by only producing this segment's output frames, "
|
||||||
|
"unlike chunk_size which bounds VRAM but still assembles the full output in RAM. "
|
||||||
|
"Chain the model output to the next Segment Interpolate to force sequential execution.",
|
||||||
|
})
|
||||||
|
base["required"]["segment_size"] = ("INT", {
|
||||||
|
"default": 500, "min": 2, "max": 10000, "step": 1,
|
||||||
|
"tooltip": "Number of input frames per segment. Adjacent segments overlap by 1 frame for seamless stitching. "
|
||||||
|
"Smaller = less peak RAM per segment. Save each segment's output to disk before the next runs.",
|
||||||
|
})
|
||||||
|
return base
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE", "SGM_VFI_MODEL")
|
||||||
|
RETURN_NAMES = ("images", "model")
|
||||||
|
FUNCTION = "interpolate"
|
||||||
|
CATEGORY = "video/SGM-VFI"
|
||||||
|
|
||||||
|
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
|
||||||
|
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||||
|
segment_index, segment_size):
|
||||||
|
total_input = images.shape[0]
|
||||||
|
|
||||||
|
# Compute segment boundaries (1-frame overlap)
|
||||||
|
start = segment_index * (segment_size - 1)
|
||||||
|
end = min(start + segment_size, total_input)
|
||||||
|
|
||||||
|
if start >= total_input - 1:
|
||||||
|
# Past the end — return empty single frame + model
|
||||||
|
return (images[:1], model)
|
||||||
|
|
||||||
|
segment_images = images[start:end]
|
||||||
|
is_continuation = segment_index > 0
|
||||||
|
|
||||||
|
# Delegate to the parent interpolation logic
|
||||||
|
(result,) = super().interpolate(
|
||||||
|
segment_images, model, multiplier, clear_cache_after_n_frames,
|
||||||
|
keep_device, all_on_gpu, batch_size, chunk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_continuation:
|
||||||
|
result = result[1:] # skip duplicate boundary frame
|
||||||
|
|
||||||
|
return (result, model)
|
||||||
|
|||||||
5
sgm_vfi_arch/__init__.py
Normal file
5
sgm_vfi_arch/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .feature_extractor import feature_extractor
|
||||||
|
from .flow_estimation import MultiScaleFlow
|
||||||
|
from .warplayer import clear_warp_cache
|
||||||
|
|
||||||
|
__all__ = ['feature_extractor', 'MultiScaleFlow', 'clear_warp_cache']
|
||||||
116
sgm_vfi_arch/backbone.py
Normal file
116
sgm_vfi_arch/backbone.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .trident_conv import MultiScaleTridentConv
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
|
||||||
|
):
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
|
||||||
|
dilation=dilation, padding=dilation, stride=stride, bias=False)
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
|
||||||
|
dilation=dilation, padding=dilation, bias=False)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.norm1 = norm_layer(planes)
|
||||||
|
self.norm2 = norm_layer(planes)
|
||||||
|
if not stride == 1 or in_planes != planes:
|
||||||
|
self.norm3 = norm_layer(planes)
|
||||||
|
|
||||||
|
if stride == 1 and in_planes == planes:
|
||||||
|
self.downsample = None
|
||||||
|
else:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
|
y = self.relu(self.norm1(self.conv1(y)))
|
||||||
|
y = self.relu(self.norm2(self.conv2(y)))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return self.relu(x + y)
|
||||||
|
|
||||||
|
|
||||||
|
class CNNEncoder(nn.Module):
|
||||||
|
def __init__(self, output_dim=128,
|
||||||
|
norm_layer=nn.InstanceNorm2d,
|
||||||
|
num_output_scales=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(CNNEncoder, self).__init__()
|
||||||
|
self.num_branch = num_output_scales
|
||||||
|
|
||||||
|
feature_dims = [64, 96, 128]
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
|
||||||
|
self.norm1 = norm_layer(feature_dims[0])
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.in_planes = feature_dims[0]
|
||||||
|
self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
|
||||||
|
self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
|
||||||
|
|
||||||
|
# highest resolution 1/4 or 1/8
|
||||||
|
stride = 2 if num_output_scales == 1 else 1
|
||||||
|
self.layer3 = self._make_layer(feature_dims[2], stride=stride, norm_layer=norm_layer,
|
||||||
|
) # 1/4 or 1/8
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
|
||||||
|
|
||||||
|
if self.num_branch > 1:
|
||||||
|
if self.num_branch == 4:
|
||||||
|
strides = (1, 2, 4, 8)
|
||||||
|
elif self.num_branch == 3:
|
||||||
|
strides = (1, 2, 4)
|
||||||
|
elif self.num_branch == 2:
|
||||||
|
strides = (1, 2)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
strides=strides,
|
||||||
|
paddings=1,
|
||||||
|
num_branch=self.num_branch,
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||||
|
if m.weight is not None:
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
|
||||||
|
layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
|
||||||
|
layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
|
||||||
|
|
||||||
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
|
self.in_planes = dim
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.relu1(x)
|
||||||
|
|
||||||
|
x = self.layer1(x) # 1/2
|
||||||
|
x = self.layer2(x) # 1/4
|
||||||
|
x = self.layer3(x) # 1/8 or 1/4
|
||||||
|
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
if self.num_branch > 1:
|
||||||
|
out = self.trident_conv([x] * self.num_branch) # high to low res
|
||||||
|
else:
|
||||||
|
out = [x]
|
||||||
|
|
||||||
|
return out
|
||||||
459
sgm_vfi_arch/feature_extractor.py
Normal file
459
sgm_vfi_arch/feature_extractor.py
Normal file
@@ -0,0 +1,459 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
|
from .position import PositionEmbeddingSine
|
||||||
|
|
||||||
|
def window_partition(x, window_size):
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
|
||||||
|
windows = (
|
||||||
|
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], C)
|
||||||
|
)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def window_reverse(windows, window_size, H, W):
|
||||||
|
nwB, N, C = windows.shape
|
||||||
|
windows = windows.view(-1, window_size[0], window_size[1], C)
|
||||||
|
B = int(nwB / (H * W / window_size[0] / window_size[1]))
|
||||||
|
x = windows.view(
|
||||||
|
B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1
|
||||||
|
)
|
||||||
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def pad_if_needed(x, size, window_size):
|
||||||
|
n, h, w, c = size
|
||||||
|
pad_h = math.ceil(h / window_size[0]) * window_size[0] - h
|
||||||
|
pad_w = math.ceil(w / window_size[1]) * window_size[1] - w
|
||||||
|
if pad_h > 0 or pad_w > 0: # center-pad the feature on H and W axes
|
||||||
|
img_mask = torch.zeros((1, h + pad_h, w + pad_w, 1)) # 1 H W 1
|
||||||
|
h_slices = (
|
||||||
|
slice(0, pad_h // 2),
|
||||||
|
slice(pad_h // 2, h + pad_h // 2),
|
||||||
|
slice(h + pad_h // 2, None),
|
||||||
|
)
|
||||||
|
w_slices = (
|
||||||
|
slice(0, pad_w // 2),
|
||||||
|
slice(pad_w // 2, w + pad_w // 2),
|
||||||
|
slice(w + pad_w // 2, None),
|
||||||
|
)
|
||||||
|
cnt = 0
|
||||||
|
for h in h_slices:
|
||||||
|
for w in w_slices:
|
||||||
|
img_mask[:, h, w, :] = cnt
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
mask_windows = window_partition(
|
||||||
|
img_mask, window_size
|
||||||
|
) # nW, window_size*window_size, 1
|
||||||
|
mask_windows = mask_windows.squeeze(-1)
|
||||||
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
|
attn_mask = attn_mask.masked_fill(
|
||||||
|
attn_mask != 0, float(-100.0)
|
||||||
|
).masked_fill(attn_mask == 0, float(0.0))
|
||||||
|
return nn.functional.pad(
|
||||||
|
x,
|
||||||
|
(0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2),
|
||||||
|
), attn_mask
|
||||||
|
return x, None
|
||||||
|
|
||||||
|
|
||||||
|
def depad_if_needed(x, size, window_size):
|
||||||
|
n, h, w, c = size
|
||||||
|
pad_h = math.ceil(h / window_size[0]) * window_size[0] - h
|
||||||
|
pad_w = math.ceil(w / window_size[1]) * window_size[1] - w
|
||||||
|
if pad_h > 0 or pad_w > 0: # remove the center-padding on feature
|
||||||
|
return x[:, pad_h // 2: pad_h // 2 + h, pad_w // 2: pad_w // 2 + w, :].contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.dwconv = DWConv(hidden_features)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, H, W):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.dwconv(x, H, W)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class InterFrameAttention(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
|
||||||
|
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||||
|
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x1, x2, H, W, mask=None):
|
||||||
|
B, N, C = x1.shape
|
||||||
|
q = self.q(x1).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||||
|
kv = self.kv(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
k, v = kv[0], kv[1]
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
nW = mask.shape[0] # mask: nW, N, N
|
||||||
|
attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
|
||||||
|
1
|
||||||
|
).unsqueeze(0)
|
||||||
|
attn = attn.view(-1, self.num_heads, N, N)
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
else:
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MotionFormerBlock(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, window_size=0, shift_size=0, mlp_ratio=4., bidirectional=True,
|
||||||
|
qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||||
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, ):
|
||||||
|
super().__init__()
|
||||||
|
self.window_size = window_size
|
||||||
|
if not isinstance(self.window_size, (tuple, list)):
|
||||||
|
self.window_size = to_2tuple(window_size)
|
||||||
|
self.shift_size = shift_size
|
||||||
|
if not isinstance(self.shift_size, (tuple, list)):
|
||||||
|
self.shift_size = to_2tuple(shift_size)
|
||||||
|
self.bidirectional = bidirectional
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = InterFrameAttention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||||
|
attn_drop=attn_drop, proj_drop=drop)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
|
# BEGIN: absolute pos_embed, beneficial to local information extraction in our experiments
|
||||||
|
self.pos_embed = PositionEmbeddingSine(dim // 2)
|
||||||
|
# END: absolute pos_embed
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, H, W, B, self_att=False):
|
||||||
|
x = x.view(2 * B, H, W, -1)
|
||||||
|
x_pad, mask = pad_if_needed(x, x.size(), self.window_size)
|
||||||
|
|
||||||
|
if self.shift_size[0] or self.shift_size[1]:
|
||||||
|
_, H_p, W_p, C = x_pad.shape
|
||||||
|
x_pad = torch.roll(x_pad, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
|
||||||
|
|
||||||
|
if hasattr(self, 'HW') and self.HW.item() == H_p * W_p:
|
||||||
|
shift_mask = self.attn_mask
|
||||||
|
else:
|
||||||
|
shift_mask = torch.zeros((1, H_p, W_p, 1)) # 1 H W 1
|
||||||
|
h_slices = (slice(0, -self.window_size[0]),
|
||||||
|
slice(-self.window_size[0], -self.shift_size[0]),
|
||||||
|
slice(-self.shift_size[0], None))
|
||||||
|
w_slices = (slice(0, -self.window_size[1]),
|
||||||
|
slice(-self.window_size[1], -self.shift_size[1]),
|
||||||
|
slice(-self.shift_size[1], None))
|
||||||
|
cnt = 0
|
||||||
|
for h in h_slices:
|
||||||
|
for w in w_slices:
|
||||||
|
shift_mask[:, h, w, :] = cnt
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
mask_windows = window_partition(shift_mask, self.window_size).squeeze(-1)
|
||||||
|
shift_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
|
shift_mask = shift_mask.masked_fill(shift_mask != 0,
|
||||||
|
float(-100.0)).masked_fill(shift_mask == 0,
|
||||||
|
float(0.0))
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
shift_mask = shift_mask.masked_fill(mask != 0,
|
||||||
|
float(-100.0))
|
||||||
|
self.register_buffer("attn_mask", shift_mask)
|
||||||
|
self.register_buffer("HW", torch.Tensor([H_p * W_p]))
|
||||||
|
else:
|
||||||
|
shift_mask = mask
|
||||||
|
|
||||||
|
if shift_mask is not None:
|
||||||
|
shift_mask = shift_mask.to(x_pad.device)
|
||||||
|
|
||||||
|
_, Hw, Ww, C = x_pad.shape
|
||||||
|
x_win = window_partition(x_pad, self.window_size)
|
||||||
|
|
||||||
|
nwB = x_win.shape[0]
|
||||||
|
x_norm = self.norm1(x_win)
|
||||||
|
# BEGIN: absolute pos embed, beneficial to local information extraction in our experiments
|
||||||
|
x_norm = x_norm.view(nwB, self.window_size[0], self.window_size[1], C).permute(0, 3, 1, 2)
|
||||||
|
ape = self.pos_embed(x_norm)
|
||||||
|
x_norm = x_norm + ape
|
||||||
|
x_norm = x_norm.permute(0, 2, 3, 1).view(nwB, self.window_size[0] * self.window_size[1], C)
|
||||||
|
# END: absolute pos embed
|
||||||
|
|
||||||
|
if self_att is False:
|
||||||
|
x_reverse = torch.cat([x_norm[nwB // 2:], x_norm[:nwB // 2]])
|
||||||
|
x_appearence = self.attn(x_norm, x_reverse, H, W, shift_mask)
|
||||||
|
else:
|
||||||
|
x_appearence = self.attn(x_norm, x_norm, H, W, shift_mask)
|
||||||
|
|
||||||
|
x_norm = x_norm + self.drop_path(x_appearence)
|
||||||
|
|
||||||
|
x_back = x_norm
|
||||||
|
x_back_win = window_reverse(x_back, self.window_size, Hw, Ww)
|
||||||
|
|
||||||
|
if self.shift_size[0] or self.shift_size[1]:
|
||||||
|
x_back_win = torch.roll(x_back_win, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
|
||||||
|
|
||||||
|
x = depad_if_needed(x_back_win, x.size(), self.window_size).view(2 * B, H * W, -1)
|
||||||
|
|
||||||
|
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlock(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, depths=2, act_layer=nn.PReLU):
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
for i in range(depths):
|
||||||
|
if i == 0:
|
||||||
|
layers.append(nn.Conv2d(in_dim, out_dim, 3, 1, 1))
|
||||||
|
else:
|
||||||
|
layers.append(nn.Conv2d(out_dim, out_dim, 3, 1, 1))
|
||||||
|
layers.extend([
|
||||||
|
act_layer(out_dim),
|
||||||
|
])
|
||||||
|
self.conv = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class OverlapPatchEmbed(nn.Module):
|
||||||
|
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
||||||
|
super().__init__()
|
||||||
|
patch_size = to_2tuple(patch_size)
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
||||||
|
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
||||||
|
self.norm = nn.LayerNorm(embed_dim)
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.proj(x)
|
||||||
|
_, _, H, W = x.shape
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, H, W
|
||||||
|
|
||||||
|
|
||||||
|
class MotionFormer(nn.Module):
|
||||||
|
def __init__(self, in_chans=3, embed_dims=None, num_heads=None,
|
||||||
|
mlp_ratios=None, qkv_bias=True, qk_scale=None, drop_rate=0.,
|
||||||
|
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
||||||
|
depths=None, window_sizes=None, **kwarg):
|
||||||
|
super().__init__()
|
||||||
|
self.depths = depths
|
||||||
|
self.num_stages = len(embed_dims)
|
||||||
|
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||||
|
cur = 0
|
||||||
|
|
||||||
|
self.conv_stages = self.num_stages - len(num_heads)
|
||||||
|
|
||||||
|
for i in range(self.num_stages):
|
||||||
|
if i == 0:
|
||||||
|
block = ConvBlock(in_chans, embed_dims[i], depths[i])
|
||||||
|
else:
|
||||||
|
if i < self.conv_stages:
|
||||||
|
patch_embed = nn.Sequential(
|
||||||
|
nn.Conv2d(embed_dims[i - 1], embed_dims[i], 3, 2, 1),
|
||||||
|
nn.PReLU(embed_dims[i])
|
||||||
|
)
|
||||||
|
block = ConvBlock(embed_dims[i], embed_dims[i], depths[i])
|
||||||
|
else:
|
||||||
|
patch_embed = OverlapPatchEmbed(patch_size=3,
|
||||||
|
stride=2,
|
||||||
|
in_chans=embed_dims[i - 1],
|
||||||
|
embed_dim=embed_dims[i])
|
||||||
|
|
||||||
|
block = nn.ModuleList([MotionFormerBlock(
|
||||||
|
dim=embed_dims[i], num_heads=num_heads[i - self.conv_stages],
|
||||||
|
window_size=window_sizes[i - self.conv_stages],
|
||||||
|
shift_size=0 if (j % 2) == 0 else window_sizes[i - self.conv_stages] // 2,
|
||||||
|
mlp_ratio=mlp_ratios[i - self.conv_stages], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||||
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer)
|
||||||
|
for j in range(depths[i])])
|
||||||
|
|
||||||
|
norm = norm_layer(embed_dims[i])
|
||||||
|
setattr(self, f"norm{i + 1}", norm)
|
||||||
|
setattr(self, f"patch_embed{i + 1}", patch_embed)
|
||||||
|
cur += depths[i]
|
||||||
|
|
||||||
|
setattr(self, f"block{i + 1}", block)
|
||||||
|
|
||||||
|
self.cor = {}
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def get_cor(self, shape, device):
|
||||||
|
k = (str(shape), str(device))
|
||||||
|
if k not in self.cor:
|
||||||
|
tenHorizontal = torch.linspace(-1.0, 1.0, shape[2], device=device).view(
|
||||||
|
1, 1, 1, shape[2]).expand(shape[0], -1, shape[1], -1).permute(0, 2, 3, 1)
|
||||||
|
tenVertical = torch.linspace(-1.0, 1.0, shape[1], device=device).view(
|
||||||
|
1, 1, shape[1], 1).expand(shape[0], -1, -1, shape[2]).permute(0, 2, 3, 1)
|
||||||
|
self.cor[k] = torch.cat([tenHorizontal, tenVertical], -1).to(device)
|
||||||
|
return self.cor[k]
|
||||||
|
|
||||||
|
def forward(self, x1, x2):
|
||||||
|
B = x1.shape[0]
|
||||||
|
x = torch.cat([x1, x2], 0)
|
||||||
|
appearence_features = []
|
||||||
|
xs = []
|
||||||
|
for i in range(self.num_stages):
|
||||||
|
patch_embed = getattr(self, f"patch_embed{i + 1}", None)
|
||||||
|
block = getattr(self, f"block{i + 1}", None)
|
||||||
|
norm = getattr(self, f"norm{i + 1}", None)
|
||||||
|
if i < self.conv_stages:
|
||||||
|
if i > 0:
|
||||||
|
x = patch_embed(x)
|
||||||
|
x = block(x)
|
||||||
|
xs.append(x)
|
||||||
|
else:
|
||||||
|
x, H, W = patch_embed(x)
|
||||||
|
for j in range(len(block)):
|
||||||
|
x = block[j](x, H, W, B, self_att=False)
|
||||||
|
xs.append(x.reshape(2 * B, H, W, -1).permute(0, 3, 1, 2).contiguous())
|
||||||
|
x = norm(x)
|
||||||
|
x = x.reshape(2 * B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||||
|
appearence_features.append(x)
|
||||||
|
return appearence_features
|
||||||
|
|
||||||
|
|
||||||
|
class DWConv(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super(DWConv, self).__init__()
|
||||||
|
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
||||||
|
|
||||||
|
def forward(self, x, H, W):
|
||||||
|
B, N, C = x.shape
|
||||||
|
x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
|
||||||
|
x = self.dwconv(x)
|
||||||
|
x = x.reshape(B, C, -1).transpose(1, 2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def feature_extractor(**kargs):
|
||||||
|
model = MotionFormer(**kargs)
|
||||||
|
return model
|
||||||
208
sgm_vfi_arch/flow_estimation.py
Normal file
208
sgm_vfi_arch/flow_estimation.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .refine import *
|
||||||
|
from .matching import MatchingBlock
|
||||||
|
from .gmflow import GMFlow
|
||||||
|
from .utils import InputPadder
|
||||||
|
|
||||||
|
|
||||||
|
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||||
|
padding=padding, dilation=dilation, bias=True),
|
||||||
|
nn.PReLU(out_planes)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IFBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, c=64, layers=4, scale=4, in_else=17):
|
||||||
|
super(IFBlock, self).__init__()
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
self.conv0 = nn.Sequential(
|
||||||
|
conv(in_planes + in_else, c, 3, 1, 1),
|
||||||
|
conv(c, c, 3, 1, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convblock = nn.Sequential(
|
||||||
|
*[conv(c, c) for _ in range(layers)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.lastconv = conv(c, 5)
|
||||||
|
|
||||||
|
def forward(self, x, flow=None, feature=None):
|
||||||
|
if self.scale != 1:
|
||||||
|
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", align_corners=False)
|
||||||
|
if flow != None:
|
||||||
|
flow = F.interpolate(flow, scale_factor=1. / self.scale, mode="bilinear",
|
||||||
|
align_corners=False) * 1. / self.scale
|
||||||
|
x = torch.cat((x, flow), 1)
|
||||||
|
if feature != None:
|
||||||
|
x = torch.cat((x, feature), 1)
|
||||||
|
x = self.conv0(x)
|
||||||
|
x = self.convblock(x) + x
|
||||||
|
tmp = self.lastconv(x)
|
||||||
|
flow_s = tmp[:, :4]
|
||||||
|
tmp = F.interpolate(tmp, scale_factor=self.scale, mode="bilinear", align_corners=False)
|
||||||
|
flow = tmp[:, :4] * self.scale
|
||||||
|
mask = tmp[:, 4:5]
|
||||||
|
return flow, mask, flow_s
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleFlow(nn.Module):
|
||||||
|
def __init__(self, backbone, **kargs):
|
||||||
|
super(MultiScaleFlow, self).__init__()
|
||||||
|
self.flow_num_stage = len(kargs['hidden_dims'])
|
||||||
|
self.feature_bone = backbone
|
||||||
|
self.scale = [1, 2, 4, 8]
|
||||||
|
self.num_key_points = [kargs['num_key_points']]
|
||||||
|
self.block = nn.ModuleList(
|
||||||
|
[IFBlock(kargs['embed_dims'][-1] * 2, 128, 2, self.scale[-1], in_else=7), # 1/8
|
||||||
|
IFBlock(kargs['embed_dims'][-2] * 2, 128, 2, self.scale[-2], in_else=18)]) # 1/4
|
||||||
|
self.contextnet = Contextnet(kargs['c'] * 2)
|
||||||
|
self.unet = Unet(kargs['c'] * 2)
|
||||||
|
self.gmflow = GMFlow(
|
||||||
|
num_scales=1,
|
||||||
|
upsample_factor=8,
|
||||||
|
feature_channels=128,
|
||||||
|
attention_type='swin',
|
||||||
|
num_transformer_layers=6,
|
||||||
|
ffn_dim_expansion=4,
|
||||||
|
num_head=1)
|
||||||
|
|
||||||
|
self.matching_block = nn.ModuleList([
|
||||||
|
MatchingBlock(scale=8, dim=kargs['embed_dims'][-1], c=kargs['c'] * 4, num_layers=1, gm=True),
|
||||||
|
None
|
||||||
|
])
|
||||||
|
|
||||||
|
self.padding_factor = 16
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_flow(self, imgs, timestep):
|
||||||
|
img0, img1 = imgs[:, :3], imgs[:, 3:6]
|
||||||
|
B = img0.size(0)
|
||||||
|
flow, mask = None, None
|
||||||
|
flow_s = None
|
||||||
|
|
||||||
|
af = self.feature_bone(img0, img1)
|
||||||
|
if self.gmflow is not None:
|
||||||
|
padder = InputPadder(img0.shape, padding_factor=self.padding_factor)
|
||||||
|
img0_p, img1_p = padder.pad(img0, img1)
|
||||||
|
results = self.gmflow(img0_p, img1_p, attn_splits_list=[1], pred_bidir_flow=False)
|
||||||
|
matching_feat = results['trans_feat']
|
||||||
|
padder_8 = InputPadder(af[-1].shape, padding_factor=self.padding_factor // self.scale[-1])
|
||||||
|
matching_feat[0] = padder_8.unpad(matching_feat[0])
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
|
t = (img0[:B, :1].clone() * 0 + 1) * timestep
|
||||||
|
af0 = af[-1 - i][:B]
|
||||||
|
af1 = af[-1 - i][B:]
|
||||||
|
if flow != None:
|
||||||
|
flow_d, mask_d, flow_s_d = self.block[i](
|
||||||
|
torch.cat((img0, img1, warped_img0, warped_img1, mask, t), 1),
|
||||||
|
flow,
|
||||||
|
torch.cat([af0, af1], 1),
|
||||||
|
)
|
||||||
|
flow = flow + flow_d
|
||||||
|
mask = mask + mask_d
|
||||||
|
flow_s = F.interpolate(flow_s, scale_factor=2, mode="bilinear", align_corners=False) * 2
|
||||||
|
flow_s = flow_s + flow_s_d
|
||||||
|
else:
|
||||||
|
flow, mask, flow_s = self.block[i](
|
||||||
|
torch.cat((img0, img1, t), 1),
|
||||||
|
None,
|
||||||
|
torch.cat([af0, af1], 1))
|
||||||
|
warped_img0 = warp(img0, flow[:, :2])
|
||||||
|
warped_img1 = warp(img1, flow[:, 2:4])
|
||||||
|
if self.matching_block[i] is not None:
|
||||||
|
dict = self.matching_block[i](img0=img0, img1=img1, x=matching_feat[i], main_x=af[-1 - i],
|
||||||
|
init_flow=flow, init_flow_s=flow_s, init_mask=mask,
|
||||||
|
warped_img0=warped_img0, warped_img1=warped_img1,
|
||||||
|
num_key_points=self.num_key_points[i], scale_factor=self.scale[-1 - i],
|
||||||
|
timestep=timestep)
|
||||||
|
flow_t, mask_t = dict['flow_t'], dict['mask_t']
|
||||||
|
flow = flow + flow_t
|
||||||
|
mask = mask + mask_t
|
||||||
|
|
||||||
|
warped_img0 = warp(img0, flow[:, :2])
|
||||||
|
warped_img1 = warp(img1, flow[:, 2:4])
|
||||||
|
return flow, mask
|
||||||
|
|
||||||
|
def coraseWarp_and_Refine(self, imgs, flow, mask):
|
||||||
|
img0, img1 = imgs[:, :3], imgs[:, 3:6]
|
||||||
|
warped_img0 = warp(img0, flow[:, :2])
|
||||||
|
warped_img1 = warp(img1, flow[:, 2:4])
|
||||||
|
c0 = self.contextnet(img0, flow[:, :2])
|
||||||
|
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||||
|
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||||
|
res = tmp[:, :3] * 2 - 1
|
||||||
|
mask_ = torch.sigmoid(mask)
|
||||||
|
merged = warped_img0 * mask_ + warped_img1 * (1 - mask_)
|
||||||
|
pred = torch.clamp(merged + res, 0, 1)
|
||||||
|
return pred
|
||||||
|
|
||||||
|
def forward(self, x, timestep=0.5):
|
||||||
|
img0, img1 = x[:, :3], x[:, 3:6]
|
||||||
|
B = x.size(0)
|
||||||
|
flow_list, mask_list = [], []
|
||||||
|
merged, merged_fine = [], []
|
||||||
|
warped_img0, warped_img1 = img0, img1
|
||||||
|
flow, mask, flow_s = None, None, None
|
||||||
|
flow_matching_list = []
|
||||||
|
matching_feat = []
|
||||||
|
af = self.feature_bone(img0, img1)
|
||||||
|
if self.gmflow is not None:
|
||||||
|
padder = InputPadder(img0.shape, padding_factor=self.padding_factor, additional_pad=False)
|
||||||
|
img0_p, img1_p = padder.pad(img0, img1)
|
||||||
|
results = self.gmflow(img0_p, img1_p, attn_splits_list=[1], pred_bidir_flow=False)
|
||||||
|
matching_feat = results['trans_feat']
|
||||||
|
padder_8 = InputPadder(af[-1].shape, padding_factor=self.padding_factor // self.scale[-1], additional_pad=False)
|
||||||
|
matching_feat[0] = padder_8.unpad(matching_feat[0])
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
|
af0 = af[-1 - i][:B]
|
||||||
|
af1 = af[-1 - i][B:]
|
||||||
|
t = (img0[:B, :1].clone() * 0 + 1) * timestep
|
||||||
|
if flow != None:
|
||||||
|
flow_d, mask_d, flow_s_d = self.block[i](
|
||||||
|
torch.cat((img0, img1, warped_img0, warped_img1, mask, t), 1),
|
||||||
|
flow,
|
||||||
|
torch.cat([af0, af1], 1),
|
||||||
|
)
|
||||||
|
flow = flow + flow_d
|
||||||
|
mask = mask + mask_d
|
||||||
|
flow_s = F.interpolate(flow_s, scale_factor=2, mode="bilinear", align_corners=False) * 2
|
||||||
|
flow_s = flow_s + flow_s_d
|
||||||
|
else:
|
||||||
|
flow, mask, flow_s = self.block[i](
|
||||||
|
torch.cat((img0, img1, t), 1),
|
||||||
|
None,
|
||||||
|
torch.cat([af0, af1], 1))
|
||||||
|
mask_list.append(torch.sigmoid(mask))
|
||||||
|
flow_list.append(flow)
|
||||||
|
warped_img0 = warp(img0, flow[:, :2])
|
||||||
|
warped_img1 = warp(img1, flow[:, 2:4])
|
||||||
|
merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i]))
|
||||||
|
if self.matching_block[i] is not None:
|
||||||
|
dict = self.matching_block[i](img0=img0, img1=img1, x=matching_feat[i], main_x=af[-1-i].detach(),
|
||||||
|
init_flow=flow.detach(), init_flow_s=flow_s.detach(), init_mask=mask.detach(),
|
||||||
|
warped_img0=warped_img0.detach(), warped_img1=warped_img1.detach(),
|
||||||
|
num_key_points=self.num_key_points[i], scale_factor=self.scale[-1-i],
|
||||||
|
timestep=0.5)
|
||||||
|
flow_t, mask_t = dict['flow_t'], dict['mask_t']
|
||||||
|
flow = flow + flow_t
|
||||||
|
mask = mask + mask_t
|
||||||
|
mask_list[i] = torch.sigmoid(mask)
|
||||||
|
warped_img0_fine = warp(img0, flow[:, 0:2])
|
||||||
|
warped_img1_fine = warp(img1, flow[:, 2:4])
|
||||||
|
merged_fine.append(warped_img0_fine * mask_list[i] + warped_img1_fine * (1 - mask_list[i]))
|
||||||
|
warped_img0, warped_img1 = warped_img0_fine, warped_img1_fine # NOTE: for next iteration training
|
||||||
|
c0 = self.contextnet(img0, flow[:, :2])
|
||||||
|
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||||
|
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||||
|
res = tmp[:, :3] * 2 - 1
|
||||||
|
pred = torch.clamp(merged[-1] + res, 0, 1)
|
||||||
|
merged.extend(merged_fine)
|
||||||
|
return flow_list, mask_list, merged, pred, flow_matching_list
|
||||||
96
sgm_vfi_arch/geometry.py
Normal file
96
sgm_vfi_arch/geometry.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def coords_grid(b, h, w, homogeneous=False, device=None):
|
||||||
|
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
||||||
|
|
||||||
|
stacks = [x, y]
|
||||||
|
|
||||||
|
if homogeneous:
|
||||||
|
ones = torch.ones_like(x) # [H, W]
|
||||||
|
stacks.append(ones)
|
||||||
|
|
||||||
|
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
|
||||||
|
|
||||||
|
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
grid = grid.to(device)
|
||||||
|
|
||||||
|
return grid
|
||||||
|
|
||||||
|
|
||||||
|
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
|
||||||
|
assert device is not None
|
||||||
|
|
||||||
|
x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
|
||||||
|
torch.linspace(h_min, h_max, len_h, device=device)],
|
||||||
|
)
|
||||||
|
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
|
||||||
|
|
||||||
|
return grid
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_coords(coords, h, w):
|
||||||
|
# coords: [B, H, W, 2]
|
||||||
|
c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
|
||||||
|
return (coords - c) / c # [-1, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
|
||||||
|
# img: [B, C, H, W]
|
||||||
|
# sample_coords: [B, 2, H, W] in image scale
|
||||||
|
if sample_coords.size(1) != 2: # [B, H, W, 2]
|
||||||
|
sample_coords = sample_coords.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
b, _, h, w = sample_coords.shape
|
||||||
|
|
||||||
|
# Normalize to [-1, 1]
|
||||||
|
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
|
||||||
|
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
|
||||||
|
|
||||||
|
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
|
||||||
|
|
||||||
|
img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
|
||||||
|
|
||||||
|
if return_mask:
|
||||||
|
mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
|
||||||
|
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
|
||||||
|
b, c, h, w = feature.size()
|
||||||
|
assert flow.size(1) == 2
|
||||||
|
|
||||||
|
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
|
||||||
|
|
||||||
|
return bilinear_sample(feature, grid, padding_mode=padding_mode,
|
||||||
|
return_mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def forward_backward_consistency_check(fwd_flow, bwd_flow,
|
||||||
|
alpha=0.01,
|
||||||
|
beta=0.5
|
||||||
|
):
|
||||||
|
# fwd_flow, bwd_flow: [B, 2, H, W]
|
||||||
|
# alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
|
||||||
|
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
|
||||||
|
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
|
||||||
|
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
|
||||||
|
|
||||||
|
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
|
||||||
|
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
|
||||||
|
|
||||||
|
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
|
||||||
|
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
|
||||||
|
|
||||||
|
threshold = alpha * flow_mag + beta
|
||||||
|
|
||||||
|
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
|
||||||
|
bwd_occ = (diff_bwd > threshold).float()
|
||||||
|
|
||||||
|
return fwd_occ, bwd_occ
|
||||||
87
sgm_vfi_arch/gmflow.py
Normal file
87
sgm_vfi_arch/gmflow.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .backbone import CNNEncoder
|
||||||
|
from .transformer import FeatureTransformer, FeatureFlowAttention
|
||||||
|
from .utils import feature_add_position
|
||||||
|
|
||||||
|
class GMFlow(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_scales=1,
|
||||||
|
upsample_factor=8,
|
||||||
|
feature_channels=128,
|
||||||
|
attention_type='swin',
|
||||||
|
num_transformer_layers=6,
|
||||||
|
ffn_dim_expansion=4,
|
||||||
|
num_head=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(GMFlow, self).__init__()
|
||||||
|
|
||||||
|
self.num_scales = num_scales
|
||||||
|
self.feature_channels = feature_channels
|
||||||
|
self.upsample_factor = upsample_factor
|
||||||
|
self.attention_type = attention_type
|
||||||
|
self.num_transformer_layers = num_transformer_layers
|
||||||
|
|
||||||
|
# CNN backbone
|
||||||
|
self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
|
||||||
|
|
||||||
|
# Transformer
|
||||||
|
self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
|
||||||
|
d_model=feature_channels,
|
||||||
|
nhead=num_head,
|
||||||
|
attention_type=attention_type,
|
||||||
|
ffn_dim_expansion=ffn_dim_expansion,
|
||||||
|
)
|
||||||
|
|
||||||
|
def extract_feature(self, img0, img1):
|
||||||
|
concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
|
||||||
|
features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
|
||||||
|
|
||||||
|
# reverse: resolution from low to high
|
||||||
|
features = features[::-1]
|
||||||
|
|
||||||
|
feature0, feature1 = [], []
|
||||||
|
|
||||||
|
for i in range(len(features)):
|
||||||
|
feature = features[i]
|
||||||
|
chunks = torch.chunk(feature, 2, 0) # tuple
|
||||||
|
feature0.append(chunks[0])
|
||||||
|
feature1.append(chunks[1])
|
||||||
|
|
||||||
|
return feature0, feature1
|
||||||
|
|
||||||
|
def forward(self, img0, img1,
|
||||||
|
attn_splits_list=None,
|
||||||
|
corr_radius_list=None,
|
||||||
|
prop_radius_list=None,
|
||||||
|
pred_bidir_flow=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
results_dict = {}
|
||||||
|
flow_preds = []
|
||||||
|
flow_s_macthing = []
|
||||||
|
flow_s_prop = []
|
||||||
|
transformer_features = []
|
||||||
|
|
||||||
|
feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
|
||||||
|
|
||||||
|
flow = None
|
||||||
|
|
||||||
|
for scale_idx in range(self.num_scales):
|
||||||
|
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
|
||||||
|
|
||||||
|
attn_splits = attn_splits_list[scale_idx]
|
||||||
|
|
||||||
|
feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
|
||||||
|
|
||||||
|
# Transformer
|
||||||
|
feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits)
|
||||||
|
transformer_features.append(torch.cat([feature0, feature1], 0))
|
||||||
|
|
||||||
|
results_dict.update({'trans_feat': transformer_features})
|
||||||
|
|
||||||
|
return results_dict
|
||||||
278
sgm_vfi_arch/matching.py
Normal file
278
sgm_vfi_arch/matching.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .warplayer import warp as backwarp
|
||||||
|
from .softsplat import softsplat
|
||||||
|
from .geometry import coords_grid
|
||||||
|
|
||||||
|
|
||||||
|
# for random sample ablation
|
||||||
|
def random_sample(feature, num_points=256):
|
||||||
|
rand_ind = torch.randint(low=0, high=feature.shape[1], size=(feature.shape[0], num_points)).unsqueeze(-1).to(
|
||||||
|
feature.device)
|
||||||
|
kp = torch.gather(feature, dim=1, index=rand_ind.expand(-1, -1, feature.shape[2]))
|
||||||
|
return rand_ind, kp
|
||||||
|
|
||||||
|
def sample_key_points(importance_map, feature, num_points=256):
|
||||||
|
importance_map = importance_map.view(-1, 1, importance_map.shape[2] * importance_map.shape[3]).permute(0, 2, 1)
|
||||||
|
_, kp_ind = torch.topk(importance_map, num_points, dim=1)
|
||||||
|
kp = torch.gather(feature, dim=1, index=kp_ind.expand(-1, -1, feature.shape[2]))
|
||||||
|
return kp_ind, kp
|
||||||
|
|
||||||
|
|
||||||
|
def forward_warp(tenIn, tenFlow, z=None):
|
||||||
|
if z is None:
|
||||||
|
z = torch.ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]).to(tenIn.device)
|
||||||
|
else:
|
||||||
|
z = torch.where(z == 0, -20, 1)
|
||||||
|
out = softsplat(tenIn, tenFlow, tenMetric=z, strMode='soft')
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def warp_twice(imgA, target, flow_tA, flow_tB):
|
||||||
|
It_warp = backwarp(imgA, flow_tA) # backward warp(I1,Ft1)
|
||||||
|
z = torch.ones([imgA.shape[0], 1, imgA.shape[2], imgA.shape[3]]).to(imgA.device)
|
||||||
|
IB_warp = softsplat(tenIn=It_warp, tenFlow=flow_tB, tenMetric=z, strMode='soft')
|
||||||
|
return IB_warp
|
||||||
|
|
||||||
|
|
||||||
|
def build_map(imgA, imgB, flow_tA, flow_tB):
|
||||||
|
# build map for img B
|
||||||
|
IB_warp = warp_twice(imgA, imgB, flow_tA, flow_tB)
|
||||||
|
difference_map = IB_warp - imgB # [B, 3, H, W], difference map on IB
|
||||||
|
difference_map = torch.sum(torch.abs(difference_map), dim=1, keepdim=True) # B, 1, H, W
|
||||||
|
return difference_map
|
||||||
|
|
||||||
|
|
||||||
|
def build_hole_mask(img_template, flow_tA, flow_tB):
|
||||||
|
# build hole mask
|
||||||
|
with torch.no_grad():
|
||||||
|
ones = torch.ones(img_template.shape[0], 1, img_template.shape[2], img_template.shape[3]).to(
|
||||||
|
img_template.device)
|
||||||
|
out = warp_twice(ones, ones, flow_tA, flow_tB)
|
||||||
|
hole_mask = torch.where(out == 0, 0, 1)
|
||||||
|
return hole_mask
|
||||||
|
|
||||||
|
|
||||||
|
def gen_importance_map(img0, img1, flow):
|
||||||
|
I1_dmap = build_map(img0, img1, flow[:, 0:2], flow[:, 2:4])
|
||||||
|
I0_dmap = build_map(img1, img0, flow[:, 2:4], flow[:, 0:2])
|
||||||
|
|
||||||
|
I1_hole_mask = build_hole_mask(img0, flow[:, 0:2], flow[:, 2:4])
|
||||||
|
I0_hole_mask = build_hole_mask(img1, flow[:, 2:4], flow[:, 0:2])
|
||||||
|
|
||||||
|
I1_dmap = I1_dmap * I1_hole_mask
|
||||||
|
I0_dmap = I0_dmap * I0_hole_mask
|
||||||
|
|
||||||
|
I0_prob = warp_twice(I1_dmap, I1_dmap, flow[:, 2:4], flow[:, 0:2])
|
||||||
|
I1_prob = warp_twice(I0_dmap, I0_dmap, flow[:, 0:2], flow[:, 2:4])
|
||||||
|
|
||||||
|
importance_map = torch.cat([I0_prob, I1_prob], dim=0) # 2B, 1, H, W
|
||||||
|
return importance_map
|
||||||
|
|
||||||
|
|
||||||
|
def global_matching(key_feature, global_feature, key_index, H, W):
|
||||||
|
b, n, c = global_feature.shape
|
||||||
|
query = key_feature
|
||||||
|
key = global_feature
|
||||||
|
correlation = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, k, H*W]
|
||||||
|
|
||||||
|
prob = F.softmax(correlation, dim=-1)
|
||||||
|
init_grid = coords_grid(b, H, W, homogeneous=False, device=global_feature.device)
|
||||||
|
grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
||||||
|
out = torch.matmul(prob, grid) # B, k, 2
|
||||||
|
if key_index is not None:
|
||||||
|
flow_fix = torch.zeros_like(grid)
|
||||||
|
# key_index: [B, K, 1], out: [B, K, 2], flow_fix: [B, H*W, 2]
|
||||||
|
flow_fix = torch.scatter(flow_fix, dim=1, index=key_index.expand(-1, -1, 2), src=out)
|
||||||
|
flow_fix = flow_fix.view(b, H, W, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
|
||||||
|
|
||||||
|
# for grid, points in grid and not in key_index, set to 0
|
||||||
|
grid_new = torch.zeros_like(grid)
|
||||||
|
key_pos = torch.ones_like(out)
|
||||||
|
grid_new = torch.scatter(grid_new, dim=1, index=key_index.expand(-1, -1, 2), src=key_pos)
|
||||||
|
grid = (grid * grid_new).reshape(b, H, W, 2).permute(0, 3, 1, 2)
|
||||||
|
flow_fix = flow_fix - grid
|
||||||
|
else:
|
||||||
|
flow_fix = out.view(b, H, W, 2).permute(0, 3, 1, 2)
|
||||||
|
flow_fix = flow_fix - init_grid
|
||||||
|
return flow_fix, prob
|
||||||
|
|
||||||
|
|
||||||
|
def extract_topk(foo, k):
|
||||||
|
b, _, h, w = foo.shape
|
||||||
|
foo = foo.view(b, 1, h * w).permute(0, 2, 1)
|
||||||
|
kp, kp_ind = torch.topk(foo, k, dim=1)
|
||||||
|
grid = torch.zeros(b, h * w, 1).to(foo.device)
|
||||||
|
out = torch.scatter(grid, dim=1, index=kp_ind, src=kp)
|
||||||
|
out = out.permute(0, 2, 1).reshape(b, 1, h, w)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def flow_shift(flow_fix, timestep, num_key_points=None, select_topk=False):
|
||||||
|
B = flow_fix.shape[0] // 2
|
||||||
|
z = torch.where(flow_fix == 0, 0, 1).detach().sum(1, keepdim=True) / 2
|
||||||
|
zt0, zt1 = z[B:], z[:B]
|
||||||
|
flow_fix_t0 = forward_warp(flow_fix[B:] * timestep, flow_fix[B:] * (1 - timestep), z=zt0)
|
||||||
|
flow_fix_t1 = forward_warp(flow_fix[:B] * (1 - timestep), flow_fix[:B] * timestep, z=zt1)
|
||||||
|
flow_fix_t = torch.cat([flow_fix_t0, flow_fix_t1], 0)
|
||||||
|
if select_topk and num_key_points != -1:
|
||||||
|
warp_map_t0 = softsplat(zt0, flow_fix[B:] * (1 - timestep), None, 'sum')
|
||||||
|
warp_map_t1 = softsplat(zt1, flow_fix[:B] * timestep, None, 'sum')
|
||||||
|
|
||||||
|
warp_map = torch.cat([warp_map_t0, warp_map_t1], 0)
|
||||||
|
warp_map_topk = extract_topk(warp_map, num_key_points)
|
||||||
|
warp_map_topk = torch.where(warp_map_topk != 0, 1, 0)
|
||||||
|
flow_fix_t = flow_fix_t * warp_map_topk
|
||||||
|
return flow_fix_t
|
||||||
|
|
||||||
|
|
||||||
|
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||||
|
padding=padding, dilation=dilation, bias=True),
|
||||||
|
nn.PReLU(out_planes)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def deconv(in_planes=64, out_planes=64, kernel_size=4, stride=2, padding=1):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||||
|
padding=padding, bias=True),
|
||||||
|
nn.PReLU(out_planes)
|
||||||
|
)
|
||||||
|
|
||||||
|
class FlowRefine(nn.Module):
|
||||||
|
def __init__(self, in_planes, scale=4, c=64, n_layers=8):
|
||||||
|
super(FlowRefine, self).__init__()
|
||||||
|
self.conv0 = nn.Sequential(
|
||||||
|
conv(in_planes, c, 3, 1, 1),
|
||||||
|
conv(c, c, 3, 1, 1),
|
||||||
|
)
|
||||||
|
self.convblock = nn.Sequential(
|
||||||
|
*[conv(c, c) for _ in range(n_layers)]
|
||||||
|
)
|
||||||
|
self.lastconv = conv(c, 5)
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, x, flow_s, flow):
|
||||||
|
if self.scale != 1:
|
||||||
|
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", align_corners=False)
|
||||||
|
if flow is not None:
|
||||||
|
flow = F.interpolate(flow, scale_factor=1. / self.scale, mode="bilinear",
|
||||||
|
align_corners=False) * 1. / self.scale
|
||||||
|
x = torch.cat((x, flow), 1)
|
||||||
|
if flow_s is not None:
|
||||||
|
x = torch.cat((x, flow_s), 1)
|
||||||
|
x = self.conv0(x)
|
||||||
|
x = self.convblock(x) + x
|
||||||
|
x = self.lastconv(x)
|
||||||
|
tmp = F.interpolate(x, scale_factor=self.scale, mode="bilinear", align_corners=False)
|
||||||
|
flow = tmp[:, :4] * self.scale
|
||||||
|
mask = tmp[:, 4:5]
|
||||||
|
return flow, mask
|
||||||
|
|
||||||
|
|
||||||
|
class MergingBlock(nn.Module):
|
||||||
|
def __init__(self, radius=3, input_dim=256, hidden_dim=256):
|
||||||
|
super(MergingBlock, self).__init__()
|
||||||
|
self.r = radius
|
||||||
|
self.rf = radius ** 2
|
||||||
|
self.conv = nn.Sequential(nn.Conv2d(8 + 2*input_dim, hidden_dim, 3, 1, 1),
|
||||||
|
nn.PReLU(hidden_dim),
|
||||||
|
nn.Conv2d(hidden_dim, 2*2*self.rf, 1, 1, 0))
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(0.1 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, feature, init_flow, flow_fix):
|
||||||
|
"""
|
||||||
|
:param feature: [B, C, H, W] -> (local feature) or (local feature + matching feature)
|
||||||
|
:param init_flow: [B, 2, H, W] -> (local init flow)
|
||||||
|
:param flow_fix: [B, 2, H, W] -> (matching output, flow_fix (after patching, no hollows))
|
||||||
|
"""
|
||||||
|
b, flow_channel, h, w = init_flow.shape
|
||||||
|
concat = torch.cat((init_flow, flow_fix, feature), dim=1)
|
||||||
|
mask = self.conv(concat)
|
||||||
|
assert init_flow.shape == flow_fix.shape, f"different flow shape not implemented yet"
|
||||||
|
mask = mask.view(b, 1, 2 * 2 * self.rf, h, w)
|
||||||
|
mask0 = mask[:, :, :2 * self.rf, :, :]
|
||||||
|
mask1 = mask[:, :, 2 * self.rf:, :, :]
|
||||||
|
mask = torch.cat([mask0, mask1], dim=0)
|
||||||
|
mask = torch.softmax(mask, dim=2)
|
||||||
|
|
||||||
|
init_flow_all = torch.cat([init_flow[:, 0:2], init_flow[:, 2:4]], dim=0)
|
||||||
|
flow_fix_all = torch.cat([flow_fix[:, 0:2], flow_fix[:, 2:4]], dim=0)
|
||||||
|
|
||||||
|
init_flow_grid = F.unfold(init_flow_all, [self.r, self.r], padding=self.r//2)
|
||||||
|
init_flow_grid = init_flow_grid.view(2*b, 2, self.rf, h, w) # [B, 2, 9, H, W]
|
||||||
|
flow_fix_grid = F.unfold(flow_fix_all, [self.r, self.r], padding=self.r//2)
|
||||||
|
flow_fix_grid = flow_fix_grid.view(2*b, 2, self.rf, h, w) # [B, 2, 9, H, W]
|
||||||
|
|
||||||
|
flow_grid = torch.cat([init_flow_grid, flow_fix_grid], dim=2) # [B, 2, 2*9, H, W]
|
||||||
|
|
||||||
|
merge_flow = torch.sum(mask * flow_grid, dim=2) # [B, 2, H, W]
|
||||||
|
return merge_flow
|
||||||
|
|
||||||
|
|
||||||
|
class MatchingBlock(nn.Module):
|
||||||
|
def __init__(self, scale, c, dim, num_layers=2, gm=True):
|
||||||
|
super(MatchingBlock, self).__init__()
|
||||||
|
self.gm = gm
|
||||||
|
self.dim = dim
|
||||||
|
self.scale = scale
|
||||||
|
self.merge = MergingBlock(radius=3, input_dim=dim+128, hidden_dim=256)
|
||||||
|
self.refine_block = FlowRefine(27, scale, c, num_layers)
|
||||||
|
|
||||||
|
def forward(self, img0, img1, x, main_x, init_flow, init_flow_s, init_mask,
|
||||||
|
warped_img0, warped_img1, num_key_points, scale_factor, timestep=0.5):
|
||||||
|
result_dict = {}
|
||||||
|
|
||||||
|
_, c, h, w = x.shape
|
||||||
|
B = main_x.shape[0] // 2
|
||||||
|
# NOTE:
|
||||||
|
# 1. we stop sparse selecting points when the image resolution
|
||||||
|
# becomes too small (1/8 feature map resolution <= 32, i.e., h <= 256)
|
||||||
|
# (see `random_rescale` in train_x4k.py)
|
||||||
|
# 2. This limitation should be deleted when evaluating on low-resolution images (<=256x256)
|
||||||
|
if num_key_points != -1 and h > 32:
|
||||||
|
num_key_points = int(num_key_points * (h * w))
|
||||||
|
else:
|
||||||
|
num_key_points = -1 # -1 stands for global matching
|
||||||
|
|
||||||
|
feature = x.permute(0, 2, 3, 1).reshape(2 * B, h*w, c)
|
||||||
|
feature_reverse = torch.cat([feature[B:], feature[:B]], 0)
|
||||||
|
|
||||||
|
if num_key_points == -1:
|
||||||
|
flow_fix_norm, _ = global_matching(feature, feature_reverse, None, h, w)
|
||||||
|
else:
|
||||||
|
imap = gen_importance_map(img0, img1, init_flow)
|
||||||
|
imap_s = F.interpolate(imap, size=(h, w), mode="bilinear", align_corners=False)
|
||||||
|
kp_ind, kp_feature = sample_key_points(imap_s, feature, num_key_points)
|
||||||
|
flow_fix_norm, _ = global_matching(kp_feature, feature_reverse, kp_ind, h, w)
|
||||||
|
|
||||||
|
flow_fix = flow_shift(flow_fix_norm, timestep, num_key_points, select_topk=True)
|
||||||
|
flow_fix = torch.cat([flow_fix[:B], flow_fix[B:]], 1)
|
||||||
|
flow_r = torch.where(flow_fix == 0, init_flow_s, flow_fix)
|
||||||
|
flow_merge = self.merge(torch.cat([x[:B], x[B:], main_x[:B], main_x[B:]], dim=1), init_flow_s, flow_r)
|
||||||
|
flow_merge = torch.cat([flow_merge[:B], flow_merge[B:]], dim=1)
|
||||||
|
img0_s = F.interpolate(img0, scale_factor=1 / scale_factor, mode="bilinear", align_corners=False)
|
||||||
|
img1_s = F.interpolate(img1, scale_factor=1 / scale_factor, mode="bilinear", align_corners=False)
|
||||||
|
warped_img0_fine_s_m = backwarp(img0_s, flow_merge[:, 0:2])
|
||||||
|
warped_img1_fine_s_m = backwarp(img1_s, flow_merge[:, 2:4])
|
||||||
|
|
||||||
|
flow_t, mask_t = self.refine_block(torch.cat((img0, img1, warped_img0, warped_img1, init_mask), 1),
|
||||||
|
torch.cat([warped_img0_fine_s_m, warped_img1_fine_s_m, flow_merge], 1),
|
||||||
|
init_flow)
|
||||||
|
|
||||||
|
result_dict.update({'flow_t': flow_t})
|
||||||
|
result_dict.update({'mask_t': mask_t})
|
||||||
|
return result_dict
|
||||||
46
sgm_vfi_arch/position.py
Normal file
46
sgm_vfi_arch/position.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEmbeddingSine(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a more standard version of the position embedding, very similar to the one
|
||||||
|
used by the Attention is all you need paper, generalized to work on images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_pos_feats = num_pos_feats
|
||||||
|
self.temperature = temperature
|
||||||
|
self.normalize = normalize
|
||||||
|
if scale is not None and normalize is False:
|
||||||
|
raise ValueError("normalize should be True if scale is passed")
|
||||||
|
if scale is None:
|
||||||
|
scale = 2 * math.pi
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x = tensor_list.tensors # [B, C, H, W]
|
||||||
|
# mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
|
||||||
|
b, c, h, w = x.size()
|
||||||
|
mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
|
||||||
|
y_embed = mask.cumsum(1, dtype=torch.float32)
|
||||||
|
x_embed = mask.cumsum(2, dtype=torch.float32)
|
||||||
|
if self.normalize:
|
||||||
|
eps = 1e-6
|
||||||
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||||
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||||
|
|
||||||
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||||
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||||
|
|
||||||
|
pos_x = x_embed[:, :, :, None] / dim_t
|
||||||
|
pos_y = y_embed[:, :, :, None] / dim_t
|
||||||
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||||
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||||
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||||
|
return pos
|
||||||
98
sgm_vfi_arch/refine.py
Normal file
98
sgm_vfi_arch/refine.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
from timm.models.layers import trunc_normal_
|
||||||
|
from .warplayer import warp
|
||||||
|
|
||||||
|
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||||
|
padding=padding, dilation=dilation, bias=True),
|
||||||
|
nn.PReLU(out_planes)
|
||||||
|
)
|
||||||
|
|
||||||
|
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||||
|
return nn.Sequential(
|
||||||
|
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True),
|
||||||
|
nn.PReLU(out_planes)
|
||||||
|
)
|
||||||
|
|
||||||
|
class Conv2(nn.Module):
|
||||||
|
def __init__(self, in_planes, out_planes, stride=2):
|
||||||
|
super(Conv2, self).__init__()
|
||||||
|
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
|
||||||
|
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Contextnet(nn.Module):
|
||||||
|
def __init__(self, c=16):
|
||||||
|
super(Contextnet, self).__init__()
|
||||||
|
self.conv1 = Conv2(3, c)
|
||||||
|
self.conv2 = Conv2(c, 2 * c)
|
||||||
|
self.conv3 = Conv2(2 * c, 4 * c)
|
||||||
|
self.conv4 = Conv2(4 * c, 8 * c)
|
||||||
|
|
||||||
|
def forward(self, x, flow):
|
||||||
|
x = self.conv1(x)
|
||||||
|
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False,
|
||||||
|
recompute_scale_factor=False) * 0.5
|
||||||
|
f1 = warp(x, flow)
|
||||||
|
x = self.conv2(x)
|
||||||
|
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False,
|
||||||
|
recompute_scale_factor=False) * 0.5
|
||||||
|
f2 = warp(x, flow)
|
||||||
|
x = self.conv3(x)
|
||||||
|
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False,
|
||||||
|
recompute_scale_factor=False) * 0.5
|
||||||
|
f3 = warp(x, flow)
|
||||||
|
x = self.conv4(x)
|
||||||
|
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False,
|
||||||
|
recompute_scale_factor=False) * 0.5
|
||||||
|
f4 = warp(x, flow)
|
||||||
|
return [f1, f2, f3, f4]
|
||||||
|
|
||||||
|
class Unet(nn.Module):
|
||||||
|
def __init__(self, c=16, out=3):
|
||||||
|
super(Unet, self).__init__()
|
||||||
|
self.down0 = Conv2(17, 2*c)
|
||||||
|
self.down1 = Conv2(4*c, 4*c)
|
||||||
|
self.down2 = Conv2(8*c, 8*c)
|
||||||
|
self.down3 = Conv2(16*c, 16*c)
|
||||||
|
self.up0 = deconv(32*c, 8*c)
|
||||||
|
self.up1 = deconv(16*c, 4*c)
|
||||||
|
self.up2 = deconv(8*c, 2*c)
|
||||||
|
self.up3 = deconv(4*c, c)
|
||||||
|
self.conv = nn.Conv2d(c, out, 3, 1, 1)
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
|
||||||
|
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1))
|
||||||
|
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
|
||||||
|
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
|
||||||
|
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
|
||||||
|
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
|
||||||
|
x = self.up1(torch.cat((x, s2), 1))
|
||||||
|
x = self.up2(torch.cat((x, s1), 1))
|
||||||
|
x = self.up3(torch.cat((x, s0), 1))
|
||||||
|
x = self.conv(x)
|
||||||
|
return torch.sigmoid(x)
|
||||||
530
sgm_vfi_arch/softsplat.py
Normal file
530
sgm_vfi_arch/softsplat.py
Normal file
@@ -0,0 +1,530 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import cupy
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
import typing
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
|
objCudacache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_int32(intIn:int):
|
||||||
|
return cupy.int32(intIn)
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_float32(fltIn:float):
|
||||||
|
return cupy.float32(fltIn)
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
|
||||||
|
if 'device' not in objCudacache:
|
||||||
|
objCudacache['device'] = torch.cuda.get_device_name()
|
||||||
|
# end
|
||||||
|
|
||||||
|
strKey = strFunction
|
||||||
|
|
||||||
|
for strVariable in objVariables:
|
||||||
|
objValue = objVariables[strVariable]
|
||||||
|
|
||||||
|
strKey += strVariable
|
||||||
|
|
||||||
|
if objValue is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif type(objValue) == int:
|
||||||
|
strKey += str(objValue)
|
||||||
|
|
||||||
|
elif type(objValue) == float:
|
||||||
|
strKey += str(objValue)
|
||||||
|
|
||||||
|
elif type(objValue) == bool:
|
||||||
|
strKey += str(objValue)
|
||||||
|
|
||||||
|
elif type(objValue) == str:
|
||||||
|
strKey += objValue
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor:
|
||||||
|
strKey += str(objValue.dtype)
|
||||||
|
strKey += str(objValue.shape)
|
||||||
|
strKey += str(objValue.stride())
|
||||||
|
|
||||||
|
elif True:
|
||||||
|
print(strVariable, type(objValue))
|
||||||
|
assert(False)
|
||||||
|
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
strKey += objCudacache['device']
|
||||||
|
|
||||||
|
if strKey not in objCudacache:
|
||||||
|
for strVariable in objVariables:
|
||||||
|
objValue = objVariables[strVariable]
|
||||||
|
|
||||||
|
if objValue is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif type(objValue) == int:
|
||||||
|
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
||||||
|
|
||||||
|
elif type(objValue) == float:
|
||||||
|
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
||||||
|
|
||||||
|
elif type(objValue) == bool:
|
||||||
|
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
||||||
|
|
||||||
|
elif type(objValue) == str:
|
||||||
|
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
|
||||||
|
strKernel = strKernel.replace('{{type}}', 'unsigned char')
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
|
||||||
|
strKernel = strKernel.replace('{{type}}', 'half')
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
|
||||||
|
strKernel = strKernel.replace('{{type}}', 'float')
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
|
||||||
|
strKernel = strKernel.replace('{{type}}', 'double')
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
|
||||||
|
strKernel = strKernel.replace('{{type}}', 'int')
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
|
||||||
|
strKernel = strKernel.replace('{{type}}', 'long')
|
||||||
|
|
||||||
|
elif type(objValue) == torch.Tensor:
|
||||||
|
print(strVariable, objValue.dtype)
|
||||||
|
assert(False)
|
||||||
|
|
||||||
|
elif True:
|
||||||
|
print(strVariable, type(objValue))
|
||||||
|
assert(False)
|
||||||
|
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
|
||||||
|
while True:
|
||||||
|
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
||||||
|
|
||||||
|
if objMatch is None:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intArg = int(objMatch.group(2))
|
||||||
|
|
||||||
|
strTensor = objMatch.group(4)
|
||||||
|
intSizes = objVariables[strTensor].size()
|
||||||
|
|
||||||
|
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
|
||||||
|
# end
|
||||||
|
|
||||||
|
while True:
|
||||||
|
objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
|
||||||
|
|
||||||
|
if objMatch is None:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intStart = objMatch.span()[1]
|
||||||
|
intStop = objMatch.span()[1]
|
||||||
|
intParentheses = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
intParentheses += 1 if strKernel[intStop] == '(' else 0
|
||||||
|
intParentheses -= 1 if strKernel[intStop] == ')' else 0
|
||||||
|
|
||||||
|
if intParentheses == 0:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intStop += 1
|
||||||
|
# end
|
||||||
|
|
||||||
|
intArgs = int(objMatch.group(2))
|
||||||
|
strArgs = strKernel[intStart:intStop].split(',')
|
||||||
|
|
||||||
|
assert(intArgs == len(strArgs) - 1)
|
||||||
|
|
||||||
|
strTensor = strArgs[0]
|
||||||
|
intStrides = objVariables[strTensor].stride()
|
||||||
|
|
||||||
|
strIndex = []
|
||||||
|
|
||||||
|
for intArg in range(intArgs):
|
||||||
|
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
|
||||||
|
# end
|
||||||
|
|
||||||
|
strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
|
||||||
|
# end
|
||||||
|
|
||||||
|
while True:
|
||||||
|
objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
|
||||||
|
|
||||||
|
if objMatch is None:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intStart = objMatch.span()[1]
|
||||||
|
intStop = objMatch.span()[1]
|
||||||
|
intParentheses = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
intParentheses += 1 if strKernel[intStop] == '(' else 0
|
||||||
|
intParentheses -= 1 if strKernel[intStop] == ')' else 0
|
||||||
|
|
||||||
|
if intParentheses == 0:
|
||||||
|
break
|
||||||
|
# end
|
||||||
|
|
||||||
|
intStop += 1
|
||||||
|
# end
|
||||||
|
|
||||||
|
intArgs = int(objMatch.group(2))
|
||||||
|
strArgs = strKernel[intStart:intStop].split(',')
|
||||||
|
|
||||||
|
assert(intArgs == len(strArgs) - 1)
|
||||||
|
|
||||||
|
strTensor = strArgs[0]
|
||||||
|
intStrides = objVariables[strTensor].stride()
|
||||||
|
|
||||||
|
strIndex = []
|
||||||
|
|
||||||
|
for intArg in range(intArgs):
|
||||||
|
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
|
||||||
|
# end
|
||||||
|
|
||||||
|
strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
|
||||||
|
# end
|
||||||
|
|
||||||
|
objCudacache[strKey] = {
|
||||||
|
'strFunction': strFunction,
|
||||||
|
'strKernel': strKernel
|
||||||
|
}
|
||||||
|
# end
|
||||||
|
|
||||||
|
return strKey
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
@cupy.memoize(for_each_device=True)
|
||||||
|
def cuda_launch(strKey:str):
|
||||||
|
if 'CUDA_HOME' not in os.environ:
|
||||||
|
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
|
||||||
|
# end
|
||||||
|
|
||||||
|
return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction'],
|
||||||
|
options=tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include']))
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################
|
||||||
|
|
||||||
|
|
||||||
|
def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str):
|
||||||
|
assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
|
||||||
|
|
||||||
|
if strMode == 'sum': assert(tenMetric is None)
|
||||||
|
if strMode == 'avg': assert(tenMetric is None)
|
||||||
|
if strMode.split('-')[0] == 'linear': assert(tenMetric is not None)
|
||||||
|
if strMode.split('-')[0] == 'soft': assert(tenMetric is not None)
|
||||||
|
|
||||||
|
if strMode == 'avg':
|
||||||
|
tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1)
|
||||||
|
|
||||||
|
elif strMode.split('-')[0] == 'linear':
|
||||||
|
tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
|
||||||
|
|
||||||
|
elif strMode.split('-')[0] == 'soft':
|
||||||
|
tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
tenOut = softsplat_func.apply(tenIn, tenFlow)
|
||||||
|
|
||||||
|
if strMode.split('-')[0] in ['avg', 'linear', 'soft']:
|
||||||
|
tenNormalize = tenOut[:, -1:, :, :]
|
||||||
|
|
||||||
|
if len(strMode.split('-')) == 1:
|
||||||
|
tenNormalize = tenNormalize + 0.0000001
|
||||||
|
|
||||||
|
elif strMode.split('-')[1] == 'addeps':
|
||||||
|
tenNormalize = tenNormalize + 0.0000001
|
||||||
|
|
||||||
|
elif strMode.split('-')[1] == 'zeroeps':
|
||||||
|
tenNormalize[tenNormalize == 0.0] = 1.0
|
||||||
|
|
||||||
|
elif strMode.split('-')[1] == 'clipeps':
|
||||||
|
tenNormalize = tenNormalize.clip(0.0000001, None)
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
tenOut = tenOut[:, :-1, :, :] / tenNormalize
|
||||||
|
# end
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
# end
|
||||||
|
|
||||||
|
|
||||||
|
class softsplat_func(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
def forward(self, tenIn, tenFlow):
|
||||||
|
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
|
||||||
|
|
||||||
|
if tenIn.is_cuda == True:
|
||||||
|
cuda_launch(cuda_kernel('softsplat_out', '''
|
||||||
|
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
|
||||||
|
const int n,
|
||||||
|
const {{type}}* __restrict__ tenIn,
|
||||||
|
const {{type}}* __restrict__ tenFlow,
|
||||||
|
{{type}}* __restrict__ tenOut
|
||||||
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||||
|
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
|
||||||
|
const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut);
|
||||||
|
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
|
||||||
|
const int intX = ( intIndex ) % SIZE_3(tenOut);
|
||||||
|
|
||||||
|
assert(SIZE_1(tenFlow) == 2);
|
||||||
|
|
||||||
|
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
||||||
|
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
||||||
|
|
||||||
|
if (isfinite(fltX) == false) { return; }
|
||||||
|
if (isfinite(fltY) == false) { return; }
|
||||||
|
|
||||||
|
{{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
|
||||||
|
|
||||||
|
int intNorthwestX = (int) (floor(fltX));
|
||||||
|
int intNorthwestY = (int) (floor(fltY));
|
||||||
|
int intNortheastX = intNorthwestX + 1;
|
||||||
|
int intNortheastY = intNorthwestY;
|
||||||
|
int intSouthwestX = intNorthwestX;
|
||||||
|
int intSouthwestY = intNorthwestY + 1;
|
||||||
|
int intSoutheastX = intNorthwestX + 1;
|
||||||
|
int intSoutheastY = intNorthwestY + 1;
|
||||||
|
|
||||||
|
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
|
||||||
|
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
|
||||||
|
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
|
||||||
|
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
|
||||||
|
|
||||||
|
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
|
||||||
|
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
|
||||||
|
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
|
||||||
|
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
|
||||||
|
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
|
||||||
|
}
|
||||||
|
} }
|
||||||
|
''', {
|
||||||
|
'tenIn': tenIn,
|
||||||
|
'tenFlow': tenFlow,
|
||||||
|
'tenOut': tenOut
|
||||||
|
}))(
|
||||||
|
grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
|
||||||
|
block=tuple([512, 1, 1]),
|
||||||
|
args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],
|
||||||
|
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif tenIn.is_cuda != True:
|
||||||
|
assert(False)
|
||||||
|
|
||||||
|
# end
|
||||||
|
|
||||||
|
self.save_for_backward(tenIn, tenFlow)
|
||||||
|
|
||||||
|
return tenOut
|
||||||
|
# end
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.cuda.amp.custom_bwd
|
||||||
|
def backward(self, tenOutgrad):
|
||||||
|
tenIn, tenFlow = self.saved_tensors
|
||||||
|
|
||||||
|
tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
|
||||||
|
|
||||||
|
tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None
|
||||||
|
tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None
|
||||||
|
|
||||||
|
if tenIngrad is not None:
|
||||||
|
cuda_launch(cuda_kernel('softsplat_ingrad', '''
|
||||||
|
extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
|
||||||
|
const int n,
|
||||||
|
const {{type}}* __restrict__ tenIn,
|
||||||
|
const {{type}}* __restrict__ tenFlow,
|
||||||
|
const {{type}}* __restrict__ tenOutgrad,
|
||||||
|
{{type}}* __restrict__ tenIngrad,
|
||||||
|
{{type}}* __restrict__ tenFlowgrad
|
||||||
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||||
|
const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
|
||||||
|
const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
|
||||||
|
const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
|
||||||
|
const int intX = ( intIndex ) % SIZE_3(tenIngrad);
|
||||||
|
|
||||||
|
assert(SIZE_1(tenFlow) == 2);
|
||||||
|
|
||||||
|
{{type}} fltIngrad = 0.0f;
|
||||||
|
|
||||||
|
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
||||||
|
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
||||||
|
|
||||||
|
if (isfinite(fltX) == false) { return; }
|
||||||
|
if (isfinite(fltY) == false) { return; }
|
||||||
|
|
||||||
|
int intNorthwestX = (int) (floor(fltX));
|
||||||
|
int intNorthwestY = (int) (floor(fltY));
|
||||||
|
int intNortheastX = intNorthwestX + 1;
|
||||||
|
int intNortheastY = intNorthwestY;
|
||||||
|
int intSouthwestX = intNorthwestX;
|
||||||
|
int intSouthwestY = intNorthwestY + 1;
|
||||||
|
int intSoutheastX = intNorthwestX + 1;
|
||||||
|
int intSoutheastY = intNorthwestY + 1;
|
||||||
|
|
||||||
|
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
|
||||||
|
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
|
||||||
|
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
|
||||||
|
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
|
||||||
|
|
||||||
|
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
|
||||||
|
}
|
||||||
|
|
||||||
|
tenIngrad[intIndex] = fltIngrad;
|
||||||
|
} }
|
||||||
|
''', {
|
||||||
|
'tenIn': tenIn,
|
||||||
|
'tenFlow': tenFlow,
|
||||||
|
'tenOutgrad': tenOutgrad,
|
||||||
|
'tenIngrad': tenIngrad,
|
||||||
|
'tenFlowgrad': tenFlowgrad
|
||||||
|
}))(
|
||||||
|
grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
|
||||||
|
block=tuple([512, 1, 1]),
|
||||||
|
args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None],
|
||||||
|
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||||
|
)
|
||||||
|
# end
|
||||||
|
|
||||||
|
if tenFlowgrad is not None:
|
||||||
|
cuda_launch(cuda_kernel('softsplat_flowgrad', '''
|
||||||
|
extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
|
||||||
|
const int n,
|
||||||
|
const {{type}}* __restrict__ tenIn,
|
||||||
|
const {{type}}* __restrict__ tenFlow,
|
||||||
|
const {{type}}* __restrict__ tenOutgrad,
|
||||||
|
{{type}}* __restrict__ tenIngrad,
|
||||||
|
{{type}}* __restrict__ tenFlowgrad
|
||||||
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||||
|
const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
|
||||||
|
const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
|
||||||
|
const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
|
||||||
|
const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
|
||||||
|
|
||||||
|
assert(SIZE_1(tenFlow) == 2);
|
||||||
|
|
||||||
|
{{type}} fltFlowgrad = 0.0f;
|
||||||
|
|
||||||
|
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
||||||
|
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
||||||
|
|
||||||
|
if (isfinite(fltX) == false) { return; }
|
||||||
|
if (isfinite(fltY) == false) { return; }
|
||||||
|
|
||||||
|
int intNorthwestX = (int) (floor(fltX));
|
||||||
|
int intNorthwestY = (int) (floor(fltY));
|
||||||
|
int intNortheastX = intNorthwestX + 1;
|
||||||
|
int intNortheastY = intNorthwestY;
|
||||||
|
int intSouthwestX = intNorthwestX;
|
||||||
|
int intSouthwestY = intNorthwestY + 1;
|
||||||
|
int intSoutheastX = intNorthwestX + 1;
|
||||||
|
int intSoutheastY = intNorthwestY + 1;
|
||||||
|
|
||||||
|
{{type}} fltNorthwest = 0.0f;
|
||||||
|
{{type}} fltNortheast = 0.0f;
|
||||||
|
{{type}} fltSouthwest = 0.0f;
|
||||||
|
{{type}} fltSoutheast = 0.0f;
|
||||||
|
|
||||||
|
if (intC == 0) {
|
||||||
|
fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
|
||||||
|
fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
|
||||||
|
fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
|
||||||
|
fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
|
||||||
|
|
||||||
|
} else if (intC == 1) {
|
||||||
|
fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
|
||||||
|
fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
|
||||||
|
fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
|
||||||
|
fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
|
||||||
|
{{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
|
||||||
|
|
||||||
|
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
|
||||||
|
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tenFlowgrad[intIndex] = fltFlowgrad;
|
||||||
|
} }
|
||||||
|
''', {
|
||||||
|
'tenIn': tenIn,
|
||||||
|
'tenFlow': tenFlow,
|
||||||
|
'tenOutgrad': tenOutgrad,
|
||||||
|
'tenIngrad': tenIngrad,
|
||||||
|
'tenFlowgrad': tenFlowgrad
|
||||||
|
}))(
|
||||||
|
grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
|
||||||
|
block=tuple([512, 1, 1]),
|
||||||
|
args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()],
|
||||||
|
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
|
||||||
|
)
|
||||||
|
# end
|
||||||
|
|
||||||
|
return tenIngrad, tenFlowgrad
|
||||||
|
# end
|
||||||
|
# end
|
||||||
450
sgm_vfi_arch/transformer.py
Normal file
450
sgm_vfi_arch/transformer.py
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
def split_feature(feature,
|
||||||
|
num_splits=2,
|
||||||
|
channel_last=False,
|
||||||
|
):
|
||||||
|
if channel_last: # [B, H, W, C]
|
||||||
|
b, h, w, c = feature.size()
|
||||||
|
assert h % num_splits == 0 and w % num_splits == 0
|
||||||
|
|
||||||
|
b_new = b * num_splits * num_splits
|
||||||
|
h_new = h // num_splits
|
||||||
|
w_new = w // num_splits
|
||||||
|
|
||||||
|
feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
|
||||||
|
).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
|
||||||
|
else: # [B, C, H, W]
|
||||||
|
b, c, h, w = feature.size()
|
||||||
|
assert h % num_splits == 0 and w % num_splits == 0, f'h: {h}, w: {w}, num_splits: {num_splits}'
|
||||||
|
|
||||||
|
b_new = b * num_splits * num_splits
|
||||||
|
h_new = h // num_splits
|
||||||
|
w_new = w // num_splits
|
||||||
|
|
||||||
|
feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
|
||||||
|
).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
|
||||||
|
|
||||||
|
return feature
|
||||||
|
|
||||||
|
|
||||||
|
def merge_splits(splits,
|
||||||
|
num_splits=2,
|
||||||
|
channel_last=False,
|
||||||
|
):
|
||||||
|
if channel_last: # [B*K*K, H/K, W/K, C]
|
||||||
|
b, h, w, c = splits.size()
|
||||||
|
new_b = b // num_splits // num_splits
|
||||||
|
|
||||||
|
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
|
||||||
|
merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
|
||||||
|
new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
|
||||||
|
else: # [B*K*K, C, H/K, W/K]
|
||||||
|
b, c, h, w = splits.size()
|
||||||
|
new_b = b // num_splits // num_splits
|
||||||
|
|
||||||
|
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
|
||||||
|
merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
|
||||||
|
new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
|
||||||
|
|
||||||
|
return merge
|
||||||
|
|
||||||
|
|
||||||
|
def single_head_full_attention(q, k, v):
|
||||||
|
# q, k, v: [B, L, C]
|
||||||
|
assert q.dim() == k.dim() == v.dim() == 3
|
||||||
|
|
||||||
|
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
|
||||||
|
attn = torch.softmax(scores, dim=2) # [B, L, L]
|
||||||
|
out = torch.matmul(attn, v) # [B, L, C]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w,
|
||||||
|
shift_size_h, shift_size_w, device=None):
|
||||||
|
# Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
||||||
|
# calculate attention mask for SW-MSA
|
||||||
|
h, w = input_resolution
|
||||||
|
img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
|
||||||
|
h_slices = (slice(0, -window_size_h),
|
||||||
|
slice(-window_size_h, -shift_size_h),
|
||||||
|
slice(-shift_size_h, None))
|
||||||
|
w_slices = (slice(0, -window_size_w),
|
||||||
|
slice(-window_size_w, -shift_size_w),
|
||||||
|
slice(-shift_size_w, None))
|
||||||
|
cnt = 0
|
||||||
|
for h in h_slices:
|
||||||
|
for w in w_slices:
|
||||||
|
img_mask[:, h, w, :] = cnt
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True)
|
||||||
|
|
||||||
|
mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
|
||||||
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||||
|
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
|
def single_head_split_window_attention(q, k, v,
|
||||||
|
num_splits=1,
|
||||||
|
with_shift=False,
|
||||||
|
h=None,
|
||||||
|
w=None,
|
||||||
|
attn_mask=None,
|
||||||
|
):
|
||||||
|
# Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
||||||
|
# q, k, v: [B, L, C]
|
||||||
|
assert q.dim() == k.dim() == v.dim() == 3
|
||||||
|
|
||||||
|
assert h is not None and w is not None
|
||||||
|
assert q.size(1) == h * w
|
||||||
|
|
||||||
|
b, _, c = q.size()
|
||||||
|
|
||||||
|
b_new = b * num_splits * num_splits
|
||||||
|
|
||||||
|
window_size_h = h // num_splits
|
||||||
|
window_size_w = w // num_splits
|
||||||
|
|
||||||
|
q = q.view(b, h, w, c) # [B, H, W, C]
|
||||||
|
k = k.view(b, h, w, c)
|
||||||
|
v = v.view(b, h, w, c)
|
||||||
|
|
||||||
|
scale_factor = c ** 0.5
|
||||||
|
|
||||||
|
if with_shift:
|
||||||
|
assert attn_mask is not None # compute once
|
||||||
|
shift_size_h = window_size_h // 2
|
||||||
|
shift_size_w = window_size_w // 2
|
||||||
|
|
||||||
|
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
||||||
|
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
||||||
|
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
||||||
|
|
||||||
|
q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
|
||||||
|
k = split_feature(k, num_splits=num_splits, channel_last=True)
|
||||||
|
v = split_feature(v, num_splits=num_splits, channel_last=True)
|
||||||
|
|
||||||
|
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
|
||||||
|
) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
|
||||||
|
|
||||||
|
if with_shift:
|
||||||
|
scores += attn_mask.repeat(b, 1, 1)
|
||||||
|
|
||||||
|
attn = torch.softmax(scores, dim=-1)
|
||||||
|
|
||||||
|
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
|
||||||
|
|
||||||
|
out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
|
||||||
|
num_splits=num_splits, channel_last=True) # [B, H, W, C]
|
||||||
|
|
||||||
|
# shift back
|
||||||
|
if with_shift:
|
||||||
|
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
|
||||||
|
|
||||||
|
out = out.view(b, -1, c)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
d_model=256,
|
||||||
|
nhead=1,
|
||||||
|
attention_type='swin',
|
||||||
|
no_ffn=False,
|
||||||
|
ffn_dim_expansion=4,
|
||||||
|
with_shift=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(TransformerLayer, self).__init__()
|
||||||
|
|
||||||
|
self.dim = d_model
|
||||||
|
self.nhead = nhead
|
||||||
|
self.attention_type = attention_type
|
||||||
|
self.no_ffn = no_ffn
|
||||||
|
|
||||||
|
self.with_shift = with_shift
|
||||||
|
|
||||||
|
# multi-head attention
|
||||||
|
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
||||||
|
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
||||||
|
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
||||||
|
|
||||||
|
self.merge = nn.Linear(d_model, d_model, bias=False)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
# no ffn after self-attn, with ffn after cross-attn
|
||||||
|
if not self.no_ffn:
|
||||||
|
in_channels = d_model * 2
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, source, target,
|
||||||
|
height=None,
|
||||||
|
width=None,
|
||||||
|
shifted_window_attn_mask=None,
|
||||||
|
attn_num_splits=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# source, target: [B, L, C]
|
||||||
|
query, key, value = source, target, target
|
||||||
|
|
||||||
|
# single-head attention
|
||||||
|
query = self.q_proj(query) # [B, L, C]
|
||||||
|
key = self.k_proj(key) # [B, L, C]
|
||||||
|
value = self.v_proj(value) # [B, L, C]
|
||||||
|
|
||||||
|
if self.attention_type == 'swin' and attn_num_splits > 1:
|
||||||
|
if self.nhead > 1:
|
||||||
|
# we observe that multihead attention slows down the speed and increases the memory consumption
|
||||||
|
# without bringing obvious performance gains and thus the implementation is removed
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
message = single_head_split_window_attention(query, key, value,
|
||||||
|
num_splits=attn_num_splits,
|
||||||
|
with_shift=self.with_shift,
|
||||||
|
h=height,
|
||||||
|
w=width,
|
||||||
|
attn_mask=shifted_window_attn_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message = single_head_full_attention(query, key, value) # [B, L, C]
|
||||||
|
|
||||||
|
message = self.merge(message) # [B, L, C]
|
||||||
|
message = self.norm1(message)
|
||||||
|
|
||||||
|
if not self.no_ffn:
|
||||||
|
message = self.mlp(torch.cat([source, message], dim=-1))
|
||||||
|
message = self.norm2(message)
|
||||||
|
|
||||||
|
return source + message
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
"""self attention + cross attention + FFN"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
d_model=256,
|
||||||
|
nhead=1,
|
||||||
|
attention_type='swin',
|
||||||
|
ffn_dim_expansion=4,
|
||||||
|
with_shift=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(TransformerBlock, self).__init__()
|
||||||
|
|
||||||
|
self.self_attn = TransformerLayer(d_model=d_model,
|
||||||
|
nhead=nhead,
|
||||||
|
attention_type=attention_type,
|
||||||
|
no_ffn=True,
|
||||||
|
ffn_dim_expansion=ffn_dim_expansion,
|
||||||
|
with_shift=with_shift,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cross_attn_ffn = TransformerLayer(d_model=d_model,
|
||||||
|
nhead=nhead,
|
||||||
|
attention_type=attention_type,
|
||||||
|
ffn_dim_expansion=ffn_dim_expansion,
|
||||||
|
with_shift=with_shift,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, source, target,
|
||||||
|
height=None,
|
||||||
|
width=None,
|
||||||
|
shifted_window_attn_mask=None,
|
||||||
|
attn_num_splits=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# source, target: [B, L, C]
|
||||||
|
|
||||||
|
# self attention
|
||||||
|
source = self.self_attn(source, source,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
shifted_window_attn_mask=shifted_window_attn_mask,
|
||||||
|
attn_num_splits=attn_num_splits,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cross attention and ffn
|
||||||
|
source = self.cross_attn_ffn(source, target,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
shifted_window_attn_mask=shifted_window_attn_mask,
|
||||||
|
attn_num_splits=attn_num_splits,
|
||||||
|
)
|
||||||
|
|
||||||
|
return source
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureTransformer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_layers=6,
|
||||||
|
d_model=128,
|
||||||
|
nhead=1,
|
||||||
|
attention_type='swin',
|
||||||
|
ffn_dim_expansion=4,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(FeatureTransformer, self).__init__()
|
||||||
|
|
||||||
|
self.attention_type = attention_type
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.nhead = nhead
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
TransformerBlock(d_model=d_model,
|
||||||
|
nhead=nhead,
|
||||||
|
attention_type=attention_type,
|
||||||
|
ffn_dim_expansion=ffn_dim_expansion,
|
||||||
|
with_shift=True if attention_type == 'swin' and i % 2 == 1 else False,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)])
|
||||||
|
|
||||||
|
for p in self.parameters():
|
||||||
|
if p.dim() > 1:
|
||||||
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
|
def forward(self, feature0, feature1,
|
||||||
|
attn_num_splits=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
b, c, h, w = feature0.shape
|
||||||
|
assert self.d_model == c
|
||||||
|
|
||||||
|
feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
||||||
|
feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
||||||
|
|
||||||
|
if self.attention_type == 'swin' and attn_num_splits > 1:
|
||||||
|
# global and refine use different number of splits
|
||||||
|
window_size_h = h // attn_num_splits
|
||||||
|
window_size_w = w // attn_num_splits
|
||||||
|
|
||||||
|
# compute attn mask once
|
||||||
|
shifted_window_attn_mask = generate_shift_window_attn_mask(
|
||||||
|
input_resolution=(h, w),
|
||||||
|
window_size_h=window_size_h,
|
||||||
|
window_size_w=window_size_w,
|
||||||
|
shift_size_h=window_size_h // 2,
|
||||||
|
shift_size_w=window_size_w // 2,
|
||||||
|
device=feature0.device,
|
||||||
|
) # [K*K, H/K*W/K, H/K*W/K]
|
||||||
|
else:
|
||||||
|
shifted_window_attn_mask = None
|
||||||
|
|
||||||
|
# concat feature0 and feature1 in batch dimension to compute in parallel
|
||||||
|
concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
|
||||||
|
concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
concat0 = layer(concat0, concat1,
|
||||||
|
height=h,
|
||||||
|
width=w,
|
||||||
|
shifted_window_attn_mask=shifted_window_attn_mask,
|
||||||
|
attn_num_splits=attn_num_splits,
|
||||||
|
)
|
||||||
|
|
||||||
|
# update feature1
|
||||||
|
concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
|
||||||
|
|
||||||
|
feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
|
||||||
|
|
||||||
|
# reshape back
|
||||||
|
feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
|
||||||
|
feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
|
||||||
|
|
||||||
|
return feature0, feature1
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureFlowAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
flow propagation with self-attention on feature
|
||||||
|
query: feature0, key: feature0, value: flow
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(FeatureFlowAttention, self).__init__()
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(in_channels, in_channels)
|
||||||
|
self.k_proj = nn.Linear(in_channels, in_channels)
|
||||||
|
|
||||||
|
for p in self.parameters():
|
||||||
|
if p.dim() > 1:
|
||||||
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
|
def forward(self, feature0, flow,
|
||||||
|
local_window_attn=False,
|
||||||
|
local_window_radius=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
|
||||||
|
if local_window_attn:
|
||||||
|
return self.forward_local_window_attn(feature0, flow,
|
||||||
|
local_window_radius=local_window_radius)
|
||||||
|
|
||||||
|
b, c, h, w = feature0.size()
|
||||||
|
|
||||||
|
query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
|
||||||
|
|
||||||
|
query = self.q_proj(query) # [B, H*W, C]
|
||||||
|
key = self.k_proj(query) # [B, H*W, C]
|
||||||
|
|
||||||
|
value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
|
||||||
|
|
||||||
|
scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W]
|
||||||
|
prob = torch.softmax(scores, dim=-1)
|
||||||
|
|
||||||
|
out = torch.matmul(prob, value) # [B, H*W, 2]
|
||||||
|
out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward_local_window_attn(self, feature0, flow,
|
||||||
|
local_window_radius=1,
|
||||||
|
):
|
||||||
|
assert flow.size(1) == 2
|
||||||
|
assert local_window_radius > 0
|
||||||
|
|
||||||
|
b, c, h, w = feature0.size()
|
||||||
|
|
||||||
|
feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
|
||||||
|
).reshape(b * h * w, 1, c) # [B*H*W, 1, C]
|
||||||
|
|
||||||
|
kernel_size = 2 * local_window_radius + 1
|
||||||
|
|
||||||
|
feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
|
||||||
|
|
||||||
|
feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
|
||||||
|
padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
|
||||||
|
|
||||||
|
feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
|
||||||
|
0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
|
||||||
|
|
||||||
|
flow_window = F.unfold(flow, kernel_size=kernel_size,
|
||||||
|
padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
|
||||||
|
|
||||||
|
flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute(
|
||||||
|
0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2]
|
||||||
|
|
||||||
|
scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]
|
||||||
|
|
||||||
|
prob = torch.softmax(scores, dim=-1)
|
||||||
|
|
||||||
|
out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
|
||||||
|
|
||||||
|
return out
|
||||||
90
sgm_vfi_arch/trident_conv.py
Normal file
90
sgm_vfi_arch/trident_conv.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.modules.utils import _pair
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleTridentConv(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
strides=1,
|
||||||
|
paddings=0,
|
||||||
|
dilations=1,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
num_branch=1,
|
||||||
|
test_branch_idx=-1,
|
||||||
|
bias=False,
|
||||||
|
norm=None,
|
||||||
|
activation=None,
|
||||||
|
):
|
||||||
|
super(MultiScaleTridentConv, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = _pair(kernel_size)
|
||||||
|
self.num_branch = num_branch
|
||||||
|
self.stride = _pair(stride)
|
||||||
|
self.groups = groups
|
||||||
|
self.with_bias = bias
|
||||||
|
self.dilation = dilation
|
||||||
|
if isinstance(paddings, int):
|
||||||
|
paddings = [paddings] * self.num_branch
|
||||||
|
if isinstance(dilations, int):
|
||||||
|
dilations = [dilations] * self.num_branch
|
||||||
|
if isinstance(strides, int):
|
||||||
|
strides = [strides] * self.num_branch
|
||||||
|
self.paddings = [_pair(padding) for padding in paddings]
|
||||||
|
self.dilations = [_pair(dilation) for dilation in dilations]
|
||||||
|
self.strides = [_pair(stride) for stride in strides]
|
||||||
|
self.test_branch_idx = test_branch_idx
|
||||||
|
self.norm = norm
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
|
||||||
|
if self.bias is not None:
|
||||||
|
nn.init.constant_(self.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
|
||||||
|
assert len(inputs) == num_branch
|
||||||
|
|
||||||
|
if self.training or self.test_branch_idx == -1:
|
||||||
|
outputs = [
|
||||||
|
F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
|
||||||
|
for input, stride, padding in zip(inputs, self.strides, self.paddings)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
outputs = [
|
||||||
|
F.conv2d(
|
||||||
|
inputs[0],
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
|
||||||
|
self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.norm is not None:
|
||||||
|
outputs = [self.norm(x) for x in outputs]
|
||||||
|
if self.activation is not None:
|
||||||
|
outputs = [self.activation(x) for x in outputs]
|
||||||
|
return outputs
|
||||||
98
sgm_vfi_arch/utils.py
Normal file
98
sgm_vfi_arch/utils.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .position import PositionEmbeddingSine
|
||||||
|
from .geometry import coords_grid, generate_window_grid, normalize_coords
|
||||||
|
|
||||||
|
|
||||||
|
def split_feature(feature,
|
||||||
|
num_splits=2,
|
||||||
|
channel_last=False,
|
||||||
|
):
|
||||||
|
if channel_last: # [B, H, W, C]
|
||||||
|
b, h, w, c = feature.size()
|
||||||
|
assert h % num_splits == 0 and w % num_splits == 0
|
||||||
|
|
||||||
|
b_new = b * num_splits * num_splits
|
||||||
|
h_new = h // num_splits
|
||||||
|
w_new = w // num_splits
|
||||||
|
|
||||||
|
feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
|
||||||
|
).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
|
||||||
|
else: # [B, C, H, W]
|
||||||
|
b, c, h, w = feature.size()
|
||||||
|
assert h % num_splits == 0 and w % num_splits == 0
|
||||||
|
|
||||||
|
b_new = b * num_splits * num_splits
|
||||||
|
h_new = h // num_splits
|
||||||
|
w_new = w // num_splits
|
||||||
|
|
||||||
|
feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
|
||||||
|
).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
|
||||||
|
|
||||||
|
return feature
|
||||||
|
|
||||||
|
def merge_splits(splits,
|
||||||
|
num_splits=2,
|
||||||
|
channel_last=False,
|
||||||
|
):
|
||||||
|
if channel_last: # [B*K*K, H/K, W/K, C]
|
||||||
|
b, h, w, c = splits.size()
|
||||||
|
new_b = b // num_splits // num_splits
|
||||||
|
|
||||||
|
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
|
||||||
|
merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
|
||||||
|
new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
|
||||||
|
else: # [B*K*K, C, H/K, W/K]
|
||||||
|
b, c, h, w = splits.size()
|
||||||
|
new_b = b // num_splits // num_splits
|
||||||
|
|
||||||
|
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
|
||||||
|
merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
|
||||||
|
new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
|
||||||
|
|
||||||
|
return merge
|
||||||
|
|
||||||
|
|
||||||
|
def feature_add_position(feature0, feature1, attn_splits, feature_channels):
|
||||||
|
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
|
||||||
|
|
||||||
|
if attn_splits > 1: # add position in splited window
|
||||||
|
feature0_splits = split_feature(feature0, num_splits=attn_splits)
|
||||||
|
feature1_splits = split_feature(feature1, num_splits=attn_splits)
|
||||||
|
|
||||||
|
position = pos_enc(feature0_splits)
|
||||||
|
|
||||||
|
feature0_splits = feature0_splits + position
|
||||||
|
feature1_splits = feature1_splits + position
|
||||||
|
|
||||||
|
feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
|
||||||
|
feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
|
||||||
|
else:
|
||||||
|
position = pos_enc(feature0)
|
||||||
|
|
||||||
|
feature0 = feature0 + position
|
||||||
|
feature1 = feature1 + position
|
||||||
|
|
||||||
|
return feature0, feature1
|
||||||
|
|
||||||
|
|
||||||
|
class InputPadder:
|
||||||
|
""" Pads images such that dimensions are divisible by 8 """
|
||||||
|
|
||||||
|
def __init__(self, dims, mode='sintel', padding_factor=8, additional_pad=False):
|
||||||
|
self.ht, self.wd = dims[-2:]
|
||||||
|
add_pad = padding_factor*2 if additional_pad else 0
|
||||||
|
pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor + add_pad
|
||||||
|
pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor + add_pad
|
||||||
|
if mode == 'sintel':
|
||||||
|
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
|
||||||
|
else:
|
||||||
|
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
|
||||||
|
|
||||||
|
def pad(self, *inputs):
|
||||||
|
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
||||||
|
|
||||||
|
def unpad(self, x):
|
||||||
|
ht, wd = x.shape[-2:]
|
||||||
|
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||||
|
return x[..., c[0]:c[1], c[2]:c[3]]
|
||||||
25
sgm_vfi_arch/warplayer.py
Normal file
25
sgm_vfi_arch/warplayer.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
backwarp_tenGrid = {}
|
||||||
|
|
||||||
|
|
||||||
|
def clear_warp_cache():
|
||||||
|
"""Free all cached grid tensors (call between frame pairs to reclaim VRAM)."""
|
||||||
|
backwarp_tenGrid.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def warp(tenInput, tenFlow):
|
||||||
|
k = (str(tenFlow.device), str(tenFlow.size()))
|
||||||
|
if k not in backwarp_tenGrid:
|
||||||
|
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device).view(
|
||||||
|
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
||||||
|
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device).view(
|
||||||
|
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
||||||
|
backwarp_tenGrid[k] = torch.cat(
|
||||||
|
[tenHorizontal, tenVertical], 1).to(tenFlow.device)
|
||||||
|
|
||||||
|
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
||||||
|
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
||||||
|
|
||||||
|
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
||||||
|
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
||||||
Reference in New Issue
Block a user