Add SGM-VFI (CVPR 2024) frame interpolation support

SGM-VFI combines local flow estimation with sparse global matching
(GMFlow) to handle large motion and occlusion-heavy scenes. Adds 3 new
nodes: Load SGM-VFI Model, SGM-VFI Interpolate, SGM-VFI Segment
Interpolate. Architecture files vendored from MCG-NJU/SGM-VFI with
device-awareness fixes (no hardcoded .cuda()), relative imports, and
debug code removed. README updated with model comparison table.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-12 23:02:48 +01:00
parent 1de086569c
commit 42ebdd8b96
18 changed files with 3132 additions and 7 deletions

View File

@@ -1,6 +1,21 @@
# ComfyUI BIM-VFI + EMA-VFI # ComfyUI BIM-VFI + EMA-VFI + SGM-VFI
ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025) and [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023). Designed for long videos with thousands of frames — processes them without running out of VRAM. ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), and [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024). Designed for long videos with thousands of frames — processes them without running out of VRAM.
## Which model should I use?
| | BIM-VFI | EMA-VFI | SGM-VFI |
|---|---------|---------|---------|
| **Best for** | General-purpose, non-uniform motion | Fast inference, light VRAM | Large motion, occlusion-heavy scenes |
| **Quality** | Highest overall | Good | Best on large motion |
| **Speed** | Moderate | Fastest | Slowest |
| **VRAM** | ~2 GB/pair | ~1.5 GB/pair | ~3 GB/pair |
| **Params** | ~17M | ~1465M | ~15M + GMFlow |
| **Arbitrary timestep** | Yes | Yes (with `_t` checkpoint) | No (fixed 0.5) |
| **Paper** | CVPR 2025 | CVPR 2023 | CVPR 2024 |
| **License** | Research only | Apache 2.0 | Apache 2.0 |
**TL;DR:** Start with **BIM-VFI** for best quality. Use **EMA-VFI** if you need speed or lower VRAM. Use **SGM-VFI** if your video has large camera motion or fast-moving objects that the others struggle with.
## Nodes ## Nodes
@@ -66,7 +81,32 @@ Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate.
Same as EMA-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate. Same as EMA-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
**Output frame count (both models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7 ### SGM-VFI
#### Load SGM-VFI Model
Loads an SGM-VFI checkpoint. Auto-downloads from Google Drive on first use to `ComfyUI/models/sgm-vfi/`. Variant (base/small) is auto-detected from the filename (default is small).
| Input | Description |
|-------|-------------|
| **model_path** | Checkpoint file from `models/sgm-vfi/` |
| **tta** | Test-time augmentation: flip input and average with unflipped result (~2x slower, slightly better quality) |
| **num_key_points** | Sparsity of global matching (0.0 = global everywhere, 0.5 = default balance, higher = faster) |
Available checkpoints:
| Checkpoint | Variant | Params |
|-----------|---------|--------|
| `ours-1-2-points.pth` | Small | ~15M + GMFlow |
#### SGM-VFI Interpolate
Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate.
#### SGM-VFI Segment Interpolate
Same as SGM-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate.
**Output frame count (all models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7
## Installation ## Installation
@@ -94,8 +134,8 @@ python install.py
### Requirements ### Requirements
- PyTorch with CUDA - PyTorch with CUDA
- `cupy` (matching your CUDA version, for BIM-VFI) - `cupy` (matching your CUDA version, for BIM-VFI and SGM-VFI)
- `timm` (for EMA-VFI) - `timm` (for EMA-VFI and SGM-VFI)
- `gdown` (for model auto-download) - `gdown` (for model auto-download)
## VRAM Guide ## VRAM Guide
@@ -109,7 +149,7 @@ python install.py
## Acknowledgments ## Acknowledgments
This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab) and the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU. Architecture files in `bim_vfi_arch/` and `ema_vfi_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, inference-only paths). This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, and the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU. Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, and `sgm_vfi_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, inference-only paths).
**BiM-VFI:** **BiM-VFI:**
> Wonyong Seo, Jihyong Oh, and Munchurl Kim. > Wonyong Seo, Jihyong Oh, and Munchurl Kim.
@@ -141,8 +181,25 @@ This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VF
} }
``` ```
**SGM-VFI:**
> Guozhen Zhang, Yuhan Zhu, Evan Zheran Liu, Haonan Wang, Mingzhen Sun, Gangshan Wu, and Limin Wang.
> "Sparse Global Matching for Video Frame Interpolation with Large Motion."
> *IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2024.
> [[arXiv]](https://arxiv.org/abs/2404.06913) [[GitHub]](https://github.com/MCG-NJU/SGM-VFI)
```bibtex
@inproceedings{zhang2024sgmvfi,
title={Sparse Global Matching for Video Frame Interpolation with Large Motion},
author={Zhang, Guozhen and Zhu, Yuhan and Liu, Evan Zheran and Wang, Haonan and Sun, Mingzhen and Wu, Gangshan and Wang, Limin},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2024}
}
```
## License ## License
The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details. The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details.
The EMA-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/EMA-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/EMA-VFI) for details. The EMA-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/EMA-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/EMA-VFI) for details.
The SGM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/SGM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/SGM-VFI) for details.

View File

@@ -40,6 +40,7 @@ _auto_install_deps()
from .nodes import ( from .nodes import (
LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, BIMVFIConcatVideos, LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, BIMVFIConcatVideos,
LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate, LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate,
LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate,
) )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
@@ -50,6 +51,9 @@ NODE_CLASS_MAPPINGS = {
"LoadEMAVFIModel": LoadEMAVFIModel, "LoadEMAVFIModel": LoadEMAVFIModel,
"EMAVFIInterpolate": EMAVFIInterpolate, "EMAVFIInterpolate": EMAVFIInterpolate,
"EMAVFISegmentInterpolate": EMAVFISegmentInterpolate, "EMAVFISegmentInterpolate": EMAVFISegmentInterpolate,
"LoadSGMVFIModel": LoadSGMVFIModel,
"SGMVFIInterpolate": SGMVFIInterpolate,
"SGMVFISegmentInterpolate": SGMVFISegmentInterpolate,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@@ -60,4 +64,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoadEMAVFIModel": "Load EMA-VFI Model", "LoadEMAVFIModel": "Load EMA-VFI Model",
"EMAVFIInterpolate": "EMA-VFI Interpolate", "EMAVFIInterpolate": "EMA-VFI Interpolate",
"EMAVFISegmentInterpolate": "EMA-VFI Segment Interpolate", "EMAVFISegmentInterpolate": "EMA-VFI Segment Interpolate",
"LoadSGMVFIModel": "Load SGM-VFI Model",
"SGMVFIInterpolate": "SGM-VFI Interpolate",
"SGMVFISegmentInterpolate": "SGM-VFI Segment Interpolate",
} }

View File

@@ -7,6 +7,8 @@ import torch.nn as nn
from .bim_vfi_arch import BiMVFI from .bim_vfi_arch import BiMVFI
from .ema_vfi_arch import feature_extractor as ema_feature_extractor from .ema_vfi_arch import feature_extractor as ema_feature_extractor
from .ema_vfi_arch import MultiScaleFlow as EMAMultiScaleFlow from .ema_vfi_arch import MultiScaleFlow as EMAMultiScaleFlow
from .sgm_vfi_arch import feature_extractor as sgm_feature_extractor
from .sgm_vfi_arch import MultiScaleFlow as SGMMultiScaleFlow
from .utils.padder import InputPadder from .utils.padder import InputPadder
logger = logging.getLogger("BIM-VFI") logger = logging.getLogger("BIM-VFI")
@@ -282,3 +284,160 @@ class EMAVFIModel:
pred = self._inference(img0, img1, timestep=time_step) pred = self._inference(img0, img1, timestep=time_step)
pred = padder.unpad(pred) pred = padder.unpad(pred)
return torch.clamp(pred, 0, 1) return torch.clamp(pred, 0, 1)
# ---------------------------------------------------------------------------
# SGM-VFI model wrapper
# ---------------------------------------------------------------------------
def _sgm_init_model_config(F=16, W=7, depth=[2, 2, 2, 4], num_key_points=0.5):
"""Build SGM-VFI model config dicts (backbone + multiscale)."""
return {
'embed_dims': [F, 2*F, 4*F, 8*F],
'num_heads': [8*F//32],
'mlp_ratios': [4],
'qkv_bias': True,
'norm_layer': partial(nn.LayerNorm, eps=1e-6),
'depths': depth,
'window_sizes': [W]
}, {
'embed_dims': [F, 2*F, 4*F, 8*F],
'motion_dims': [0, 0, 0, 8*F//depth[-1]],
'depths': depth,
'scales': [8],
'hidden_dims': [4*F],
'c': F,
'num_key_points': num_key_points,
}
def _sgm_detect_variant(filename):
"""Auto-detect SGM-VFI model variant from filename.
Returns (F, depth).
Default is small (F=16) since the primary checkpoint (ours-1-2-points)
is a small model. Only detect base when "base" is in the filename.
"""
name = filename.lower()
is_base = "base" in name
if is_base:
return 32, [2, 2, 2, 6]
else:
return 16, [2, 2, 2, 4]
class SGMVFIModel:
"""Clean inference wrapper around SGM-VFI for ComfyUI integration."""
def __init__(self, checkpoint_path, variant="auto", num_key_points=0.5, tta=False, device="cpu"):
import os
filename = os.path.basename(checkpoint_path)
if variant == "auto":
F_dim, depth = _sgm_detect_variant(filename)
elif variant == "small":
F_dim, depth = 16, [2, 2, 2, 4]
else: # base
F_dim, depth = 32, [2, 2, 2, 6]
self.tta = tta
self.device = device
self.variant_name = "small" if F_dim == 16 else "base"
backbone_cfg, multiscale_cfg = _sgm_init_model_config(
F=F_dim, depth=depth, num_key_points=num_key_points)
backbone = sgm_feature_extractor(**backbone_cfg)
self.model = SGMMultiScaleFlow(backbone, **multiscale_cfg)
self._load_checkpoint(checkpoint_path)
self.model.eval()
self.model.to(device)
def _load_checkpoint(self, checkpoint_path):
"""Load checkpoint with module prefix stripping and buffer filtering."""
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
# Handle wrapped checkpoint formats
if isinstance(state_dict, dict):
if "model" in state_dict:
state_dict = state_dict["model"]
elif "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
# Strip "module." prefix and filter out attn_mask/HW buffers
cleaned = {}
for k, v in state_dict.items():
if "attn_mask" in k or k.endswith(".HW"):
continue
key = k
if key.startswith("module."):
key = key[len("module."):]
cleaned[key] = v
self.model.load_state_dict(cleaned, strict=False)
def to(self, device):
"""Move model to device (returns self for chaining)."""
self.device = device
self.model.to(device)
return self
@torch.no_grad()
def _inference(self, img0, img1, timestep=0.5):
"""Run single inference pass. Inputs already padded, on device."""
B = img0.shape[0]
imgs = torch.cat((img0, img1), 1)
if self.tta:
imgs_ = imgs.flip(2).flip(3)
input_batch = torch.cat((imgs, imgs_), 0)
_, _, _, preds, _ = self.model(input_batch, timestep=timestep)
return (preds[:B] + preds[B:].flip(2).flip(3)) / 2.
else:
_, _, _, pred, _ = self.model(imgs, timestep=timestep)
return pred
@torch.no_grad()
def interpolate_pair(self, frame0, frame1, time_step=0.5):
"""Interpolate a single frame between two input frames.
Args:
frame0: [1, C, H, W] tensor, float32, range [0, 1]
frame1: [1, C, H, W] tensor, float32, range [0, 1]
time_step: float in (0, 1)
Returns:
Interpolated frame as [1, C, H, W] tensor, float32, clamped to [0, 1]
"""
device = next(self.model.parameters()).device
img0 = frame0.to(device)
img1 = frame1.to(device)
padder = InputPadder(img0.shape, divisor=32, mode='replicate', center=True)
img0, img1 = padder.pad(img0, img1)
pred = self._inference(img0, img1, timestep=time_step)
pred = padder.unpad(pred)
return torch.clamp(pred, 0, 1)
@torch.no_grad()
def interpolate_batch(self, frames0, frames1, time_step=0.5):
"""Interpolate multiple frame pairs at once.
Args:
frames0: [B, C, H, W] tensor, float32, range [0, 1]
frames1: [B, C, H, W] tensor, float32, range [0, 1]
time_step: float in (0, 1)
Returns:
Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1]
"""
device = next(self.model.parameters()).device
img0 = frames0.to(device)
img1 = frames1.to(device)
padder = InputPadder(img0.shape, divisor=32, mode='replicate', center=True)
img0, img1 = padder.pad(img0, img1)
pred = self._inference(img0, img1, timestep=time_step)
pred = padder.unpad(pred)
return torch.clamp(pred, 0, 1)

318
nodes.py
View File

@@ -8,9 +8,10 @@ import torch
import folder_paths import folder_paths
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
from .inference import BiMVFIModel, EMAVFIModel from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel
from .bim_vfi_arch import clear_backwarp_cache from .bim_vfi_arch import clear_backwarp_cache
from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache
from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache
logger = logging.getLogger("BIM-VFI") logger = logging.getLogger("BIM-VFI")
@@ -31,6 +32,14 @@ EMA_MODEL_DIR = os.path.join(folder_paths.models_dir, "ema-vfi")
if not os.path.exists(EMA_MODEL_DIR): if not os.path.exists(EMA_MODEL_DIR):
os.makedirs(EMA_MODEL_DIR, exist_ok=True) os.makedirs(EMA_MODEL_DIR, exist_ok=True)
# Google Drive folder ID for SGM-VFI pretrained models
SGM_GDRIVE_FOLDER_ID = "1S5O6W0a7XQDHgBtP9HnmoxYEzWBIzSJq"
SGM_DEFAULT_MODEL = "ours-1-2-points.pth"
SGM_MODEL_DIR = os.path.join(folder_paths.models_dir, "sgm-vfi")
if not os.path.exists(SGM_MODEL_DIR):
os.makedirs(SGM_MODEL_DIR, exist_ok=True)
def get_available_models(): def get_available_models():
"""List available checkpoint files in the bim-vfi model directory.""" """List available checkpoint files in the bim-vfi model directory."""
@@ -767,3 +776,310 @@ class EMAVFISegmentInterpolate(EMAVFIInterpolate):
result = result[1:] # skip duplicate boundary frame result = result[1:] # skip duplicate boundary frame
return (result, model) return (result, model)
# ---------------------------------------------------------------------------
# SGM-VFI nodes
# ---------------------------------------------------------------------------
def get_available_sgm_models():
"""List available checkpoint files in the sgm-vfi model directory."""
models = []
if os.path.isdir(SGM_MODEL_DIR):
for f in os.listdir(SGM_MODEL_DIR):
if f.endswith((".pkl", ".pth", ".pt", ".ckpt", ".safetensors")):
models.append(f)
if not models:
models.append(SGM_DEFAULT_MODEL) # Will trigger auto-download
return sorted(models)
def download_sgm_model_from_gdrive(folder_id, dest_path):
"""Download SGM-VFI model from Google Drive folder using gdown."""
try:
import gdown
except ImportError:
raise RuntimeError(
"gdown is required to auto-download the SGM-VFI model. "
"Install it with: pip install gdown"
)
filename = os.path.basename(dest_path)
url = f"https://drive.google.com/drive/folders/{folder_id}"
logger.info(f"Downloading {filename} from Google Drive folder to {dest_path}...")
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
gdown.download_folder(url, output=os.path.dirname(dest_path), quiet=False, remaining_ok=True)
if not os.path.exists(dest_path):
raise RuntimeError(
f"Failed to download {filename}. Please download manually from "
f"https://drive.google.com/drive/folders/{folder_id} "
f"and place it in {os.path.dirname(dest_path)}"
)
logger.info("Download complete.")
class LoadSGMVFIModel:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_path": (get_available_sgm_models(), {
"default": SGM_DEFAULT_MODEL,
"tooltip": "Checkpoint file from models/sgm-vfi/. Auto-downloads on first use if missing. "
"Variant (base/small) is auto-detected from filename.",
}),
"tta": ("BOOLEAN", {
"default": False,
"tooltip": "Test-time augmentation: flip input and average with unflipped result. "
"~2x slower but slightly better quality.",
}),
"num_key_points": ("FLOAT", {
"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Sparsity of global matching. 0.0 = global matching everywhere (slower, better for large motion). "
"Higher = sparser keypoints (faster). Default 0.5 is a good balance.",
}),
}
}
RETURN_TYPES = ("SGM_VFI_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = "video/SGM-VFI"
def load_model(self, model_path, tta, num_key_points):
full_path = os.path.join(SGM_MODEL_DIR, model_path)
if not os.path.exists(full_path):
logger.info(f"Model not found at {full_path}, attempting download...")
download_sgm_model_from_gdrive(SGM_GDRIVE_FOLDER_ID, full_path)
wrapper = SGMVFIModel(
checkpoint_path=full_path,
variant="auto",
num_key_points=num_key_points,
tta=tta,
device="cpu",
)
logger.info(f"SGM-VFI model loaded (variant={wrapper.variant_name}, num_key_points={num_key_points}, tta={tta})")
return (wrapper,)
class SGMVFIInterpolate:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {
"tooltip": "Input image batch. Output frame count: 2x=(2N-1), 4x=(4N-3), 8x=(8N-7).",
}),
"model": ("SGM_VFI_MODEL", {
"tooltip": "SGM-VFI model from the Load SGM-VFI Model node.",
}),
"multiplier": ([2, 4, 8], {
"default": 2,
"tooltip": "Frame rate multiplier. 2x=one interpolation pass, 4x=two recursive passes, 8x=three. Higher = more frames but longer processing.",
}),
"clear_cache_after_n_frames": ("INT", {
"default": 10, "min": 1, "max": 100, "step": 1,
"tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower. Ignored when all_on_gpu is enabled.",
}),
"keep_device": ("BOOLEAN", {
"default": True,
"tooltip": "Keep model on GPU between frame pairs. Faster but uses more VRAM constantly. Disable to free VRAM between pairs (slower due to CPU-GPU transfers).",
}),
"all_on_gpu": ("BOOLEAN", {
"default": False,
"tooltip": "Store all intermediate frames on GPU instead of CPU. Much faster (no transfers) but requires enough VRAM for all frames. Recommended for 48GB+ cards.",
}),
"batch_size": ("INT", {
"default": 1, "min": 1, "max": 64, "step": 1,
"tooltip": "Number of frame pairs to process simultaneously. Higher = faster but uses more VRAM. Start with 1, increase until VRAM is full.",
}),
"chunk_size": ("INT", {
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.",
}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "interpolate"
CATEGORY = "video/SGM-VFI"
def _interpolate_frames(self, frames, model, num_passes, batch_size,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref):
"""Run all interpolation passes on a chunk of frames."""
for pass_idx in range(num_passes):
new_frames = []
num_pairs = frames.shape[0] - 1
pairs_since_clear = 0
for i in range(0, num_pairs, batch_size):
batch_end = min(i + batch_size, num_pairs)
actual_batch = batch_end - i
frames0 = frames[i:batch_end]
frames1 = frames[i + 1:batch_end + 1]
if not keep_device:
model.to(device)
mids = model.interpolate_batch(frames0, frames1, time_step=0.5)
mids = mids.to(storage_device)
if not keep_device:
model.to("cpu")
for j in range(actual_batch):
new_frames.append(frames[i + j:i + j + 1])
new_frames.append(mids[j:j+1])
step_ref[0] += actual_batch
pbar.update_absolute(step_ref[0])
pairs_since_clear += actual_batch
if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available():
clear_sgm_warp_cache()
torch.cuda.empty_cache()
pairs_since_clear = 0
new_frames.append(frames[-1:])
frames = torch.cat(new_frames, dim=0)
if not all_on_gpu and torch.cuda.is_available():
clear_sgm_warp_cache()
torch.cuda.empty_cache()
return frames
@staticmethod
def _count_steps(num_frames, num_passes):
"""Count total interpolation steps for a given input frame count."""
n = num_frames
total = 0
for _ in range(num_passes):
total += n - 1
n = 2 * n - 1
return total
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size):
if images.shape[0] < 2:
return (images,)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_passes = {2: 1, 4: 2, 8: 3}[multiplier]
if all_on_gpu:
keep_device = True
storage_device = device if all_on_gpu else torch.device("cpu")
# Convert from ComfyUI [B, H, W, C] to model [B, C, H, W]
all_frames = images.permute(0, 3, 1, 2).to(storage_device)
total_input = all_frames.shape[0]
# Build chunk boundaries (1-frame overlap between consecutive chunks)
if chunk_size < 2 or chunk_size >= total_input:
chunks = [(0, total_input)]
else:
chunks = []
start = 0
while start < total_input - 1:
end = min(start + chunk_size, total_input)
chunks.append((start, end))
start = end - 1 # overlap by 1 frame
if end == total_input:
break
# Calculate total progress steps across all chunks
total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks)
pbar = ProgressBar(total_steps)
step_ref = [0]
if keep_device:
model.to(device)
result_chunks = []
for chunk_idx, (chunk_start, chunk_end) in enumerate(chunks):
chunk_frames = all_frames[chunk_start:chunk_end].clone()
chunk_result = self._interpolate_frames(
chunk_frames, model, num_passes, batch_size,
device, storage_device, keep_device, all_on_gpu,
clear_cache_after_n_frames, pbar, step_ref,
)
# Skip first frame of subsequent chunks (duplicate of previous chunk's last frame)
if chunk_idx > 0:
chunk_result = chunk_result[1:]
# Move completed chunk to CPU to bound memory when chunking
if len(chunks) > 1:
chunk_result = chunk_result.cpu()
result_chunks.append(chunk_result)
result = torch.cat(result_chunks, dim=0)
# Convert back to ComfyUI [B, H, W, C], on CPU
result = result.cpu().permute(0, 2, 3, 1)
return (result,)
class SGMVFISegmentInterpolate(SGMVFIInterpolate):
"""Process a numbered segment of the input batch for SGM-VFI.
Chain multiple instances with Save nodes between them to bound peak RAM.
The model pass-through output forces sequential execution so each segment
saves and frees from RAM before the next starts.
"""
@classmethod
def INPUT_TYPES(cls):
base = SGMVFIInterpolate.INPUT_TYPES()
base["required"]["segment_index"] = ("INT", {
"default": 0, "min": 0, "max": 10000, "step": 1,
"tooltip": "Which segment to process (0-based). Bounds RAM by only producing this segment's output frames, "
"unlike chunk_size which bounds VRAM but still assembles the full output in RAM. "
"Chain the model output to the next Segment Interpolate to force sequential execution.",
})
base["required"]["segment_size"] = ("INT", {
"default": 500, "min": 2, "max": 10000, "step": 1,
"tooltip": "Number of input frames per segment. Adjacent segments overlap by 1 frame for seamless stitching. "
"Smaller = less peak RAM per segment. Save each segment's output to disk before the next runs.",
})
return base
RETURN_TYPES = ("IMAGE", "SGM_VFI_MODEL")
RETURN_NAMES = ("images", "model")
FUNCTION = "interpolate"
CATEGORY = "video/SGM-VFI"
def interpolate(self, images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
segment_index, segment_size):
total_input = images.shape[0]
# Compute segment boundaries (1-frame overlap)
start = segment_index * (segment_size - 1)
end = min(start + segment_size, total_input)
if start >= total_input - 1:
# Past the end — return empty single frame + model
return (images[:1], model)
segment_images = images[start:end]
is_continuation = segment_index > 0
# Delegate to the parent interpolation logic
(result,) = super().interpolate(
segment_images, model, multiplier, clear_cache_after_n_frames,
keep_device, all_on_gpu, batch_size, chunk_size,
)
if is_continuation:
result = result[1:] # skip duplicate boundary frame
return (result, model)

5
sgm_vfi_arch/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from .feature_extractor import feature_extractor
from .flow_estimation import MultiScaleFlow
from .warplayer import clear_warp_cache
__all__ = ['feature_extractor', 'MultiScaleFlow', 'clear_warp_cache']

116
sgm_vfi_arch/backbone.py Normal file
View File

@@ -0,0 +1,116 @@
import torch.nn as nn
from .trident_conv import MultiScaleTridentConv
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
dilation=dilation, padding=dilation, stride=stride, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
dilation=dilation, padding=dilation, bias=False)
self.relu = nn.ReLU(inplace=True)
self.norm1 = norm_layer(planes)
self.norm2 = norm_layer(planes)
if not stride == 1 or in_planes != planes:
self.norm3 = norm_layer(planes)
if stride == 1 and in_planes == planes:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class CNNEncoder(nn.Module):
def __init__(self, output_dim=128,
norm_layer=nn.InstanceNorm2d,
num_output_scales=1,
**kwargs,
):
super(CNNEncoder, self).__init__()
self.num_branch = num_output_scales
feature_dims = [64, 96, 128]
self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
self.norm1 = norm_layer(feature_dims[0])
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = feature_dims[0]
self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
# highest resolution 1/4 or 1/8
stride = 2 if num_output_scales == 1 else 1
self.layer3 = self._make_layer(feature_dims[2], stride=stride, norm_layer=norm_layer,
) # 1/4 or 1/8
self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
if self.num_branch > 1:
if self.num_branch == 4:
strides = (1, 2, 4, 8)
elif self.num_branch == 3:
strides = (1, 2, 4)
elif self.num_branch == 2:
strides = (1, 2)
else:
raise ValueError
self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
kernel_size=3,
strides=strides,
paddings=1,
num_branch=self.num_branch,
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x) # 1/2
x = self.layer2(x) # 1/4
x = self.layer3(x) # 1/8 or 1/4
x = self.conv2(x)
if self.num_branch > 1:
out = self.trident_conv([x] * self.num_branch) # high to low res
else:
out = [x]
return out

View File

@@ -0,0 +1,459 @@
import torch
import torch.nn as nn
import math
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from .position import PositionEmbeddingSine
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], C)
)
return windows
def window_reverse(windows, window_size, H, W):
nwB, N, C = windows.shape
windows = windows.view(-1, window_size[0], window_size[1], C)
B = int(nwB / (H * W / window_size[0] / window_size[1]))
x = windows.view(
B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
def pad_if_needed(x, size, window_size):
n, h, w, c = size
pad_h = math.ceil(h / window_size[0]) * window_size[0] - h
pad_w = math.ceil(w / window_size[1]) * window_size[1] - w
if pad_h > 0 or pad_w > 0: # center-pad the feature on H and W axes
img_mask = torch.zeros((1, h + pad_h, w + pad_w, 1)) # 1 H W 1
h_slices = (
slice(0, pad_h // 2),
slice(pad_h // 2, h + pad_h // 2),
slice(h + pad_h // 2, None),
)
w_slices = (
slice(0, pad_w // 2),
slice(pad_w // 2, w + pad_w // 2),
slice(w + pad_w // 2, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, window_size
) # nW, window_size*window_size, 1
mask_windows = mask_windows.squeeze(-1)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(
attn_mask != 0, float(-100.0)
).masked_fill(attn_mask == 0, float(0.0))
return nn.functional.pad(
x,
(0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2),
), attn_mask
return x, None
def depad_if_needed(x, size, window_size):
n, h, w, c = size
pad_h = math.ceil(h / window_size[0]) * window_size[0] - h
pad_w = math.ceil(w / window_size[1]) * window_size[1] - w
if pad_h > 0 or pad_w > 0: # remove the center-padding on feature
return x[:, pad_h // 2: pad_h // 2 + h, pad_w // 2: pad_w // 2 + w, :].contiguous()
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.relu = nn.ReLU(inplace=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class InterFrameAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x1, x2, H, W, mask=None):
B, N, C = x1.shape
q = self.q(x1).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
kv = self.kv(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
nW = mask.shape[0] # mask: nW, N, N
attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
1
).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = attn.softmax(dim=-1)
else:
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MotionFormerBlock(nn.Module):
def __init__(self, dim, num_heads, window_size=0, shift_size=0, mlp_ratio=4., bidirectional=True,
qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, ):
super().__init__()
self.window_size = window_size
if not isinstance(self.window_size, (tuple, list)):
self.window_size = to_2tuple(window_size)
self.shift_size = shift_size
if not isinstance(self.shift_size, (tuple, list)):
self.shift_size = to_2tuple(shift_size)
self.bidirectional = bidirectional
self.norm1 = norm_layer(dim)
self.attn = InterFrameAttention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
# BEGIN: absolute pos_embed, beneficial to local information extraction in our experiments
self.pos_embed = PositionEmbeddingSine(dim // 2)
# END: absolute pos_embed
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W, B, self_att=False):
x = x.view(2 * B, H, W, -1)
x_pad, mask = pad_if_needed(x, x.size(), self.window_size)
if self.shift_size[0] or self.shift_size[1]:
_, H_p, W_p, C = x_pad.shape
x_pad = torch.roll(x_pad, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
if hasattr(self, 'HW') and self.HW.item() == H_p * W_p:
shift_mask = self.attn_mask
else:
shift_mask = torch.zeros((1, H_p, W_p, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size[0]),
slice(-self.window_size[0], -self.shift_size[0]),
slice(-self.shift_size[0], None))
w_slices = (slice(0, -self.window_size[1]),
slice(-self.window_size[1], -self.shift_size[1]),
slice(-self.shift_size[1], None))
cnt = 0
for h in h_slices:
for w in w_slices:
shift_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(shift_mask, self.window_size).squeeze(-1)
shift_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
shift_mask = shift_mask.masked_fill(shift_mask != 0,
float(-100.0)).masked_fill(shift_mask == 0,
float(0.0))
if mask is not None:
shift_mask = shift_mask.masked_fill(mask != 0,
float(-100.0))
self.register_buffer("attn_mask", shift_mask)
self.register_buffer("HW", torch.Tensor([H_p * W_p]))
else:
shift_mask = mask
if shift_mask is not None:
shift_mask = shift_mask.to(x_pad.device)
_, Hw, Ww, C = x_pad.shape
x_win = window_partition(x_pad, self.window_size)
nwB = x_win.shape[0]
x_norm = self.norm1(x_win)
# BEGIN: absolute pos embed, beneficial to local information extraction in our experiments
x_norm = x_norm.view(nwB, self.window_size[0], self.window_size[1], C).permute(0, 3, 1, 2)
ape = self.pos_embed(x_norm)
x_norm = x_norm + ape
x_norm = x_norm.permute(0, 2, 3, 1).view(nwB, self.window_size[0] * self.window_size[1], C)
# END: absolute pos embed
if self_att is False:
x_reverse = torch.cat([x_norm[nwB // 2:], x_norm[:nwB // 2]])
x_appearence = self.attn(x_norm, x_reverse, H, W, shift_mask)
else:
x_appearence = self.attn(x_norm, x_norm, H, W, shift_mask)
x_norm = x_norm + self.drop_path(x_appearence)
x_back = x_norm
x_back_win = window_reverse(x_back, self.window_size, Hw, Ww)
if self.shift_size[0] or self.shift_size[1]:
x_back_win = torch.roll(x_back_win, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
x = depad_if_needed(x_back_win, x.size(), self.window_size).view(2 * B, H * W, -1)
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class ConvBlock(nn.Module):
def __init__(self, in_dim, out_dim, depths=2, act_layer=nn.PReLU):
super().__init__()
layers = []
for i in range(depths):
if i == 0:
layers.append(nn.Conv2d(in_dim, out_dim, 3, 1, 1))
else:
layers.append(nn.Conv2d(out_dim, out_dim, 3, 1, 1))
layers.extend([
act_layer(out_dim),
])
self.conv = nn.Sequential(*layers)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.conv(x)
return x
class OverlapPatchEmbed(nn.Module):
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class MotionFormer(nn.Module):
def __init__(self, in_chans=3, embed_dims=None, num_heads=None,
mlp_ratios=None, qkv_bias=True, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=None, window_sizes=None, **kwarg):
super().__init__()
self.depths = depths
self.num_stages = len(embed_dims)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
self.conv_stages = self.num_stages - len(num_heads)
for i in range(self.num_stages):
if i == 0:
block = ConvBlock(in_chans, embed_dims[i], depths[i])
else:
if i < self.conv_stages:
patch_embed = nn.Sequential(
nn.Conv2d(embed_dims[i - 1], embed_dims[i], 3, 2, 1),
nn.PReLU(embed_dims[i])
)
block = ConvBlock(embed_dims[i], embed_dims[i], depths[i])
else:
patch_embed = OverlapPatchEmbed(patch_size=3,
stride=2,
in_chans=embed_dims[i - 1],
embed_dim=embed_dims[i])
block = nn.ModuleList([MotionFormerBlock(
dim=embed_dims[i], num_heads=num_heads[i - self.conv_stages],
window_size=window_sizes[i - self.conv_stages],
shift_size=0 if (j % 2) == 0 else window_sizes[i - self.conv_stages] // 2,
mlp_ratio=mlp_ratios[i - self.conv_stages], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer)
for j in range(depths[i])])
norm = norm_layer(embed_dims[i])
setattr(self, f"norm{i + 1}", norm)
setattr(self, f"patch_embed{i + 1}", patch_embed)
cur += depths[i]
setattr(self, f"block{i + 1}", block)
self.cor = {}
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def get_cor(self, shape, device):
k = (str(shape), str(device))
if k not in self.cor:
tenHorizontal = torch.linspace(-1.0, 1.0, shape[2], device=device).view(
1, 1, 1, shape[2]).expand(shape[0], -1, shape[1], -1).permute(0, 2, 3, 1)
tenVertical = torch.linspace(-1.0, 1.0, shape[1], device=device).view(
1, 1, shape[1], 1).expand(shape[0], -1, -1, shape[2]).permute(0, 2, 3, 1)
self.cor[k] = torch.cat([tenHorizontal, tenVertical], -1).to(device)
return self.cor[k]
def forward(self, x1, x2):
B = x1.shape[0]
x = torch.cat([x1, x2], 0)
appearence_features = []
xs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f"patch_embed{i + 1}", None)
block = getattr(self, f"block{i + 1}", None)
norm = getattr(self, f"norm{i + 1}", None)
if i < self.conv_stages:
if i > 0:
x = patch_embed(x)
x = block(x)
xs.append(x)
else:
x, H, W = patch_embed(x)
for j in range(len(block)):
x = block[j](x, H, W, B, self_att=False)
xs.append(x.reshape(2 * B, H, W, -1).permute(0, 3, 1, 2).contiguous())
x = norm(x)
x = x.reshape(2 * B, H, W, -1).permute(0, 3, 1, 2).contiguous()
appearence_features.append(x)
return appearence_features
class DWConv(nn.Module):
def __init__(self, dim):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
x = self.dwconv(x)
x = x.reshape(B, C, -1).transpose(1, 2)
return x
def feature_extractor(**kargs):
model = MotionFormer(**kargs)
return model

View File

@@ -0,0 +1,208 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .refine import *
from .matching import MatchingBlock
from .gmflow import GMFlow
from .utils import InputPadder
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64, layers=4, scale=4, in_else=17):
super(IFBlock, self).__init__()
self.scale = scale
self.conv0 = nn.Sequential(
conv(in_planes + in_else, c, 3, 1, 1),
conv(c, c, 3, 1, 1),
)
self.convblock = nn.Sequential(
*[conv(c, c) for _ in range(layers)]
)
self.lastconv = conv(c, 5)
def forward(self, x, flow=None, feature=None):
if self.scale != 1:
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", align_corners=False)
if flow != None:
flow = F.interpolate(flow, scale_factor=1. / self.scale, mode="bilinear",
align_corners=False) * 1. / self.scale
x = torch.cat((x, flow), 1)
if feature != None:
x = torch.cat((x, feature), 1)
x = self.conv0(x)
x = self.convblock(x) + x
tmp = self.lastconv(x)
flow_s = tmp[:, :4]
tmp = F.interpolate(tmp, scale_factor=self.scale, mode="bilinear", align_corners=False)
flow = tmp[:, :4] * self.scale
mask = tmp[:, 4:5]
return flow, mask, flow_s
class MultiScaleFlow(nn.Module):
def __init__(self, backbone, **kargs):
super(MultiScaleFlow, self).__init__()
self.flow_num_stage = len(kargs['hidden_dims'])
self.feature_bone = backbone
self.scale = [1, 2, 4, 8]
self.num_key_points = [kargs['num_key_points']]
self.block = nn.ModuleList(
[IFBlock(kargs['embed_dims'][-1] * 2, 128, 2, self.scale[-1], in_else=7), # 1/8
IFBlock(kargs['embed_dims'][-2] * 2, 128, 2, self.scale[-2], in_else=18)]) # 1/4
self.contextnet = Contextnet(kargs['c'] * 2)
self.unet = Unet(kargs['c'] * 2)
self.gmflow = GMFlow(
num_scales=1,
upsample_factor=8,
feature_channels=128,
attention_type='swin',
num_transformer_layers=6,
ffn_dim_expansion=4,
num_head=1)
self.matching_block = nn.ModuleList([
MatchingBlock(scale=8, dim=kargs['embed_dims'][-1], c=kargs['c'] * 4, num_layers=1, gm=True),
None
])
self.padding_factor = 16
def calculate_flow(self, imgs, timestep):
img0, img1 = imgs[:, :3], imgs[:, 3:6]
B = img0.size(0)
flow, mask = None, None
flow_s = None
af = self.feature_bone(img0, img1)
if self.gmflow is not None:
padder = InputPadder(img0.shape, padding_factor=self.padding_factor)
img0_p, img1_p = padder.pad(img0, img1)
results = self.gmflow(img0_p, img1_p, attn_splits_list=[1], pred_bidir_flow=False)
matching_feat = results['trans_feat']
padder_8 = InputPadder(af[-1].shape, padding_factor=self.padding_factor // self.scale[-1])
matching_feat[0] = padder_8.unpad(matching_feat[0])
for i in range(2):
t = (img0[:B, :1].clone() * 0 + 1) * timestep
af0 = af[-1 - i][:B]
af1 = af[-1 - i][B:]
if flow != None:
flow_d, mask_d, flow_s_d = self.block[i](
torch.cat((img0, img1, warped_img0, warped_img1, mask, t), 1),
flow,
torch.cat([af0, af1], 1),
)
flow = flow + flow_d
mask = mask + mask_d
flow_s = F.interpolate(flow_s, scale_factor=2, mode="bilinear", align_corners=False) * 2
flow_s = flow_s + flow_s_d
else:
flow, mask, flow_s = self.block[i](
torch.cat((img0, img1, t), 1),
None,
torch.cat([af0, af1], 1))
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
if self.matching_block[i] is not None:
dict = self.matching_block[i](img0=img0, img1=img1, x=matching_feat[i], main_x=af[-1 - i],
init_flow=flow, init_flow_s=flow_s, init_mask=mask,
warped_img0=warped_img0, warped_img1=warped_img1,
num_key_points=self.num_key_points[i], scale_factor=self.scale[-1 - i],
timestep=timestep)
flow_t, mask_t = dict['flow_t'], dict['mask_t']
flow = flow + flow_t
mask = mask + mask_t
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
return flow, mask
def coraseWarp_and_Refine(self, imgs, flow, mask):
img0, img1 = imgs[:, :3], imgs[:, 3:6]
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, :3] * 2 - 1
mask_ = torch.sigmoid(mask)
merged = warped_img0 * mask_ + warped_img1 * (1 - mask_)
pred = torch.clamp(merged + res, 0, 1)
return pred
def forward(self, x, timestep=0.5):
img0, img1 = x[:, :3], x[:, 3:6]
B = x.size(0)
flow_list, mask_list = [], []
merged, merged_fine = [], []
warped_img0, warped_img1 = img0, img1
flow, mask, flow_s = None, None, None
flow_matching_list = []
matching_feat = []
af = self.feature_bone(img0, img1)
if self.gmflow is not None:
padder = InputPadder(img0.shape, padding_factor=self.padding_factor, additional_pad=False)
img0_p, img1_p = padder.pad(img0, img1)
results = self.gmflow(img0_p, img1_p, attn_splits_list=[1], pred_bidir_flow=False)
matching_feat = results['trans_feat']
padder_8 = InputPadder(af[-1].shape, padding_factor=self.padding_factor // self.scale[-1], additional_pad=False)
matching_feat[0] = padder_8.unpad(matching_feat[0])
for i in range(2):
af0 = af[-1 - i][:B]
af1 = af[-1 - i][B:]
t = (img0[:B, :1].clone() * 0 + 1) * timestep
if flow != None:
flow_d, mask_d, flow_s_d = self.block[i](
torch.cat((img0, img1, warped_img0, warped_img1, mask, t), 1),
flow,
torch.cat([af0, af1], 1),
)
flow = flow + flow_d
mask = mask + mask_d
flow_s = F.interpolate(flow_s, scale_factor=2, mode="bilinear", align_corners=False) * 2
flow_s = flow_s + flow_s_d
else:
flow, mask, flow_s = self.block[i](
torch.cat((img0, img1, t), 1),
None,
torch.cat([af0, af1], 1))
mask_list.append(torch.sigmoid(mask))
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i]))
if self.matching_block[i] is not None:
dict = self.matching_block[i](img0=img0, img1=img1, x=matching_feat[i], main_x=af[-1-i].detach(),
init_flow=flow.detach(), init_flow_s=flow_s.detach(), init_mask=mask.detach(),
warped_img0=warped_img0.detach(), warped_img1=warped_img1.detach(),
num_key_points=self.num_key_points[i], scale_factor=self.scale[-1-i],
timestep=0.5)
flow_t, mask_t = dict['flow_t'], dict['mask_t']
flow = flow + flow_t
mask = mask + mask_t
mask_list[i] = torch.sigmoid(mask)
warped_img0_fine = warp(img0, flow[:, 0:2])
warped_img1_fine = warp(img1, flow[:, 2:4])
merged_fine.append(warped_img0_fine * mask_list[i] + warped_img1_fine * (1 - mask_list[i]))
warped_img0, warped_img1 = warped_img0_fine, warped_img1_fine # NOTE: for next iteration training
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, :3] * 2 - 1
pred = torch.clamp(merged[-1] + res, 0, 1)
merged.extend(merged_fine)
return flow_list, mask_list, merged, pred, flow_matching_list

96
sgm_vfi_arch/geometry.py Normal file
View File

@@ -0,0 +1,96 @@
import torch
import torch.nn.functional as F
def coords_grid(b, h, w, homogeneous=False, device=None):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
stacks = [x, y]
if homogeneous:
ones = torch.ones_like(x) # [H, W]
stacks.append(ones)
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
if device is not None:
grid = grid.to(device)
return grid
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
assert device is not None
x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
torch.linspace(h_min, h_max, len_h, device=device)],
)
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
return grid
def normalize_coords(coords, h, w):
# coords: [B, H, W, 2]
c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
return (coords - c) / c # [-1, 1]
def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
# img: [B, C, H, W]
# sample_coords: [B, 2, H, W] in image scale
if sample_coords.size(1) != 2: # [B, H, W, 2]
sample_coords = sample_coords.permute(0, 3, 1, 2)
b, _, h, w = sample_coords.shape
# Normalize to [-1, 1]
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
if return_mask:
mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
return img, mask
return img
def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
b, c, h, w = feature.size()
assert flow.size(1) == 2
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
return bilinear_sample(feature, grid, padding_mode=padding_mode,
return_mask=mask)
def forward_backward_consistency_check(fwd_flow, bwd_flow,
alpha=0.01,
beta=0.5
):
# fwd_flow, bwd_flow: [B, 2, H, W]
# alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
threshold = alpha * flow_mag + beta
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
bwd_occ = (diff_bwd > threshold).float()
return fwd_occ, bwd_occ

87
sgm_vfi_arch/gmflow.py Normal file
View File

@@ -0,0 +1,87 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone import CNNEncoder
from .transformer import FeatureTransformer, FeatureFlowAttention
from .utils import feature_add_position
class GMFlow(nn.Module):
def __init__(self,
num_scales=1,
upsample_factor=8,
feature_channels=128,
attention_type='swin',
num_transformer_layers=6,
ffn_dim_expansion=4,
num_head=1,
**kwargs,
):
super(GMFlow, self).__init__()
self.num_scales = num_scales
self.feature_channels = feature_channels
self.upsample_factor = upsample_factor
self.attention_type = attention_type
self.num_transformer_layers = num_transformer_layers
# CNN backbone
self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
# Transformer
self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
d_model=feature_channels,
nhead=num_head,
attention_type=attention_type,
ffn_dim_expansion=ffn_dim_expansion,
)
def extract_feature(self, img0, img1):
concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
# reverse: resolution from low to high
features = features[::-1]
feature0, feature1 = [], []
for i in range(len(features)):
feature = features[i]
chunks = torch.chunk(feature, 2, 0) # tuple
feature0.append(chunks[0])
feature1.append(chunks[1])
return feature0, feature1
def forward(self, img0, img1,
attn_splits_list=None,
corr_radius_list=None,
prop_radius_list=None,
pred_bidir_flow=False,
**kwargs,
):
results_dict = {}
flow_preds = []
flow_s_macthing = []
flow_s_prop = []
transformer_features = []
feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
flow = None
for scale_idx in range(self.num_scales):
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
attn_splits = attn_splits_list[scale_idx]
feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
# Transformer
feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits)
transformer_features.append(torch.cat([feature0, feature1], 0))
results_dict.update({'trans_feat': transformer_features})
return results_dict

278
sgm_vfi_arch/matching.py Normal file
View File

@@ -0,0 +1,278 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .warplayer import warp as backwarp
from .softsplat import softsplat
from .geometry import coords_grid
# for random sample ablation
def random_sample(feature, num_points=256):
rand_ind = torch.randint(low=0, high=feature.shape[1], size=(feature.shape[0], num_points)).unsqueeze(-1).to(
feature.device)
kp = torch.gather(feature, dim=1, index=rand_ind.expand(-1, -1, feature.shape[2]))
return rand_ind, kp
def sample_key_points(importance_map, feature, num_points=256):
importance_map = importance_map.view(-1, 1, importance_map.shape[2] * importance_map.shape[3]).permute(0, 2, 1)
_, kp_ind = torch.topk(importance_map, num_points, dim=1)
kp = torch.gather(feature, dim=1, index=kp_ind.expand(-1, -1, feature.shape[2]))
return kp_ind, kp
def forward_warp(tenIn, tenFlow, z=None):
if z is None:
z = torch.ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]).to(tenIn.device)
else:
z = torch.where(z == 0, -20, 1)
out = softsplat(tenIn, tenFlow, tenMetric=z, strMode='soft')
return out
def warp_twice(imgA, target, flow_tA, flow_tB):
It_warp = backwarp(imgA, flow_tA) # backward warp(I1,Ft1)
z = torch.ones([imgA.shape[0], 1, imgA.shape[2], imgA.shape[3]]).to(imgA.device)
IB_warp = softsplat(tenIn=It_warp, tenFlow=flow_tB, tenMetric=z, strMode='soft')
return IB_warp
def build_map(imgA, imgB, flow_tA, flow_tB):
# build map for img B
IB_warp = warp_twice(imgA, imgB, flow_tA, flow_tB)
difference_map = IB_warp - imgB # [B, 3, H, W], difference map on IB
difference_map = torch.sum(torch.abs(difference_map), dim=1, keepdim=True) # B, 1, H, W
return difference_map
def build_hole_mask(img_template, flow_tA, flow_tB):
# build hole mask
with torch.no_grad():
ones = torch.ones(img_template.shape[0], 1, img_template.shape[2], img_template.shape[3]).to(
img_template.device)
out = warp_twice(ones, ones, flow_tA, flow_tB)
hole_mask = torch.where(out == 0, 0, 1)
return hole_mask
def gen_importance_map(img0, img1, flow):
I1_dmap = build_map(img0, img1, flow[:, 0:2], flow[:, 2:4])
I0_dmap = build_map(img1, img0, flow[:, 2:4], flow[:, 0:2])
I1_hole_mask = build_hole_mask(img0, flow[:, 0:2], flow[:, 2:4])
I0_hole_mask = build_hole_mask(img1, flow[:, 2:4], flow[:, 0:2])
I1_dmap = I1_dmap * I1_hole_mask
I0_dmap = I0_dmap * I0_hole_mask
I0_prob = warp_twice(I1_dmap, I1_dmap, flow[:, 2:4], flow[:, 0:2])
I1_prob = warp_twice(I0_dmap, I0_dmap, flow[:, 0:2], flow[:, 2:4])
importance_map = torch.cat([I0_prob, I1_prob], dim=0) # 2B, 1, H, W
return importance_map
def global_matching(key_feature, global_feature, key_index, H, W):
b, n, c = global_feature.shape
query = key_feature
key = global_feature
correlation = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, k, H*W]
prob = F.softmax(correlation, dim=-1)
init_grid = coords_grid(b, H, W, homogeneous=False, device=global_feature.device)
grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
out = torch.matmul(prob, grid) # B, k, 2
if key_index is not None:
flow_fix = torch.zeros_like(grid)
# key_index: [B, K, 1], out: [B, K, 2], flow_fix: [B, H*W, 2]
flow_fix = torch.scatter(flow_fix, dim=1, index=key_index.expand(-1, -1, 2), src=out)
flow_fix = flow_fix.view(b, H, W, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
# for grid, points in grid and not in key_index, set to 0
grid_new = torch.zeros_like(grid)
key_pos = torch.ones_like(out)
grid_new = torch.scatter(grid_new, dim=1, index=key_index.expand(-1, -1, 2), src=key_pos)
grid = (grid * grid_new).reshape(b, H, W, 2).permute(0, 3, 1, 2)
flow_fix = flow_fix - grid
else:
flow_fix = out.view(b, H, W, 2).permute(0, 3, 1, 2)
flow_fix = flow_fix - init_grid
return flow_fix, prob
def extract_topk(foo, k):
b, _, h, w = foo.shape
foo = foo.view(b, 1, h * w).permute(0, 2, 1)
kp, kp_ind = torch.topk(foo, k, dim=1)
grid = torch.zeros(b, h * w, 1).to(foo.device)
out = torch.scatter(grid, dim=1, index=kp_ind, src=kp)
out = out.permute(0, 2, 1).reshape(b, 1, h, w)
return out
def flow_shift(flow_fix, timestep, num_key_points=None, select_topk=False):
B = flow_fix.shape[0] // 2
z = torch.where(flow_fix == 0, 0, 1).detach().sum(1, keepdim=True) / 2
zt0, zt1 = z[B:], z[:B]
flow_fix_t0 = forward_warp(flow_fix[B:] * timestep, flow_fix[B:] * (1 - timestep), z=zt0)
flow_fix_t1 = forward_warp(flow_fix[:B] * (1 - timestep), flow_fix[:B] * timestep, z=zt1)
flow_fix_t = torch.cat([flow_fix_t0, flow_fix_t1], 0)
if select_topk and num_key_points != -1:
warp_map_t0 = softsplat(zt0, flow_fix[B:] * (1 - timestep), None, 'sum')
warp_map_t1 = softsplat(zt1, flow_fix[:B] * timestep, None, 'sum')
warp_map = torch.cat([warp_map_t0, warp_map_t1], 0)
warp_map_topk = extract_topk(warp_map, num_key_points)
warp_map_topk = torch.where(warp_map_topk != 0, 1, 0)
flow_fix_t = flow_fix_t * warp_map_topk
return flow_fix_t
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes=64, out_planes=64, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, bias=True),
nn.PReLU(out_planes)
)
class FlowRefine(nn.Module):
def __init__(self, in_planes, scale=4, c=64, n_layers=8):
super(FlowRefine, self).__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c, 3, 1, 1),
conv(c, c, 3, 1, 1),
)
self.convblock = nn.Sequential(
*[conv(c, c) for _ in range(n_layers)]
)
self.lastconv = conv(c, 5)
self.scale = scale
def forward(self, x, flow_s, flow):
if self.scale != 1:
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", align_corners=False)
if flow is not None:
flow = F.interpolate(flow, scale_factor=1. / self.scale, mode="bilinear",
align_corners=False) * 1. / self.scale
x = torch.cat((x, flow), 1)
if flow_s is not None:
x = torch.cat((x, flow_s), 1)
x = self.conv0(x)
x = self.convblock(x) + x
x = self.lastconv(x)
tmp = F.interpolate(x, scale_factor=self.scale, mode="bilinear", align_corners=False)
flow = tmp[:, :4] * self.scale
mask = tmp[:, 4:5]
return flow, mask
class MergingBlock(nn.Module):
def __init__(self, radius=3, input_dim=256, hidden_dim=256):
super(MergingBlock, self).__init__()
self.r = radius
self.rf = radius ** 2
self.conv = nn.Sequential(nn.Conv2d(8 + 2*input_dim, hidden_dim, 3, 1, 1),
nn.PReLU(hidden_dim),
nn.Conv2d(hidden_dim, 2*2*self.rf, 1, 1, 0))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(0.1 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, feature, init_flow, flow_fix):
"""
:param feature: [B, C, H, W] -> (local feature) or (local feature + matching feature)
:param init_flow: [B, 2, H, W] -> (local init flow)
:param flow_fix: [B, 2, H, W] -> (matching output, flow_fix (after patching, no hollows))
"""
b, flow_channel, h, w = init_flow.shape
concat = torch.cat((init_flow, flow_fix, feature), dim=1)
mask = self.conv(concat)
assert init_flow.shape == flow_fix.shape, f"different flow shape not implemented yet"
mask = mask.view(b, 1, 2 * 2 * self.rf, h, w)
mask0 = mask[:, :, :2 * self.rf, :, :]
mask1 = mask[:, :, 2 * self.rf:, :, :]
mask = torch.cat([mask0, mask1], dim=0)
mask = torch.softmax(mask, dim=2)
init_flow_all = torch.cat([init_flow[:, 0:2], init_flow[:, 2:4]], dim=0)
flow_fix_all = torch.cat([flow_fix[:, 0:2], flow_fix[:, 2:4]], dim=0)
init_flow_grid = F.unfold(init_flow_all, [self.r, self.r], padding=self.r//2)
init_flow_grid = init_flow_grid.view(2*b, 2, self.rf, h, w) # [B, 2, 9, H, W]
flow_fix_grid = F.unfold(flow_fix_all, [self.r, self.r], padding=self.r//2)
flow_fix_grid = flow_fix_grid.view(2*b, 2, self.rf, h, w) # [B, 2, 9, H, W]
flow_grid = torch.cat([init_flow_grid, flow_fix_grid], dim=2) # [B, 2, 2*9, H, W]
merge_flow = torch.sum(mask * flow_grid, dim=2) # [B, 2, H, W]
return merge_flow
class MatchingBlock(nn.Module):
def __init__(self, scale, c, dim, num_layers=2, gm=True):
super(MatchingBlock, self).__init__()
self.gm = gm
self.dim = dim
self.scale = scale
self.merge = MergingBlock(radius=3, input_dim=dim+128, hidden_dim=256)
self.refine_block = FlowRefine(27, scale, c, num_layers)
def forward(self, img0, img1, x, main_x, init_flow, init_flow_s, init_mask,
warped_img0, warped_img1, num_key_points, scale_factor, timestep=0.5):
result_dict = {}
_, c, h, w = x.shape
B = main_x.shape[0] // 2
# NOTE:
# 1. we stop sparse selecting points when the image resolution
# becomes too small (1/8 feature map resolution <= 32, i.e., h <= 256)
# (see `random_rescale` in train_x4k.py)
# 2. This limitation should be deleted when evaluating on low-resolution images (<=256x256)
if num_key_points != -1 and h > 32:
num_key_points = int(num_key_points * (h * w))
else:
num_key_points = -1 # -1 stands for global matching
feature = x.permute(0, 2, 3, 1).reshape(2 * B, h*w, c)
feature_reverse = torch.cat([feature[B:], feature[:B]], 0)
if num_key_points == -1:
flow_fix_norm, _ = global_matching(feature, feature_reverse, None, h, w)
else:
imap = gen_importance_map(img0, img1, init_flow)
imap_s = F.interpolate(imap, size=(h, w), mode="bilinear", align_corners=False)
kp_ind, kp_feature = sample_key_points(imap_s, feature, num_key_points)
flow_fix_norm, _ = global_matching(kp_feature, feature_reverse, kp_ind, h, w)
flow_fix = flow_shift(flow_fix_norm, timestep, num_key_points, select_topk=True)
flow_fix = torch.cat([flow_fix[:B], flow_fix[B:]], 1)
flow_r = torch.where(flow_fix == 0, init_flow_s, flow_fix)
flow_merge = self.merge(torch.cat([x[:B], x[B:], main_x[:B], main_x[B:]], dim=1), init_flow_s, flow_r)
flow_merge = torch.cat([flow_merge[:B], flow_merge[B:]], dim=1)
img0_s = F.interpolate(img0, scale_factor=1 / scale_factor, mode="bilinear", align_corners=False)
img1_s = F.interpolate(img1, scale_factor=1 / scale_factor, mode="bilinear", align_corners=False)
warped_img0_fine_s_m = backwarp(img0_s, flow_merge[:, 0:2])
warped_img1_fine_s_m = backwarp(img1_s, flow_merge[:, 2:4])
flow_t, mask_t = self.refine_block(torch.cat((img0, img1, warped_img0, warped_img1, init_mask), 1),
torch.cat([warped_img0_fine_s_m, warped_img1_fine_s_m, flow_merge], 1),
init_flow)
result_dict.update({'flow_t': flow_t})
result_dict.update({'mask_t': mask_t})
return result_dict

46
sgm_vfi_arch/position.py Normal file
View File

@@ -0,0 +1,46 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
import torch
import torch.nn as nn
import math
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x):
# x = tensor_list.tensors # [B, C, H, W]
# mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
b, c, h, w = x.size()
mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
y_embed = mask.cumsum(1, dtype=torch.float32)
x_embed = mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos

98
sgm_vfi_arch/refine.py Normal file
View File

@@ -0,0 +1,98 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from timm.models.layers import trunc_normal_
from .warplayer import warp
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes)
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class Contextnet(nn.Module):
def __init__(self, c=16):
super(Contextnet, self).__init__()
self.conv1 = Conv2(3, c)
self.conv2 = Conv2(c, 2 * c)
self.conv3 = Conv2(2 * c, 4 * c)
self.conv4 = Conv2(4 * c, 8 * c)
def forward(self, x, flow):
x = self.conv1(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False,
recompute_scale_factor=False) * 0.5
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False,
recompute_scale_factor=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False,
recompute_scale_factor=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False,
recompute_scale_factor=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class Unet(nn.Module):
def __init__(self, c=16, out=3):
super(Unet, self).__init__()
self.down0 = Conv2(17, 2*c)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c)
self.up3 = deconv(4*c, c)
self.conv = nn.Conv2d(c, out, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1))
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return torch.sigmoid(x)

530
sgm_vfi_arch/softsplat.py Normal file
View File

@@ -0,0 +1,530 @@
#!/usr/bin/env python
import collections
import cupy
import os
import re
import torch
import typing
##########################################################
objCudacache = {}
def cuda_int32(intIn:int):
return cupy.int32(intIn)
# end
def cuda_float32(fltIn:float):
return cupy.float32(fltIn)
# end
def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
if 'device' not in objCudacache:
objCudacache['device'] = torch.cuda.get_device_name()
# end
strKey = strFunction
for strVariable in objVariables:
objValue = objVariables[strVariable]
strKey += strVariable
if objValue is None:
continue
elif type(objValue) == int:
strKey += str(objValue)
elif type(objValue) == float:
strKey += str(objValue)
elif type(objValue) == bool:
strKey += str(objValue)
elif type(objValue) == str:
strKey += objValue
elif type(objValue) == torch.Tensor:
strKey += str(objValue.dtype)
strKey += str(objValue.shape)
strKey += str(objValue.stride())
elif True:
print(strVariable, type(objValue))
assert(False)
# end
# end
strKey += objCudacache['device']
if strKey not in objCudacache:
for strVariable in objVariables:
objValue = objVariables[strVariable]
if objValue is None:
continue
elif type(objValue) == int:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == float:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == bool:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == str:
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
strKernel = strKernel.replace('{{type}}', 'unsigned char')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
strKernel = strKernel.replace('{{type}}', 'half')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
strKernel = strKernel.replace('{{type}}', 'float')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
strKernel = strKernel.replace('{{type}}', 'double')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
strKernel = strKernel.replace('{{type}}', 'int')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
strKernel = strKernel.replace('{{type}}', 'long')
elif type(objValue) == torch.Tensor:
print(strVariable, objValue.dtype)
assert(False)
elif True:
print(strVariable, type(objValue))
assert(False)
# end
# end
while True:
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
if objMatch is None:
break
# end
intArg = int(objMatch.group(2))
strTensor = objMatch.group(4)
intSizes = objVariables[strTensor].size()
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
# end
while True:
objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == '(' else 0
intParentheses -= 1 if strKernel[intStop] == ')' else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(',')
assert(intArgs == len(strArgs) - 1)
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
# end
strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
# end
while True:
objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == '(' else 0
intParentheses -= 1 if strKernel[intStop] == ')' else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(',')
assert(intArgs == len(strArgs) - 1)
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
# end
strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
# end
objCudacache[strKey] = {
'strFunction': strFunction,
'strKernel': strKernel
}
# end
return strKey
# end
@cupy.memoize(for_each_device=True)
def cuda_launch(strKey:str):
if 'CUDA_HOME' not in os.environ:
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
# end
return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction'],
options=tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include']))
# end
##########################################################
def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str):
assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
if strMode == 'sum': assert(tenMetric is None)
if strMode == 'avg': assert(tenMetric is None)
if strMode.split('-')[0] == 'linear': assert(tenMetric is not None)
if strMode.split('-')[0] == 'soft': assert(tenMetric is not None)
if strMode == 'avg':
tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1)
elif strMode.split('-')[0] == 'linear':
tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
elif strMode.split('-')[0] == 'soft':
tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
# end
tenOut = softsplat_func.apply(tenIn, tenFlow)
if strMode.split('-')[0] in ['avg', 'linear', 'soft']:
tenNormalize = tenOut[:, -1:, :, :]
if len(strMode.split('-')) == 1:
tenNormalize = tenNormalize + 0.0000001
elif strMode.split('-')[1] == 'addeps':
tenNormalize = tenNormalize + 0.0000001
elif strMode.split('-')[1] == 'zeroeps':
tenNormalize[tenNormalize == 0.0] = 1.0
elif strMode.split('-')[1] == 'clipeps':
tenNormalize = tenNormalize.clip(0.0000001, None)
# end
tenOut = tenOut[:, :-1, :, :] / tenNormalize
# end
return tenOut
# end
class softsplat_func(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, tenIn, tenFlow):
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
if tenIn.is_cuda == True:
cuda_launch(cuda_kernel('softsplat_out', '''
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
{{type}}* __restrict__ tenOut
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut);
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
const int intX = ( intIndex ) % SIZE_3(tenOut);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
{{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
}
} }
''', {
'tenIn': tenIn,
'tenFlow': tenFlow,
'tenOut': tenOut
}))(
grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
elif tenIn.is_cuda != True:
assert(False)
# end
self.save_for_backward(tenIn, tenFlow)
return tenOut
# end
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(self, tenOutgrad):
tenIn, tenFlow = self.saved_tensors
tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None
tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None
if tenIngrad is not None:
cuda_launch(cuda_kernel('softsplat_ingrad', '''
extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
const {{type}}* __restrict__ tenOutgrad,
{{type}}* __restrict__ tenIngrad,
{{type}}* __restrict__ tenFlowgrad
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
const int intX = ( intIndex ) % SIZE_3(tenIngrad);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltIngrad = 0.0f;
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
}
tenIngrad[intIndex] = fltIngrad;
} }
''', {
'tenIn': tenIn,
'tenFlow': tenFlow,
'tenOutgrad': tenOutgrad,
'tenIngrad': tenIngrad,
'tenFlowgrad': tenFlowgrad
}))(
grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
# end
if tenFlowgrad is not None:
cuda_launch(cuda_kernel('softsplat_flowgrad', '''
extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
const {{type}}* __restrict__ tenOutgrad,
{{type}}* __restrict__ tenIngrad,
{{type}}* __restrict__ tenFlowgrad
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltFlowgrad = 0.0f;
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = 0.0f;
{{type}} fltNortheast = 0.0f;
{{type}} fltSouthwest = 0.0f;
{{type}} fltSoutheast = 0.0f;
if (intC == 0) {
fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
} else if (intC == 1) {
fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
}
for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
{{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
}
}
tenFlowgrad[intIndex] = fltFlowgrad;
} }
''', {
'tenIn': tenIn,
'tenFlow': tenFlow,
'tenOutgrad': tenOutgrad,
'tenIngrad': tenIngrad,
'tenFlowgrad': tenFlowgrad
}))(
grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
# end
return tenIngrad, tenFlowgrad
# end
# end

450
sgm_vfi_arch/transformer.py Normal file
View File

@@ -0,0 +1,450 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def split_feature(feature,
num_splits=2,
channel_last=False,
):
if channel_last: # [B, H, W, C]
b, h, w, c = feature.size()
assert h % num_splits == 0 and w % num_splits == 0
b_new = b * num_splits * num_splits
h_new = h // num_splits
w_new = w // num_splits
feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
else: # [B, C, H, W]
b, c, h, w = feature.size()
assert h % num_splits == 0 and w % num_splits == 0, f'h: {h}, w: {w}, num_splits: {num_splits}'
b_new = b * num_splits * num_splits
h_new = h // num_splits
w_new = w // num_splits
feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
return feature
def merge_splits(splits,
num_splits=2,
channel_last=False,
):
if channel_last: # [B*K*K, H/K, W/K, C]
b, h, w, c = splits.size()
new_b = b // num_splits // num_splits
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
else: # [B*K*K, C, H/K, W/K]
b, c, h, w = splits.size()
new_b = b // num_splits // num_splits
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
return merge
def single_head_full_attention(q, k, v):
# q, k, v: [B, L, C]
assert q.dim() == k.dim() == v.dim() == 3
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
attn = torch.softmax(scores, dim=2) # [B, L, L]
out = torch.matmul(attn, v) # [B, L, C]
return out
def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w,
shift_size_h, shift_size_w, device=None):
# Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
# calculate attention mask for SW-MSA
h, w = input_resolution
img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
h_slices = (slice(0, -window_size_h),
slice(-window_size_h, -shift_size_h),
slice(-shift_size_h, None))
w_slices = (slice(0, -window_size_w),
slice(-window_size_w, -shift_size_w),
slice(-shift_size_w, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True)
mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def single_head_split_window_attention(q, k, v,
num_splits=1,
with_shift=False,
h=None,
w=None,
attn_mask=None,
):
# Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
# q, k, v: [B, L, C]
assert q.dim() == k.dim() == v.dim() == 3
assert h is not None and w is not None
assert q.size(1) == h * w
b, _, c = q.size()
b_new = b * num_splits * num_splits
window_size_h = h // num_splits
window_size_w = w // num_splits
q = q.view(b, h, w, c) # [B, H, W, C]
k = k.view(b, h, w, c)
v = v.view(b, h, w, c)
scale_factor = c ** 0.5
if with_shift:
assert attn_mask is not None # compute once
shift_size_h = window_size_h // 2
shift_size_w = window_size_w // 2
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
k = split_feature(k, num_splits=num_splits, channel_last=True)
v = split_feature(v, num_splits=num_splits, channel_last=True)
scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
if with_shift:
scores += attn_mask.repeat(b, 1, 1)
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
num_splits=num_splits, channel_last=True) # [B, H, W, C]
# shift back
if with_shift:
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
out = out.view(b, -1, c)
return out
class TransformerLayer(nn.Module):
def __init__(self,
d_model=256,
nhead=1,
attention_type='swin',
no_ffn=False,
ffn_dim_expansion=4,
with_shift=False,
**kwargs,
):
super(TransformerLayer, self).__init__()
self.dim = d_model
self.nhead = nhead
self.attention_type = attention_type
self.no_ffn = no_ffn
self.with_shift = with_shift
# multi-head attention
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.merge = nn.Linear(d_model, d_model, bias=False)
self.norm1 = nn.LayerNorm(d_model)
# no ffn after self-attn, with ffn after cross-attn
if not self.no_ffn:
in_channels = d_model * 2
self.mlp = nn.Sequential(
nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
nn.GELU(),
nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, source, target,
height=None,
width=None,
shifted_window_attn_mask=None,
attn_num_splits=None,
**kwargs,
):
# source, target: [B, L, C]
query, key, value = source, target, target
# single-head attention
query = self.q_proj(query) # [B, L, C]
key = self.k_proj(key) # [B, L, C]
value = self.v_proj(value) # [B, L, C]
if self.attention_type == 'swin' and attn_num_splits > 1:
if self.nhead > 1:
# we observe that multihead attention slows down the speed and increases the memory consumption
# without bringing obvious performance gains and thus the implementation is removed
raise NotImplementedError
else:
message = single_head_split_window_attention(query, key, value,
num_splits=attn_num_splits,
with_shift=self.with_shift,
h=height,
w=width,
attn_mask=shifted_window_attn_mask,
)
else:
message = single_head_full_attention(query, key, value) # [B, L, C]
message = self.merge(message) # [B, L, C]
message = self.norm1(message)
if not self.no_ffn:
message = self.mlp(torch.cat([source, message], dim=-1))
message = self.norm2(message)
return source + message
class TransformerBlock(nn.Module):
"""self attention + cross attention + FFN"""
def __init__(self,
d_model=256,
nhead=1,
attention_type='swin',
ffn_dim_expansion=4,
with_shift=False,
**kwargs,
):
super(TransformerBlock, self).__init__()
self.self_attn = TransformerLayer(d_model=d_model,
nhead=nhead,
attention_type=attention_type,
no_ffn=True,
ffn_dim_expansion=ffn_dim_expansion,
with_shift=with_shift,
)
self.cross_attn_ffn = TransformerLayer(d_model=d_model,
nhead=nhead,
attention_type=attention_type,
ffn_dim_expansion=ffn_dim_expansion,
with_shift=with_shift,
)
def forward(self, source, target,
height=None,
width=None,
shifted_window_attn_mask=None,
attn_num_splits=None,
**kwargs,
):
# source, target: [B, L, C]
# self attention
source = self.self_attn(source, source,
height=height,
width=width,
shifted_window_attn_mask=shifted_window_attn_mask,
attn_num_splits=attn_num_splits,
)
# cross attention and ffn
source = self.cross_attn_ffn(source, target,
height=height,
width=width,
shifted_window_attn_mask=shifted_window_attn_mask,
attn_num_splits=attn_num_splits,
)
return source
class FeatureTransformer(nn.Module):
def __init__(self,
num_layers=6,
d_model=128,
nhead=1,
attention_type='swin',
ffn_dim_expansion=4,
**kwargs,
):
super(FeatureTransformer, self).__init__()
self.attention_type = attention_type
self.d_model = d_model
self.nhead = nhead
self.layers = nn.ModuleList([
TransformerBlock(d_model=d_model,
nhead=nhead,
attention_type=attention_type,
ffn_dim_expansion=ffn_dim_expansion,
with_shift=True if attention_type == 'swin' and i % 2 == 1 else False,
)
for i in range(num_layers)])
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feature0, feature1,
attn_num_splits=None,
**kwargs,
):
b, c, h, w = feature0.shape
assert self.d_model == c
feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
if self.attention_type == 'swin' and attn_num_splits > 1:
# global and refine use different number of splits
window_size_h = h // attn_num_splits
window_size_w = w // attn_num_splits
# compute attn mask once
shifted_window_attn_mask = generate_shift_window_attn_mask(
input_resolution=(h, w),
window_size_h=window_size_h,
window_size_w=window_size_w,
shift_size_h=window_size_h // 2,
shift_size_w=window_size_w // 2,
device=feature0.device,
) # [K*K, H/K*W/K, H/K*W/K]
else:
shifted_window_attn_mask = None
# concat feature0 and feature1 in batch dimension to compute in parallel
concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
for layer in self.layers:
concat0 = layer(concat0, concat1,
height=h,
width=w,
shifted_window_attn_mask=shifted_window_attn_mask,
attn_num_splits=attn_num_splits,
)
# update feature1
concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
# reshape back
feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
return feature0, feature1
class FeatureFlowAttention(nn.Module):
"""
flow propagation with self-attention on feature
query: feature0, key: feature0, value: flow
"""
def __init__(self, in_channels,
**kwargs,
):
super(FeatureFlowAttention, self).__init__()
self.q_proj = nn.Linear(in_channels, in_channels)
self.k_proj = nn.Linear(in_channels, in_channels)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feature0, flow,
local_window_attn=False,
local_window_radius=1,
**kwargs,
):
# q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
if local_window_attn:
return self.forward_local_window_attn(feature0, flow,
local_window_radius=local_window_radius)
b, c, h, w = feature0.size()
query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
query = self.q_proj(query) # [B, H*W, C]
key = self.k_proj(query) # [B, H*W, C]
value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W]
prob = torch.softmax(scores, dim=-1)
out = torch.matmul(prob, value) # [B, H*W, 2]
out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
return out
def forward_local_window_attn(self, feature0, flow,
local_window_radius=1,
):
assert flow.size(1) == 2
assert local_window_radius > 0
b, c, h, w = feature0.size()
feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
).reshape(b * h * w, 1, c) # [B*H*W, 1, C]
kernel_size = 2 * local_window_radius + 1
feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
flow_window = F.unfold(flow, kernel_size=kernel_size,
padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute(
0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2]
scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]
prob = torch.softmax(scores, dim=-1)
out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
return out

View File

@@ -0,0 +1,90 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
class MultiScaleTridentConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
strides=1,
paddings=0,
dilations=1,
dilation=1,
groups=1,
num_branch=1,
test_branch_idx=-1,
bias=False,
norm=None,
activation=None,
):
super(MultiScaleTridentConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.num_branch = num_branch
self.stride = _pair(stride)
self.groups = groups
self.with_bias = bias
self.dilation = dilation
if isinstance(paddings, int):
paddings = [paddings] * self.num_branch
if isinstance(dilations, int):
dilations = [dilations] * self.num_branch
if isinstance(strides, int):
strides = [strides] * self.num_branch
self.paddings = [_pair(padding) for padding in paddings]
self.dilations = [_pair(dilation) for dilation in dilations]
self.strides = [_pair(stride) for stride in strides]
self.test_branch_idx = test_branch_idx
self.norm = norm
self.activation = activation
assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
)
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
if self.bias is not None:
nn.init.constant_(self.bias, 0)
def forward(self, inputs):
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
assert len(inputs) == num_branch
if self.training or self.test_branch_idx == -1:
outputs = [
F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
for input, stride, padding in zip(inputs, self.strides, self.paddings)
]
else:
outputs = [
F.conv2d(
inputs[0],
self.weight,
self.bias,
self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
self.dilation,
self.groups,
)
]
if self.norm is not None:
outputs = [self.norm(x) for x in outputs]
if self.activation is not None:
outputs = [self.activation(x) for x in outputs]
return outputs

98
sgm_vfi_arch/utils.py Normal file
View File

@@ -0,0 +1,98 @@
import torch
import torch.nn.functional as F
from .position import PositionEmbeddingSine
from .geometry import coords_grid, generate_window_grid, normalize_coords
def split_feature(feature,
num_splits=2,
channel_last=False,
):
if channel_last: # [B, H, W, C]
b, h, w, c = feature.size()
assert h % num_splits == 0 and w % num_splits == 0
b_new = b * num_splits * num_splits
h_new = h // num_splits
w_new = w // num_splits
feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
else: # [B, C, H, W]
b, c, h, w = feature.size()
assert h % num_splits == 0 and w % num_splits == 0
b_new = b * num_splits * num_splits
h_new = h // num_splits
w_new = w // num_splits
feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
return feature
def merge_splits(splits,
num_splits=2,
channel_last=False,
):
if channel_last: # [B*K*K, H/K, W/K, C]
b, h, w, c = splits.size()
new_b = b // num_splits // num_splits
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
else: # [B*K*K, C, H/K, W/K]
b, c, h, w = splits.size()
new_b = b // num_splits // num_splits
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
return merge
def feature_add_position(feature0, feature1, attn_splits, feature_channels):
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
if attn_splits > 1: # add position in splited window
feature0_splits = split_feature(feature0, num_splits=attn_splits)
feature1_splits = split_feature(feature1, num_splits=attn_splits)
position = pos_enc(feature0_splits)
feature0_splits = feature0_splits + position
feature1_splits = feature1_splits + position
feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
else:
position = pos_enc(feature0)
feature0 = feature0 + position
feature1 = feature1 + position
return feature0, feature1
class InputPadder:
""" Pads images such that dimensions are divisible by 8 """
def __init__(self, dims, mode='sintel', padding_factor=8, additional_pad=False):
self.ht, self.wd = dims[-2:]
add_pad = padding_factor*2 if additional_pad else 0
pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor + add_pad
pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor + add_pad
if mode == 'sintel':
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]

25
sgm_vfi_arch/warplayer.py Normal file
View File

@@ -0,0 +1,25 @@
import torch
backwarp_tenGrid = {}
def clear_warp_cache():
"""Free all cached grid tensors (call between frame pairs to reclaim VRAM)."""
backwarp_tenGrid.clear()
def warp(tenInput, tenFlow):
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device).view(
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device).view(
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
backwarp_tenGrid[k] = torch.cat(
[tenHorizontal, tenVertical], 1).to(tenFlow.device)
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)