diff --git a/README.md b/README.md index f0c60c4..05d9bdc 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,21 @@ -# ComfyUI BIM-VFI + EMA-VFI +# ComfyUI BIM-VFI + EMA-VFI + SGM-VFI -ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025) and [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023). Designed for long videos with thousands of frames — processes them without running out of VRAM. +ComfyUI custom nodes for video frame interpolation using [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) (CVPR 2025), [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) (CVPR 2023), and [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) (CVPR 2024). Designed for long videos with thousands of frames — processes them without running out of VRAM. + +## Which model should I use? + +| | BIM-VFI | EMA-VFI | SGM-VFI | +|---|---------|---------|---------| +| **Best for** | General-purpose, non-uniform motion | Fast inference, light VRAM | Large motion, occlusion-heavy scenes | +| **Quality** | Highest overall | Good | Best on large motion | +| **Speed** | Moderate | Fastest | Slowest | +| **VRAM** | ~2 GB/pair | ~1.5 GB/pair | ~3 GB/pair | +| **Params** | ~17M | ~14–65M | ~15M + GMFlow | +| **Arbitrary timestep** | Yes | Yes (with `_t` checkpoint) | No (fixed 0.5) | +| **Paper** | CVPR 2025 | CVPR 2023 | CVPR 2024 | +| **License** | Research only | Apache 2.0 | Apache 2.0 | + +**TL;DR:** Start with **BIM-VFI** for best quality. Use **EMA-VFI** if you need speed or lower VRAM. Use **SGM-VFI** if your video has large camera motion or fast-moving objects that the others struggle with. ## Nodes @@ -66,7 +81,32 @@ Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate. Same as EMA-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate. -**Output frame count (both models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7 +### SGM-VFI + +#### Load SGM-VFI Model + +Loads an SGM-VFI checkpoint. Auto-downloads from Google Drive on first use to `ComfyUI/models/sgm-vfi/`. Variant (base/small) is auto-detected from the filename (default is small). + +| Input | Description | +|-------|-------------| +| **model_path** | Checkpoint file from `models/sgm-vfi/` | +| **tta** | Test-time augmentation: flip input and average with unflipped result (~2x slower, slightly better quality) | +| **num_key_points** | Sparsity of global matching (0.0 = global everywhere, 0.5 = default balance, higher = faster) | + +Available checkpoints: +| Checkpoint | Variant | Params | +|-----------|---------|--------| +| `ours-1-2-points.pth` | Small | ~15M + GMFlow | + +#### SGM-VFI Interpolate + +Interpolates frames from an image batch. Same controls as BIM-VFI Interpolate. + +#### SGM-VFI Segment Interpolate + +Same as SGM-VFI Interpolate but processes a single segment. Same pattern as BIM-VFI Segment Interpolate. + +**Output frame count (all models):** 2x = 2N-1, 4x = 4N-3, 8x = 8N-7 ## Installation @@ -94,8 +134,8 @@ python install.py ### Requirements - PyTorch with CUDA -- `cupy` (matching your CUDA version, for BIM-VFI) -- `timm` (for EMA-VFI) +- `cupy` (matching your CUDA version, for BIM-VFI and SGM-VFI) +- `timm` (for EMA-VFI and SGM-VFI) - `gdown` (for model auto-download) ## VRAM Guide @@ -109,7 +149,7 @@ python install.py ## Acknowledgments -This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab) and the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU. Architecture files in `bim_vfi_arch/` and `ema_vfi_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, inference-only paths). +This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VFI) implementation by the [KAIST VIC Lab](https://github.com/KAIST-VICLab), the official [EMA-VFI](https://github.com/MCG-NJU/EMA-VFI) implementation by MCG-NJU, and the official [SGM-VFI](https://github.com/MCG-NJU/SGM-VFI) implementation by MCG-NJU. Architecture files in `bim_vfi_arch/`, `ema_vfi_arch/`, and `sgm_vfi_arch/` are vendored from their respective repositories with minimal modifications (relative imports, device-awareness fixes, inference-only paths). **BiM-VFI:** > Wonyong Seo, Jihyong Oh, and Munchurl Kim. @@ -141,8 +181,25 @@ This project wraps the official [BiM-VFI](https://github.com/KAIST-VICLab/BiM-VF } ``` +**SGM-VFI:** +> Guozhen Zhang, Yuhan Zhu, Evan Zheran Liu, Haonan Wang, Mingzhen Sun, Gangshan Wu, and Limin Wang. +> "Sparse Global Matching for Video Frame Interpolation with Large Motion." +> *IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2024. +> [[arXiv]](https://arxiv.org/abs/2404.06913) [[GitHub]](https://github.com/MCG-NJU/SGM-VFI) + +```bibtex +@inproceedings{zhang2024sgmvfi, + title={Sparse Global Matching for Video Frame Interpolation with Large Motion}, + author={Zhang, Guozhen and Zhu, Yuhan and Liu, Evan Zheran and Wang, Haonan and Sun, Mingzhen and Wu, Gangshan and Wang, Limin}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2024} +} +``` + ## License The BiM-VFI model weights and architecture code are provided by KAIST VIC Lab for **research and education purposes only**. Commercial use requires permission from the principal investigator (Prof. Munchurl Kim, mkimee@kaist.ac.kr). See the [original repository](https://github.com/KAIST-VICLab/BiM-VFI) for details. The EMA-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/EMA-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/EMA-VFI) for details. + +The SGM-VFI model weights and architecture code are released under the [Apache 2.0 License](https://github.com/MCG-NJU/SGM-VFI/blob/main/LICENSE). See the [original repository](https://github.com/MCG-NJU/SGM-VFI) for details. diff --git a/__init__.py b/__init__.py index c0e0d4a..af2d6f1 100644 --- a/__init__.py +++ b/__init__.py @@ -40,6 +40,7 @@ _auto_install_deps() from .nodes import ( LoadBIMVFIModel, BIMVFIInterpolate, BIMVFISegmentInterpolate, BIMVFIConcatVideos, LoadEMAVFIModel, EMAVFIInterpolate, EMAVFISegmentInterpolate, + LoadSGMVFIModel, SGMVFIInterpolate, SGMVFISegmentInterpolate, ) NODE_CLASS_MAPPINGS = { @@ -50,6 +51,9 @@ NODE_CLASS_MAPPINGS = { "LoadEMAVFIModel": LoadEMAVFIModel, "EMAVFIInterpolate": EMAVFIInterpolate, "EMAVFISegmentInterpolate": EMAVFISegmentInterpolate, + "LoadSGMVFIModel": LoadSGMVFIModel, + "SGMVFIInterpolate": SGMVFIInterpolate, + "SGMVFISegmentInterpolate": SGMVFISegmentInterpolate, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -60,4 +64,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoadEMAVFIModel": "Load EMA-VFI Model", "EMAVFIInterpolate": "EMA-VFI Interpolate", "EMAVFISegmentInterpolate": "EMA-VFI Segment Interpolate", + "LoadSGMVFIModel": "Load SGM-VFI Model", + "SGMVFIInterpolate": "SGM-VFI Interpolate", + "SGMVFISegmentInterpolate": "SGM-VFI Segment Interpolate", } diff --git a/inference.py b/inference.py index 175c8b7..01203c4 100644 --- a/inference.py +++ b/inference.py @@ -7,6 +7,8 @@ import torch.nn as nn from .bim_vfi_arch import BiMVFI from .ema_vfi_arch import feature_extractor as ema_feature_extractor from .ema_vfi_arch import MultiScaleFlow as EMAMultiScaleFlow +from .sgm_vfi_arch import feature_extractor as sgm_feature_extractor +from .sgm_vfi_arch import MultiScaleFlow as SGMMultiScaleFlow from .utils.padder import InputPadder logger = logging.getLogger("BIM-VFI") @@ -282,3 +284,160 @@ class EMAVFIModel: pred = self._inference(img0, img1, timestep=time_step) pred = padder.unpad(pred) return torch.clamp(pred, 0, 1) + + +# --------------------------------------------------------------------------- +# SGM-VFI model wrapper +# --------------------------------------------------------------------------- + +def _sgm_init_model_config(F=16, W=7, depth=[2, 2, 2, 4], num_key_points=0.5): + """Build SGM-VFI model config dicts (backbone + multiscale).""" + return { + 'embed_dims': [F, 2*F, 4*F, 8*F], + 'num_heads': [8*F//32], + 'mlp_ratios': [4], + 'qkv_bias': True, + 'norm_layer': partial(nn.LayerNorm, eps=1e-6), + 'depths': depth, + 'window_sizes': [W] + }, { + 'embed_dims': [F, 2*F, 4*F, 8*F], + 'motion_dims': [0, 0, 0, 8*F//depth[-1]], + 'depths': depth, + 'scales': [8], + 'hidden_dims': [4*F], + 'c': F, + 'num_key_points': num_key_points, + } + + +def _sgm_detect_variant(filename): + """Auto-detect SGM-VFI model variant from filename. + + Returns (F, depth). + Default is small (F=16) since the primary checkpoint (ours-1-2-points) + is a small model. Only detect base when "base" is in the filename. + """ + name = filename.lower() + is_base = "base" in name + if is_base: + return 32, [2, 2, 2, 6] + else: + return 16, [2, 2, 2, 4] + + +class SGMVFIModel: + """Clean inference wrapper around SGM-VFI for ComfyUI integration.""" + + def __init__(self, checkpoint_path, variant="auto", num_key_points=0.5, tta=False, device="cpu"): + import os + filename = os.path.basename(checkpoint_path) + + if variant == "auto": + F_dim, depth = _sgm_detect_variant(filename) + elif variant == "small": + F_dim, depth = 16, [2, 2, 2, 4] + else: # base + F_dim, depth = 32, [2, 2, 2, 6] + + self.tta = tta + self.device = device + self.variant_name = "small" if F_dim == 16 else "base" + + backbone_cfg, multiscale_cfg = _sgm_init_model_config( + F=F_dim, depth=depth, num_key_points=num_key_points) + backbone = sgm_feature_extractor(**backbone_cfg) + self.model = SGMMultiScaleFlow(backbone, **multiscale_cfg) + self._load_checkpoint(checkpoint_path) + self.model.eval() + self.model.to(device) + + def _load_checkpoint(self, checkpoint_path): + """Load checkpoint with module prefix stripping and buffer filtering.""" + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + # Handle wrapped checkpoint formats + if isinstance(state_dict, dict): + if "model" in state_dict: + state_dict = state_dict["model"] + elif "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + + # Strip "module." prefix and filter out attn_mask/HW buffers + cleaned = {} + for k, v in state_dict.items(): + if "attn_mask" in k or k.endswith(".HW"): + continue + key = k + if key.startswith("module."): + key = key[len("module."):] + cleaned[key] = v + + self.model.load_state_dict(cleaned, strict=False) + + def to(self, device): + """Move model to device (returns self for chaining).""" + self.device = device + self.model.to(device) + return self + + @torch.no_grad() + def _inference(self, img0, img1, timestep=0.5): + """Run single inference pass. Inputs already padded, on device.""" + B = img0.shape[0] + imgs = torch.cat((img0, img1), 1) + + if self.tta: + imgs_ = imgs.flip(2).flip(3) + input_batch = torch.cat((imgs, imgs_), 0) + _, _, _, preds, _ = self.model(input_batch, timestep=timestep) + return (preds[:B] + preds[B:].flip(2).flip(3)) / 2. + else: + _, _, _, pred, _ = self.model(imgs, timestep=timestep) + return pred + + @torch.no_grad() + def interpolate_pair(self, frame0, frame1, time_step=0.5): + """Interpolate a single frame between two input frames. + + Args: + frame0: [1, C, H, W] tensor, float32, range [0, 1] + frame1: [1, C, H, W] tensor, float32, range [0, 1] + time_step: float in (0, 1) + + Returns: + Interpolated frame as [1, C, H, W] tensor, float32, clamped to [0, 1] + """ + device = next(self.model.parameters()).device + img0 = frame0.to(device) + img1 = frame1.to(device) + + padder = InputPadder(img0.shape, divisor=32, mode='replicate', center=True) + img0, img1 = padder.pad(img0, img1) + + pred = self._inference(img0, img1, timestep=time_step) + pred = padder.unpad(pred) + return torch.clamp(pred, 0, 1) + + @torch.no_grad() + def interpolate_batch(self, frames0, frames1, time_step=0.5): + """Interpolate multiple frame pairs at once. + + Args: + frames0: [B, C, H, W] tensor, float32, range [0, 1] + frames1: [B, C, H, W] tensor, float32, range [0, 1] + time_step: float in (0, 1) + + Returns: + Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1] + """ + device = next(self.model.parameters()).device + img0 = frames0.to(device) + img1 = frames1.to(device) + + padder = InputPadder(img0.shape, divisor=32, mode='replicate', center=True) + img0, img1 = padder.pad(img0, img1) + + pred = self._inference(img0, img1, timestep=time_step) + pred = padder.unpad(pred) + return torch.clamp(pred, 0, 1) diff --git a/nodes.py b/nodes.py index 4037250..2358a17 100644 --- a/nodes.py +++ b/nodes.py @@ -8,9 +8,10 @@ import torch import folder_paths from comfy.utils import ProgressBar -from .inference import BiMVFIModel, EMAVFIModel +from .inference import BiMVFIModel, EMAVFIModel, SGMVFIModel from .bim_vfi_arch import clear_backwarp_cache from .ema_vfi_arch import clear_warp_cache as clear_ema_warp_cache +from .sgm_vfi_arch import clear_warp_cache as clear_sgm_warp_cache logger = logging.getLogger("BIM-VFI") @@ -31,6 +32,14 @@ EMA_MODEL_DIR = os.path.join(folder_paths.models_dir, "ema-vfi") if not os.path.exists(EMA_MODEL_DIR): os.makedirs(EMA_MODEL_DIR, exist_ok=True) +# Google Drive folder ID for SGM-VFI pretrained models +SGM_GDRIVE_FOLDER_ID = "1S5O6W0a7XQDHgBtP9HnmoxYEzWBIzSJq" +SGM_DEFAULT_MODEL = "ours-1-2-points.pth" + +SGM_MODEL_DIR = os.path.join(folder_paths.models_dir, "sgm-vfi") +if not os.path.exists(SGM_MODEL_DIR): + os.makedirs(SGM_MODEL_DIR, exist_ok=True) + def get_available_models(): """List available checkpoint files in the bim-vfi model directory.""" @@ -767,3 +776,310 @@ class EMAVFISegmentInterpolate(EMAVFIInterpolate): result = result[1:] # skip duplicate boundary frame return (result, model) + + +# --------------------------------------------------------------------------- +# SGM-VFI nodes +# --------------------------------------------------------------------------- + +def get_available_sgm_models(): + """List available checkpoint files in the sgm-vfi model directory.""" + models = [] + if os.path.isdir(SGM_MODEL_DIR): + for f in os.listdir(SGM_MODEL_DIR): + if f.endswith((".pkl", ".pth", ".pt", ".ckpt", ".safetensors")): + models.append(f) + if not models: + models.append(SGM_DEFAULT_MODEL) # Will trigger auto-download + return sorted(models) + + +def download_sgm_model_from_gdrive(folder_id, dest_path): + """Download SGM-VFI model from Google Drive folder using gdown.""" + try: + import gdown + except ImportError: + raise RuntimeError( + "gdown is required to auto-download the SGM-VFI model. " + "Install it with: pip install gdown" + ) + filename = os.path.basename(dest_path) + url = f"https://drive.google.com/drive/folders/{folder_id}" + logger.info(f"Downloading {filename} from Google Drive folder to {dest_path}...") + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + gdown.download_folder(url, output=os.path.dirname(dest_path), quiet=False, remaining_ok=True) + if not os.path.exists(dest_path): + raise RuntimeError( + f"Failed to download {filename}. Please download manually from " + f"https://drive.google.com/drive/folders/{folder_id} " + f"and place it in {os.path.dirname(dest_path)}" + ) + logger.info("Download complete.") + + +class LoadSGMVFIModel: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model_path": (get_available_sgm_models(), { + "default": SGM_DEFAULT_MODEL, + "tooltip": "Checkpoint file from models/sgm-vfi/. Auto-downloads on first use if missing. " + "Variant (base/small) is auto-detected from filename.", + }), + "tta": ("BOOLEAN", { + "default": False, + "tooltip": "Test-time augmentation: flip input and average with unflipped result. " + "~2x slower but slightly better quality.", + }), + "num_key_points": ("FLOAT", { + "default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05, + "tooltip": "Sparsity of global matching. 0.0 = global matching everywhere (slower, better for large motion). " + "Higher = sparser keypoints (faster). Default 0.5 is a good balance.", + }), + } + } + + RETURN_TYPES = ("SGM_VFI_MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "load_model" + CATEGORY = "video/SGM-VFI" + + def load_model(self, model_path, tta, num_key_points): + full_path = os.path.join(SGM_MODEL_DIR, model_path) + + if not os.path.exists(full_path): + logger.info(f"Model not found at {full_path}, attempting download...") + download_sgm_model_from_gdrive(SGM_GDRIVE_FOLDER_ID, full_path) + + wrapper = SGMVFIModel( + checkpoint_path=full_path, + variant="auto", + num_key_points=num_key_points, + tta=tta, + device="cpu", + ) + + logger.info(f"SGM-VFI model loaded (variant={wrapper.variant_name}, num_key_points={num_key_points}, tta={tta})") + return (wrapper,) + + +class SGMVFIInterpolate: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE", { + "tooltip": "Input image batch. Output frame count: 2x=(2N-1), 4x=(4N-3), 8x=(8N-7).", + }), + "model": ("SGM_VFI_MODEL", { + "tooltip": "SGM-VFI model from the Load SGM-VFI Model node.", + }), + "multiplier": ([2, 4, 8], { + "default": 2, + "tooltip": "Frame rate multiplier. 2x=one interpolation pass, 4x=two recursive passes, 8x=three. Higher = more frames but longer processing.", + }), + "clear_cache_after_n_frames": ("INT", { + "default": 10, "min": 1, "max": 100, "step": 1, + "tooltip": "Clear CUDA cache every N frame pairs to prevent VRAM buildup. Lower = less VRAM but slower. Ignored when all_on_gpu is enabled.", + }), + "keep_device": ("BOOLEAN", { + "default": True, + "tooltip": "Keep model on GPU between frame pairs. Faster but uses more VRAM constantly. Disable to free VRAM between pairs (slower due to CPU-GPU transfers).", + }), + "all_on_gpu": ("BOOLEAN", { + "default": False, + "tooltip": "Store all intermediate frames on GPU instead of CPU. Much faster (no transfers) but requires enough VRAM for all frames. Recommended for 48GB+ cards.", + }), + "batch_size": ("INT", { + "default": 1, "min": 1, "max": 64, "step": 1, + "tooltip": "Number of frame pairs to process simultaneously. Higher = faster but uses more VRAM. Start with 1, increase until VRAM is full.", + }), + "chunk_size": ("INT", { + "default": 0, "min": 0, "max": 10000, "step": 1, + "tooltip": "Process input frames in chunks of this size (0=disabled). Bounds VRAM usage during processing but the full output is still assembled in RAM. To bound RAM, use the Segment Interpolate node instead.", + }), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) + FUNCTION = "interpolate" + CATEGORY = "video/SGM-VFI" + + def _interpolate_frames(self, frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref): + """Run all interpolation passes on a chunk of frames.""" + for pass_idx in range(num_passes): + new_frames = [] + num_pairs = frames.shape[0] - 1 + pairs_since_clear = 0 + + for i in range(0, num_pairs, batch_size): + batch_end = min(i + batch_size, num_pairs) + actual_batch = batch_end - i + + frames0 = frames[i:batch_end] + frames1 = frames[i + 1:batch_end + 1] + + if not keep_device: + model.to(device) + + mids = model.interpolate_batch(frames0, frames1, time_step=0.5) + mids = mids.to(storage_device) + + if not keep_device: + model.to("cpu") + + for j in range(actual_batch): + new_frames.append(frames[i + j:i + j + 1]) + new_frames.append(mids[j:j+1]) + + step_ref[0] += actual_batch + pbar.update_absolute(step_ref[0]) + + pairs_since_clear += actual_batch + if not all_on_gpu and pairs_since_clear >= clear_cache_after_n_frames and torch.cuda.is_available(): + clear_sgm_warp_cache() + torch.cuda.empty_cache() + pairs_since_clear = 0 + + new_frames.append(frames[-1:]) + frames = torch.cat(new_frames, dim=0) + + if not all_on_gpu and torch.cuda.is_available(): + clear_sgm_warp_cache() + torch.cuda.empty_cache() + + return frames + + @staticmethod + def _count_steps(num_frames, num_passes): + """Count total interpolation steps for a given input frame count.""" + n = num_frames + total = 0 + for _ in range(num_passes): + total += n - 1 + n = 2 * n - 1 + return total + + def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, + keep_device, all_on_gpu, batch_size, chunk_size): + if images.shape[0] < 2: + return (images,) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + num_passes = {2: 1, 4: 2, 8: 3}[multiplier] + + if all_on_gpu: + keep_device = True + + storage_device = device if all_on_gpu else torch.device("cpu") + + # Convert from ComfyUI [B, H, W, C] to model [B, C, H, W] + all_frames = images.permute(0, 3, 1, 2).to(storage_device) + total_input = all_frames.shape[0] + + # Build chunk boundaries (1-frame overlap between consecutive chunks) + if chunk_size < 2 or chunk_size >= total_input: + chunks = [(0, total_input)] + else: + chunks = [] + start = 0 + while start < total_input - 1: + end = min(start + chunk_size, total_input) + chunks.append((start, end)) + start = end - 1 # overlap by 1 frame + if end == total_input: + break + + # Calculate total progress steps across all chunks + total_steps = sum(self._count_steps(ce - cs, num_passes) for cs, ce in chunks) + pbar = ProgressBar(total_steps) + step_ref = [0] + + if keep_device: + model.to(device) + + result_chunks = [] + for chunk_idx, (chunk_start, chunk_end) in enumerate(chunks): + chunk_frames = all_frames[chunk_start:chunk_end].clone() + + chunk_result = self._interpolate_frames( + chunk_frames, model, num_passes, batch_size, + device, storage_device, keep_device, all_on_gpu, + clear_cache_after_n_frames, pbar, step_ref, + ) + + # Skip first frame of subsequent chunks (duplicate of previous chunk's last frame) + if chunk_idx > 0: + chunk_result = chunk_result[1:] + + # Move completed chunk to CPU to bound memory when chunking + if len(chunks) > 1: + chunk_result = chunk_result.cpu() + + result_chunks.append(chunk_result) + + result = torch.cat(result_chunks, dim=0) + # Convert back to ComfyUI [B, H, W, C], on CPU + result = result.cpu().permute(0, 2, 3, 1) + return (result,) + + +class SGMVFISegmentInterpolate(SGMVFIInterpolate): + """Process a numbered segment of the input batch for SGM-VFI. + + Chain multiple instances with Save nodes between them to bound peak RAM. + The model pass-through output forces sequential execution so each segment + saves and frees from RAM before the next starts. + """ + + @classmethod + def INPUT_TYPES(cls): + base = SGMVFIInterpolate.INPUT_TYPES() + base["required"]["segment_index"] = ("INT", { + "default": 0, "min": 0, "max": 10000, "step": 1, + "tooltip": "Which segment to process (0-based). Bounds RAM by only producing this segment's output frames, " + "unlike chunk_size which bounds VRAM but still assembles the full output in RAM. " + "Chain the model output to the next Segment Interpolate to force sequential execution.", + }) + base["required"]["segment_size"] = ("INT", { + "default": 500, "min": 2, "max": 10000, "step": 1, + "tooltip": "Number of input frames per segment. Adjacent segments overlap by 1 frame for seamless stitching. " + "Smaller = less peak RAM per segment. Save each segment's output to disk before the next runs.", + }) + return base + + RETURN_TYPES = ("IMAGE", "SGM_VFI_MODEL") + RETURN_NAMES = ("images", "model") + FUNCTION = "interpolate" + CATEGORY = "video/SGM-VFI" + + def interpolate(self, images, model, multiplier, clear_cache_after_n_frames, + keep_device, all_on_gpu, batch_size, chunk_size, + segment_index, segment_size): + total_input = images.shape[0] + + # Compute segment boundaries (1-frame overlap) + start = segment_index * (segment_size - 1) + end = min(start + segment_size, total_input) + + if start >= total_input - 1: + # Past the end — return empty single frame + model + return (images[:1], model) + + segment_images = images[start:end] + is_continuation = segment_index > 0 + + # Delegate to the parent interpolation logic + (result,) = super().interpolate( + segment_images, model, multiplier, clear_cache_after_n_frames, + keep_device, all_on_gpu, batch_size, chunk_size, + ) + + if is_continuation: + result = result[1:] # skip duplicate boundary frame + + return (result, model) diff --git a/sgm_vfi_arch/__init__.py b/sgm_vfi_arch/__init__.py new file mode 100644 index 0000000..e872ec4 --- /dev/null +++ b/sgm_vfi_arch/__init__.py @@ -0,0 +1,5 @@ +from .feature_extractor import feature_extractor +from .flow_estimation import MultiScaleFlow +from .warplayer import clear_warp_cache + +__all__ = ['feature_extractor', 'MultiScaleFlow', 'clear_warp_cache'] diff --git a/sgm_vfi_arch/backbone.py b/sgm_vfi_arch/backbone.py new file mode 100644 index 0000000..35c2b2c --- /dev/null +++ b/sgm_vfi_arch/backbone.py @@ -0,0 +1,116 @@ +import torch.nn as nn + +from .trident_conv import MultiScaleTridentConv + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, + ): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, stride=stride, bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, bias=False) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) + if not stride == 1 or in_planes != planes: + self.norm3 = norm_layer(planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class CNNEncoder(nn.Module): + def __init__(self, output_dim=128, + norm_layer=nn.InstanceNorm2d, + num_output_scales=1, + **kwargs, + ): + super(CNNEncoder, self).__init__() + self.num_branch = num_output_scales + + feature_dims = [64, 96, 128] + + self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 + self.norm1 = norm_layer(feature_dims[0]) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = feature_dims[0] + self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 + self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 + + # highest resolution 1/4 or 1/8 + stride = 2 if num_output_scales == 1 else 1 + self.layer3 = self._make_layer(feature_dims[2], stride=stride, norm_layer=norm_layer, + ) # 1/4 or 1/8 + + self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) + + if self.num_branch > 1: + if self.num_branch == 4: + strides = (1, 2, 4, 8) + elif self.num_branch == 3: + strides = (1, 2, 4) + elif self.num_branch == 2: + strides = (1, 2) + else: + raise ValueError + + self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, + kernel_size=3, + strides=strides, + paddings=1, + num_branch=self.num_branch, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): + layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) + layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) + + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) # 1/2 + x = self.layer2(x) # 1/4 + x = self.layer3(x) # 1/8 or 1/4 + + x = self.conv2(x) + + if self.num_branch > 1: + out = self.trident_conv([x] * self.num_branch) # high to low res + else: + out = [x] + + return out diff --git a/sgm_vfi_arch/feature_extractor.py b/sgm_vfi_arch/feature_extractor.py new file mode 100644 index 0000000..eb1deac --- /dev/null +++ b/sgm_vfi_arch/feature_extractor.py @@ -0,0 +1,459 @@ +import torch +import torch.nn as nn +import math +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from .position import PositionEmbeddingSine + +def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], C) + ) + return windows + + +def window_reverse(windows, window_size, H, W): + nwB, N, C = windows.shape + windows = windows.view(-1, window_size[0], window_size[1], C) + B = int(nwB / (H * W / window_size[0] / window_size[1])) + x = windows.view( + B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +def pad_if_needed(x, size, window_size): + n, h, w, c = size + pad_h = math.ceil(h / window_size[0]) * window_size[0] - h + pad_w = math.ceil(w / window_size[1]) * window_size[1] - w + if pad_h > 0 or pad_w > 0: # center-pad the feature on H and W axes + img_mask = torch.zeros((1, h + pad_h, w + pad_w, 1)) # 1 H W 1 + h_slices = ( + slice(0, pad_h // 2), + slice(pad_h // 2, h + pad_h // 2), + slice(h + pad_h // 2, None), + ) + w_slices = ( + slice(0, pad_w // 2), + slice(pad_w // 2, w + pad_w // 2), + slice(w + pad_w // 2, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, window_size + ) # nW, window_size*window_size, 1 + mask_windows = mask_windows.squeeze(-1) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) + return nn.functional.pad( + x, + (0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), + ), attn_mask + return x, None + + +def depad_if_needed(x, size, window_size): + n, h, w, c = size + pad_h = math.ceil(h / window_size[0]) * window_size[0] - h + pad_w = math.ceil(w / window_size[1]) * window_size[1] - w + if pad_h > 0 or pad_w > 0: # remove the center-padding on feature + return x[:, pad_h // 2: pad_h // 2 + h, pad_w // 2: pad_w // 2 + w, :].contiguous() + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.relu = nn.ReLU(inplace=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class InterFrameAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x1, x2, H, W, mask=None): + B, N, C = x1.shape + q = self.q(x1).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + kv = self.kv(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + attn = (q @ k.transpose(-2, -1)) * self.scale + + if mask is not None: + nW = mask.shape[0] # mask: nW, N, N + attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = attn.softmax(dim=-1) + else: + attn = attn.softmax(dim=-1) + + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MotionFormerBlock(nn.Module): + def __init__(self, dim, num_heads, window_size=0, shift_size=0, mlp_ratio=4., bidirectional=True, + qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): + super().__init__() + self.window_size = window_size + if not isinstance(self.window_size, (tuple, list)): + self.window_size = to_2tuple(window_size) + self.shift_size = shift_size + if not isinstance(self.shift_size, (tuple, list)): + self.shift_size = to_2tuple(shift_size) + self.bidirectional = bidirectional + self.norm1 = norm_layer(dim) + self.attn = InterFrameAttention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # BEGIN: absolute pos_embed, beneficial to local information extraction in our experiments + self.pos_embed = PositionEmbeddingSine(dim // 2) + # END: absolute pos_embed + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W, B, self_att=False): + x = x.view(2 * B, H, W, -1) + x_pad, mask = pad_if_needed(x, x.size(), self.window_size) + + if self.shift_size[0] or self.shift_size[1]: + _, H_p, W_p, C = x_pad.shape + x_pad = torch.roll(x_pad, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) + + if hasattr(self, 'HW') and self.HW.item() == H_p * W_p: + shift_mask = self.attn_mask + else: + shift_mask = torch.zeros((1, H_p, W_p, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices = (slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + shift_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(shift_mask, self.window_size).squeeze(-1) + shift_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + shift_mask = shift_mask.masked_fill(shift_mask != 0, + float(-100.0)).masked_fill(shift_mask == 0, + float(0.0)) + + if mask is not None: + shift_mask = shift_mask.masked_fill(mask != 0, + float(-100.0)) + self.register_buffer("attn_mask", shift_mask) + self.register_buffer("HW", torch.Tensor([H_p * W_p])) + else: + shift_mask = mask + + if shift_mask is not None: + shift_mask = shift_mask.to(x_pad.device) + + _, Hw, Ww, C = x_pad.shape + x_win = window_partition(x_pad, self.window_size) + + nwB = x_win.shape[0] + x_norm = self.norm1(x_win) + # BEGIN: absolute pos embed, beneficial to local information extraction in our experiments + x_norm = x_norm.view(nwB, self.window_size[0], self.window_size[1], C).permute(0, 3, 1, 2) + ape = self.pos_embed(x_norm) + x_norm = x_norm + ape + x_norm = x_norm.permute(0, 2, 3, 1).view(nwB, self.window_size[0] * self.window_size[1], C) + # END: absolute pos embed + + if self_att is False: + x_reverse = torch.cat([x_norm[nwB // 2:], x_norm[:nwB // 2]]) + x_appearence = self.attn(x_norm, x_reverse, H, W, shift_mask) + else: + x_appearence = self.attn(x_norm, x_norm, H, W, shift_mask) + + x_norm = x_norm + self.drop_path(x_appearence) + + x_back = x_norm + x_back_win = window_reverse(x_back, self.window_size, Hw, Ww) + + if self.shift_size[0] or self.shift_size[1]: + x_back_win = torch.roll(x_back_win, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2)) + + x = depad_if_needed(x_back_win, x.size(), self.window_size).view(2 * B, H * W, -1) + + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + return x + + +class ConvBlock(nn.Module): + def __init__(self, in_dim, out_dim, depths=2, act_layer=nn.PReLU): + super().__init__() + layers = [] + for i in range(depths): + if i == 0: + layers.append(nn.Conv2d(in_dim, out_dim, 3, 1, 1)) + else: + layers.append(nn.Conv2d(out_dim, out_dim, 3, 1, 1)) + layers.extend([ + act_layer(out_dim), + ]) + self.conv = nn.Sequential(*layers) + + def _init_weights(self, m): + if isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.conv(x) + return x + + +class OverlapPatchEmbed(nn.Module): + def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + patch_size = to_2tuple(patch_size) + + self.patch_size = patch_size + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2)) + self.norm = nn.LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, H, W + + +class MotionFormer(nn.Module): + def __init__(self, in_chans=3, embed_dims=None, num_heads=None, + mlp_ratios=None, qkv_bias=True, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=None, window_sizes=None, **kwarg): + super().__init__() + self.depths = depths + self.num_stages = len(embed_dims) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + + self.conv_stages = self.num_stages - len(num_heads) + + for i in range(self.num_stages): + if i == 0: + block = ConvBlock(in_chans, embed_dims[i], depths[i]) + else: + if i < self.conv_stages: + patch_embed = nn.Sequential( + nn.Conv2d(embed_dims[i - 1], embed_dims[i], 3, 2, 1), + nn.PReLU(embed_dims[i]) + ) + block = ConvBlock(embed_dims[i], embed_dims[i], depths[i]) + else: + patch_embed = OverlapPatchEmbed(patch_size=3, + stride=2, + in_chans=embed_dims[i - 1], + embed_dim=embed_dims[i]) + + block = nn.ModuleList([MotionFormerBlock( + dim=embed_dims[i], num_heads=num_heads[i - self.conv_stages], + window_size=window_sizes[i - self.conv_stages], + shift_size=0 if (j % 2) == 0 else window_sizes[i - self.conv_stages] // 2, + mlp_ratio=mlp_ratios[i - self.conv_stages], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer) + for j in range(depths[i])]) + + norm = norm_layer(embed_dims[i]) + setattr(self, f"norm{i + 1}", norm) + setattr(self, f"patch_embed{i + 1}", patch_embed) + cur += depths[i] + + setattr(self, f"block{i + 1}", block) + + self.cor = {} + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def get_cor(self, shape, device): + k = (str(shape), str(device)) + if k not in self.cor: + tenHorizontal = torch.linspace(-1.0, 1.0, shape[2], device=device).view( + 1, 1, 1, shape[2]).expand(shape[0], -1, shape[1], -1).permute(0, 2, 3, 1) + tenVertical = torch.linspace(-1.0, 1.0, shape[1], device=device).view( + 1, 1, shape[1], 1).expand(shape[0], -1, -1, shape[2]).permute(0, 2, 3, 1) + self.cor[k] = torch.cat([tenHorizontal, tenVertical], -1).to(device) + return self.cor[k] + + def forward(self, x1, x2): + B = x1.shape[0] + x = torch.cat([x1, x2], 0) + appearence_features = [] + xs = [] + for i in range(self.num_stages): + patch_embed = getattr(self, f"patch_embed{i + 1}", None) + block = getattr(self, f"block{i + 1}", None) + norm = getattr(self, f"norm{i + 1}", None) + if i < self.conv_stages: + if i > 0: + x = patch_embed(x) + x = block(x) + xs.append(x) + else: + x, H, W = patch_embed(x) + for j in range(len(block)): + x = block[j](x, H, W, B, self_att=False) + xs.append(x.reshape(2 * B, H, W, -1).permute(0, 3, 1, 2).contiguous()) + x = norm(x) + x = x.reshape(2 * B, H, W, -1).permute(0, 3, 1, 2).contiguous() + appearence_features.append(x) + return appearence_features + + +class DWConv(nn.Module): + def __init__(self, dim): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() + x = self.dwconv(x) + x = x.reshape(B, C, -1).transpose(1, 2) + + return x + + +def feature_extractor(**kargs): + model = MotionFormer(**kargs) + return model diff --git a/sgm_vfi_arch/flow_estimation.py b/sgm_vfi_arch/flow_estimation.py new file mode 100644 index 0000000..0a3e936 --- /dev/null +++ b/sgm_vfi_arch/flow_estimation.py @@ -0,0 +1,208 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .refine import * +from .matching import MatchingBlock +from .gmflow import GMFlow +from .utils import InputPadder + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64, layers=4, scale=4, in_else=17): + super(IFBlock, self).__init__() + self.scale = scale + + self.conv0 = nn.Sequential( + conv(in_planes + in_else, c, 3, 1, 1), + conv(c, c, 3, 1, 1), + ) + + self.convblock = nn.Sequential( + *[conv(c, c) for _ in range(layers)] + ) + + self.lastconv = conv(c, 5) + + def forward(self, x, flow=None, feature=None): + if self.scale != 1: + x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", align_corners=False) + if flow != None: + flow = F.interpolate(flow, scale_factor=1. / self.scale, mode="bilinear", + align_corners=False) * 1. / self.scale + x = torch.cat((x, flow), 1) + if feature != None: + x = torch.cat((x, feature), 1) + x = self.conv0(x) + x = self.convblock(x) + x + tmp = self.lastconv(x) + flow_s = tmp[:, :4] + tmp = F.interpolate(tmp, scale_factor=self.scale, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * self.scale + mask = tmp[:, 4:5] + return flow, mask, flow_s + + +class MultiScaleFlow(nn.Module): + def __init__(self, backbone, **kargs): + super(MultiScaleFlow, self).__init__() + self.flow_num_stage = len(kargs['hidden_dims']) + self.feature_bone = backbone + self.scale = [1, 2, 4, 8] + self.num_key_points = [kargs['num_key_points']] + self.block = nn.ModuleList( + [IFBlock(kargs['embed_dims'][-1] * 2, 128, 2, self.scale[-1], in_else=7), # 1/8 + IFBlock(kargs['embed_dims'][-2] * 2, 128, 2, self.scale[-2], in_else=18)]) # 1/4 + self.contextnet = Contextnet(kargs['c'] * 2) + self.unet = Unet(kargs['c'] * 2) + self.gmflow = GMFlow( + num_scales=1, + upsample_factor=8, + feature_channels=128, + attention_type='swin', + num_transformer_layers=6, + ffn_dim_expansion=4, + num_head=1) + + self.matching_block = nn.ModuleList([ + MatchingBlock(scale=8, dim=kargs['embed_dims'][-1], c=kargs['c'] * 4, num_layers=1, gm=True), + None + ]) + + self.padding_factor = 16 + + + def calculate_flow(self, imgs, timestep): + img0, img1 = imgs[:, :3], imgs[:, 3:6] + B = img0.size(0) + flow, mask = None, None + flow_s = None + + af = self.feature_bone(img0, img1) + if self.gmflow is not None: + padder = InputPadder(img0.shape, padding_factor=self.padding_factor) + img0_p, img1_p = padder.pad(img0, img1) + results = self.gmflow(img0_p, img1_p, attn_splits_list=[1], pred_bidir_flow=False) + matching_feat = results['trans_feat'] + padder_8 = InputPadder(af[-1].shape, padding_factor=self.padding_factor // self.scale[-1]) + matching_feat[0] = padder_8.unpad(matching_feat[0]) + + for i in range(2): + t = (img0[:B, :1].clone() * 0 + 1) * timestep + af0 = af[-1 - i][:B] + af1 = af[-1 - i][B:] + if flow != None: + flow_d, mask_d, flow_s_d = self.block[i]( + torch.cat((img0, img1, warped_img0, warped_img1, mask, t), 1), + flow, + torch.cat([af0, af1], 1), + ) + flow = flow + flow_d + mask = mask + mask_d + flow_s = F.interpolate(flow_s, scale_factor=2, mode="bilinear", align_corners=False) * 2 + flow_s = flow_s + flow_s_d + else: + flow, mask, flow_s = self.block[i]( + torch.cat((img0, img1, t), 1), + None, + torch.cat([af0, af1], 1)) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + if self.matching_block[i] is not None: + dict = self.matching_block[i](img0=img0, img1=img1, x=matching_feat[i], main_x=af[-1 - i], + init_flow=flow, init_flow_s=flow_s, init_mask=mask, + warped_img0=warped_img0, warped_img1=warped_img1, + num_key_points=self.num_key_points[i], scale_factor=self.scale[-1 - i], + timestep=timestep) + flow_t, mask_t = dict['flow_t'], dict['mask_t'] + flow = flow + flow_t + mask = mask + mask_t + + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + return flow, mask + + def coraseWarp_and_Refine(self, imgs, flow, mask): + img0, img1 = imgs[:, :3], imgs[:, 3:6] + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + mask_ = torch.sigmoid(mask) + merged = warped_img0 * mask_ + warped_img1 * (1 - mask_) + pred = torch.clamp(merged + res, 0, 1) + return pred + + def forward(self, x, timestep=0.5): + img0, img1 = x[:, :3], x[:, 3:6] + B = x.size(0) + flow_list, mask_list = [], [] + merged, merged_fine = [], [] + warped_img0, warped_img1 = img0, img1 + flow, mask, flow_s = None, None, None + flow_matching_list = [] + matching_feat = [] + af = self.feature_bone(img0, img1) + if self.gmflow is not None: + padder = InputPadder(img0.shape, padding_factor=self.padding_factor, additional_pad=False) + img0_p, img1_p = padder.pad(img0, img1) + results = self.gmflow(img0_p, img1_p, attn_splits_list=[1], pred_bidir_flow=False) + matching_feat = results['trans_feat'] + padder_8 = InputPadder(af[-1].shape, padding_factor=self.padding_factor // self.scale[-1], additional_pad=False) + matching_feat[0] = padder_8.unpad(matching_feat[0]) + + for i in range(2): + af0 = af[-1 - i][:B] + af1 = af[-1 - i][B:] + t = (img0[:B, :1].clone() * 0 + 1) * timestep + if flow != None: + flow_d, mask_d, flow_s_d = self.block[i]( + torch.cat((img0, img1, warped_img0, warped_img1, mask, t), 1), + flow, + torch.cat([af0, af1], 1), + ) + flow = flow + flow_d + mask = mask + mask_d + flow_s = F.interpolate(flow_s, scale_factor=2, mode="bilinear", align_corners=False) * 2 + flow_s = flow_s + flow_s_d + else: + flow, mask, flow_s = self.block[i]( + torch.cat((img0, img1, t), 1), + None, + torch.cat([af0, af1], 1)) + mask_list.append(torch.sigmoid(mask)) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i])) + if self.matching_block[i] is not None: + dict = self.matching_block[i](img0=img0, img1=img1, x=matching_feat[i], main_x=af[-1-i].detach(), + init_flow=flow.detach(), init_flow_s=flow_s.detach(), init_mask=mask.detach(), + warped_img0=warped_img0.detach(), warped_img1=warped_img1.detach(), + num_key_points=self.num_key_points[i], scale_factor=self.scale[-1-i], + timestep=0.5) + flow_t, mask_t = dict['flow_t'], dict['mask_t'] + flow = flow + flow_t + mask = mask + mask_t + mask_list[i] = torch.sigmoid(mask) + warped_img0_fine = warp(img0, flow[:, 0:2]) + warped_img1_fine = warp(img1, flow[:, 2:4]) + merged_fine.append(warped_img0_fine * mask_list[i] + warped_img1_fine * (1 - mask_list[i])) + warped_img0, warped_img1 = warped_img0_fine, warped_img1_fine # NOTE: for next iteration training + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + pred = torch.clamp(merged[-1] + res, 0, 1) + merged.extend(merged_fine) + return flow_list, mask_list, merged, pred, flow_matching_list diff --git a/sgm_vfi_arch/geometry.py b/sgm_vfi_arch/geometry.py new file mode 100644 index 0000000..207e98f --- /dev/null +++ b/sgm_vfi_arch/geometry.py @@ -0,0 +1,96 @@ +import torch +import torch.nn.functional as F + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + + if return_mask: + mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode='zeros'): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, + return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, + alpha=0.01, + beta=0.5 + ): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ diff --git a/sgm_vfi_arch/gmflow.py b/sgm_vfi_arch/gmflow.py new file mode 100644 index 0000000..2bec572 --- /dev/null +++ b/sgm_vfi_arch/gmflow.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import CNNEncoder +from .transformer import FeatureTransformer, FeatureFlowAttention +from .utils import feature_add_position + +class GMFlow(nn.Module): + def __init__(self, + num_scales=1, + upsample_factor=8, + feature_channels=128, + attention_type='swin', + num_transformer_layers=6, + ffn_dim_expansion=4, + num_head=1, + **kwargs, + ): + super(GMFlow, self).__init__() + + self.num_scales = num_scales + self.feature_channels = feature_channels + self.upsample_factor = upsample_factor + self.attention_type = attention_type + self.num_transformer_layers = num_transformer_layers + + # CNN backbone + self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales) + + # Transformer + self.transformer = FeatureTransformer(num_layers=num_transformer_layers, + d_model=feature_channels, + nhead=num_head, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + ) + + def extract_feature(self, img0, img1): + concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] + features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low + + # reverse: resolution from low to high + features = features[::-1] + + feature0, feature1 = [], [] + + for i in range(len(features)): + feature = features[i] + chunks = torch.chunk(feature, 2, 0) # tuple + feature0.append(chunks[0]) + feature1.append(chunks[1]) + + return feature0, feature1 + + def forward(self, img0, img1, + attn_splits_list=None, + corr_radius_list=None, + prop_radius_list=None, + pred_bidir_flow=False, + **kwargs, + ): + + results_dict = {} + flow_preds = [] + flow_s_macthing = [] + flow_s_prop = [] + transformer_features = [] + + feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features + + flow = None + + for scale_idx in range(self.num_scales): + feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] + + attn_splits = attn_splits_list[scale_idx] + + feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels) + + # Transformer + feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits) + transformer_features.append(torch.cat([feature0, feature1], 0)) + + results_dict.update({'trans_feat': transformer_features}) + + return results_dict diff --git a/sgm_vfi_arch/matching.py b/sgm_vfi_arch/matching.py new file mode 100644 index 0000000..fb17099 --- /dev/null +++ b/sgm_vfi_arch/matching.py @@ -0,0 +1,278 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .warplayer import warp as backwarp +from .softsplat import softsplat +from .geometry import coords_grid + + +# for random sample ablation +def random_sample(feature, num_points=256): + rand_ind = torch.randint(low=0, high=feature.shape[1], size=(feature.shape[0], num_points)).unsqueeze(-1).to( + feature.device) + kp = torch.gather(feature, dim=1, index=rand_ind.expand(-1, -1, feature.shape[2])) + return rand_ind, kp + +def sample_key_points(importance_map, feature, num_points=256): + importance_map = importance_map.view(-1, 1, importance_map.shape[2] * importance_map.shape[3]).permute(0, 2, 1) + _, kp_ind = torch.topk(importance_map, num_points, dim=1) + kp = torch.gather(feature, dim=1, index=kp_ind.expand(-1, -1, feature.shape[2])) + return kp_ind, kp + + +def forward_warp(tenIn, tenFlow, z=None): + if z is None: + z = torch.ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]).to(tenIn.device) + else: + z = torch.where(z == 0, -20, 1) + out = softsplat(tenIn, tenFlow, tenMetric=z, strMode='soft') + return out + + +def warp_twice(imgA, target, flow_tA, flow_tB): + It_warp = backwarp(imgA, flow_tA) # backward warp(I1,Ft1) + z = torch.ones([imgA.shape[0], 1, imgA.shape[2], imgA.shape[3]]).to(imgA.device) + IB_warp = softsplat(tenIn=It_warp, tenFlow=flow_tB, tenMetric=z, strMode='soft') + return IB_warp + + +def build_map(imgA, imgB, flow_tA, flow_tB): + # build map for img B + IB_warp = warp_twice(imgA, imgB, flow_tA, flow_tB) + difference_map = IB_warp - imgB # [B, 3, H, W], difference map on IB + difference_map = torch.sum(torch.abs(difference_map), dim=1, keepdim=True) # B, 1, H, W + return difference_map + + +def build_hole_mask(img_template, flow_tA, flow_tB): + # build hole mask + with torch.no_grad(): + ones = torch.ones(img_template.shape[0], 1, img_template.shape[2], img_template.shape[3]).to( + img_template.device) + out = warp_twice(ones, ones, flow_tA, flow_tB) + hole_mask = torch.where(out == 0, 0, 1) + return hole_mask + + +def gen_importance_map(img0, img1, flow): + I1_dmap = build_map(img0, img1, flow[:, 0:2], flow[:, 2:4]) + I0_dmap = build_map(img1, img0, flow[:, 2:4], flow[:, 0:2]) + + I1_hole_mask = build_hole_mask(img0, flow[:, 0:2], flow[:, 2:4]) + I0_hole_mask = build_hole_mask(img1, flow[:, 2:4], flow[:, 0:2]) + + I1_dmap = I1_dmap * I1_hole_mask + I0_dmap = I0_dmap * I0_hole_mask + + I0_prob = warp_twice(I1_dmap, I1_dmap, flow[:, 2:4], flow[:, 0:2]) + I1_prob = warp_twice(I0_dmap, I0_dmap, flow[:, 0:2], flow[:, 2:4]) + + importance_map = torch.cat([I0_prob, I1_prob], dim=0) # 2B, 1, H, W + return importance_map + + +def global_matching(key_feature, global_feature, key_index, H, W): + b, n, c = global_feature.shape + query = key_feature + key = global_feature + correlation = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, k, H*W] + + prob = F.softmax(correlation, dim=-1) + init_grid = coords_grid(b, H, W, homogeneous=False, device=global_feature.device) + grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + out = torch.matmul(prob, grid) # B, k, 2 + if key_index is not None: + flow_fix = torch.zeros_like(grid) + # key_index: [B, K, 1], out: [B, K, 2], flow_fix: [B, H*W, 2] + flow_fix = torch.scatter(flow_fix, dim=1, index=key_index.expand(-1, -1, 2), src=out) + flow_fix = flow_fix.view(b, H, W, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + # for grid, points in grid and not in key_index, set to 0 + grid_new = torch.zeros_like(grid) + key_pos = torch.ones_like(out) + grid_new = torch.scatter(grid_new, dim=1, index=key_index.expand(-1, -1, 2), src=key_pos) + grid = (grid * grid_new).reshape(b, H, W, 2).permute(0, 3, 1, 2) + flow_fix = flow_fix - grid + else: + flow_fix = out.view(b, H, W, 2).permute(0, 3, 1, 2) + flow_fix = flow_fix - init_grid + return flow_fix, prob + + +def extract_topk(foo, k): + b, _, h, w = foo.shape + foo = foo.view(b, 1, h * w).permute(0, 2, 1) + kp, kp_ind = torch.topk(foo, k, dim=1) + grid = torch.zeros(b, h * w, 1).to(foo.device) + out = torch.scatter(grid, dim=1, index=kp_ind, src=kp) + out = out.permute(0, 2, 1).reshape(b, 1, h, w) + return out + + +def flow_shift(flow_fix, timestep, num_key_points=None, select_topk=False): + B = flow_fix.shape[0] // 2 + z = torch.where(flow_fix == 0, 0, 1).detach().sum(1, keepdim=True) / 2 + zt0, zt1 = z[B:], z[:B] + flow_fix_t0 = forward_warp(flow_fix[B:] * timestep, flow_fix[B:] * (1 - timestep), z=zt0) + flow_fix_t1 = forward_warp(flow_fix[:B] * (1 - timestep), flow_fix[:B] * timestep, z=zt1) + flow_fix_t = torch.cat([flow_fix_t0, flow_fix_t1], 0) + if select_topk and num_key_points != -1: + warp_map_t0 = softsplat(zt0, flow_fix[B:] * (1 - timestep), None, 'sum') + warp_map_t1 = softsplat(zt1, flow_fix[:B] * timestep, None, 'sum') + + warp_map = torch.cat([warp_map_t0, warp_map_t1], 0) + warp_map_topk = extract_topk(warp_map, num_key_points) + warp_map_topk = torch.where(warp_map_topk != 0, 1, 0) + flow_fix_t = flow_fix_t * warp_map_topk + return flow_fix_t + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + + +def deconv(in_planes=64, out_planes=64, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, bias=True), + nn.PReLU(out_planes) + ) + +class FlowRefine(nn.Module): + def __init__(self, in_planes, scale=4, c=64, n_layers=8): + super(FlowRefine, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c, 3, 1, 1), + conv(c, c, 3, 1, 1), + ) + self.convblock = nn.Sequential( + *[conv(c, c) for _ in range(n_layers)] + ) + self.lastconv = conv(c, 5) + self.scale = scale + + def forward(self, x, flow_s, flow): + if self.scale != 1: + x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", align_corners=False) + if flow is not None: + flow = F.interpolate(flow, scale_factor=1. / self.scale, mode="bilinear", + align_corners=False) * 1. / self.scale + x = torch.cat((x, flow), 1) + if flow_s is not None: + x = torch.cat((x, flow_s), 1) + x = self.conv0(x) + x = self.convblock(x) + x + x = self.lastconv(x) + tmp = F.interpolate(x, scale_factor=self.scale, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * self.scale + mask = tmp[:, 4:5] + return flow, mask + + +class MergingBlock(nn.Module): + def __init__(self, radius=3, input_dim=256, hidden_dim=256): + super(MergingBlock, self).__init__() + self.r = radius + self.rf = radius ** 2 + self.conv = nn.Sequential(nn.Conv2d(8 + 2*input_dim, hidden_dim, 3, 1, 1), + nn.PReLU(hidden_dim), + nn.Conv2d(hidden_dim, 2*2*self.rf, 1, 1, 0)) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(0.1 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, feature, init_flow, flow_fix): + """ + :param feature: [B, C, H, W] -> (local feature) or (local feature + matching feature) + :param init_flow: [B, 2, H, W] -> (local init flow) + :param flow_fix: [B, 2, H, W] -> (matching output, flow_fix (after patching, no hollows)) + """ + b, flow_channel, h, w = init_flow.shape + concat = torch.cat((init_flow, flow_fix, feature), dim=1) + mask = self.conv(concat) + assert init_flow.shape == flow_fix.shape, f"different flow shape not implemented yet" + mask = mask.view(b, 1, 2 * 2 * self.rf, h, w) + mask0 = mask[:, :, :2 * self.rf, :, :] + mask1 = mask[:, :, 2 * self.rf:, :, :] + mask = torch.cat([mask0, mask1], dim=0) + mask = torch.softmax(mask, dim=2) + + init_flow_all = torch.cat([init_flow[:, 0:2], init_flow[:, 2:4]], dim=0) + flow_fix_all = torch.cat([flow_fix[:, 0:2], flow_fix[:, 2:4]], dim=0) + + init_flow_grid = F.unfold(init_flow_all, [self.r, self.r], padding=self.r//2) + init_flow_grid = init_flow_grid.view(2*b, 2, self.rf, h, w) # [B, 2, 9, H, W] + flow_fix_grid = F.unfold(flow_fix_all, [self.r, self.r], padding=self.r//2) + flow_fix_grid = flow_fix_grid.view(2*b, 2, self.rf, h, w) # [B, 2, 9, H, W] + + flow_grid = torch.cat([init_flow_grid, flow_fix_grid], dim=2) # [B, 2, 2*9, H, W] + + merge_flow = torch.sum(mask * flow_grid, dim=2) # [B, 2, H, W] + return merge_flow + + +class MatchingBlock(nn.Module): + def __init__(self, scale, c, dim, num_layers=2, gm=True): + super(MatchingBlock, self).__init__() + self.gm = gm + self.dim = dim + self.scale = scale + self.merge = MergingBlock(radius=3, input_dim=dim+128, hidden_dim=256) + self.refine_block = FlowRefine(27, scale, c, num_layers) + + def forward(self, img0, img1, x, main_x, init_flow, init_flow_s, init_mask, + warped_img0, warped_img1, num_key_points, scale_factor, timestep=0.5): + result_dict = {} + + _, c, h, w = x.shape + B = main_x.shape[0] // 2 + # NOTE: + # 1. we stop sparse selecting points when the image resolution + # becomes too small (1/8 feature map resolution <= 32, i.e., h <= 256) + # (see `random_rescale` in train_x4k.py) + # 2. This limitation should be deleted when evaluating on low-resolution images (<=256x256) + if num_key_points != -1 and h > 32: + num_key_points = int(num_key_points * (h * w)) + else: + num_key_points = -1 # -1 stands for global matching + + feature = x.permute(0, 2, 3, 1).reshape(2 * B, h*w, c) + feature_reverse = torch.cat([feature[B:], feature[:B]], 0) + + if num_key_points == -1: + flow_fix_norm, _ = global_matching(feature, feature_reverse, None, h, w) + else: + imap = gen_importance_map(img0, img1, init_flow) + imap_s = F.interpolate(imap, size=(h, w), mode="bilinear", align_corners=False) + kp_ind, kp_feature = sample_key_points(imap_s, feature, num_key_points) + flow_fix_norm, _ = global_matching(kp_feature, feature_reverse, kp_ind, h, w) + + flow_fix = flow_shift(flow_fix_norm, timestep, num_key_points, select_topk=True) + flow_fix = torch.cat([flow_fix[:B], flow_fix[B:]], 1) + flow_r = torch.where(flow_fix == 0, init_flow_s, flow_fix) + flow_merge = self.merge(torch.cat([x[:B], x[B:], main_x[:B], main_x[B:]], dim=1), init_flow_s, flow_r) + flow_merge = torch.cat([flow_merge[:B], flow_merge[B:]], dim=1) + img0_s = F.interpolate(img0, scale_factor=1 / scale_factor, mode="bilinear", align_corners=False) + img1_s = F.interpolate(img1, scale_factor=1 / scale_factor, mode="bilinear", align_corners=False) + warped_img0_fine_s_m = backwarp(img0_s, flow_merge[:, 0:2]) + warped_img1_fine_s_m = backwarp(img1_s, flow_merge[:, 2:4]) + + flow_t, mask_t = self.refine_block(torch.cat((img0, img1, warped_img0, warped_img1, init_mask), 1), + torch.cat([warped_img0_fine_s_m, warped_img1_fine_s_m, flow_merge], 1), + init_flow) + + result_dict.update({'flow_t': flow_t}) + result_dict.update({'mask_t': mask_t}) + return result_dict diff --git a/sgm_vfi_arch/position.py b/sgm_vfi_arch/position.py new file mode 100644 index 0000000..14a6da4 --- /dev/null +++ b/sgm_vfi_arch/position.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py + +import torch +import torch.nn as nn +import math + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos diff --git a/sgm_vfi_arch/refine.py b/sgm_vfi_arch/refine.py new file mode 100644 index 0000000..6dd5dd5 --- /dev/null +++ b/sgm_vfi_arch/refine.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from timm.models.layers import trunc_normal_ +from .warplayer import warp + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), + nn.PReLU(out_planes) + ) + +class Conv2(nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + +class Contextnet(nn.Module): + def __init__(self, c=16): + super(Contextnet, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x, flow): + x = self.conv1(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, + recompute_scale_factor=False) * 0.5 + f1 = warp(x, flow) + x = self.conv2(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, + recompute_scale_factor=False) * 0.5 + f2 = warp(x, flow) + x = self.conv3(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, + recompute_scale_factor=False) * 0.5 + f3 = warp(x, flow) + x = self.conv4(x) + flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, + recompute_scale_factor=False) * 0.5 + f4 = warp(x, flow) + return [f1, f2, f3, f4] + +class Unet(nn.Module): + def __init__(self, c=16, out=3): + super(Unet, self).__init__() + self.down0 = Conv2(17, 2*c) + self.down1 = Conv2(4*c, 4*c) + self.down2 = Conv2(8*c, 8*c) + self.down3 = Conv2(16*c, 16*c) + self.up0 = deconv(32*c, 8*c) + self.up1 = deconv(16*c, 4*c) + self.up2 = deconv(8*c, 2*c) + self.up3 = deconv(4*c, c) + self.conv = nn.Conv2d(c, out, 3, 1, 1) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): + s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) + s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) + s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) + s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) + x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) + x = self.up1(torch.cat((x, s2), 1)) + x = self.up2(torch.cat((x, s1), 1)) + x = self.up3(torch.cat((x, s0), 1)) + x = self.conv(x) + return torch.sigmoid(x) diff --git a/sgm_vfi_arch/softsplat.py b/sgm_vfi_arch/softsplat.py new file mode 100644 index 0000000..eeccb88 --- /dev/null +++ b/sgm_vfi_arch/softsplat.py @@ -0,0 +1,530 @@ +#!/usr/bin/env python + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path() + # end + + return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction'], + options=tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])) +# end + + +########################################################## + + +def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str): + assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) + + if strMode == 'sum': assert(tenMetric is None) + if strMode == 'avg': assert(tenMetric is None) + if strMode.split('-')[0] == 'linear': assert(tenMetric is not None) + if strMode.split('-')[0] == 'soft': assert(tenMetric is not None) + + if strMode == 'avg': + tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1) + + elif strMode.split('-')[0] == 'linear': + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split('-')[0] == 'soft': + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + if strMode.split('-')[0] in ['avg', 'linear', 'soft']: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split('-')) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'addeps': + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'zeroeps': + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split('-')[1] == 'clipeps': + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + # end + + return tenOut +# end + + +class softsplat_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenIn, tenFlow): + tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) + + if tenIn.is_cuda == True: + cuda_launch(cuda_kernel('softsplat_out', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); + const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOut': tenOut + }))( + grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + elif tenIn.is_cuda != True: + assert(False) + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None + tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None + + if tenIngrad is not None: + cuda_launch(cuda_kernel('softsplat_ingrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenFlowgrad is not None: + cuda_launch(cuda_kernel('softsplat_flowgrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenIngrad, tenFlowgrad + # end +# end diff --git a/sgm_vfi_arch/transformer.py b/sgm_vfi_arch/transformer.py new file mode 100644 index 0000000..82f4426 --- /dev/null +++ b/sgm_vfi_arch/transformer.py @@ -0,0 +1,450 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def split_feature(feature, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c + ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0, f'h: {h}, w: {w}, num_splits: {num_splits}' + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits + ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] + + return feature + + +def merge_splits(splits, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( + new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( + new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] + + return merge + + +def single_head_full_attention(q, k, v): + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] + + return out + + +def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, + shift_size_h, shift_size_w, device=None): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 + h_slices = (slice(0, -window_size_h), + slice(-window_size_h, -shift_size_h), + slice(-shift_size_h, None)) + w_slices = (slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +def single_head_split_window_attention(q, k, v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, + ): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c ** 0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits, channel_last=True) + v = split_feature(v, num_splits=num_splits, channel_last=True) + + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.repeat(b, 1, 1) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + + out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, channel_last=True) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +class TransformerLayer(nn.Module): + def __init__(self, + d_model=256, + nhead=1, + attention_type='swin', + no_ffn=False, + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerLayer, self).__init__() + + self.dim = d_model + self.nhead = nhead + self.attention_type = attention_type + self.no_ffn = no_ffn + + self.with_shift = with_shift + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.merge = nn.Linear(d_model, d_model, bias=False) + + self.norm1 = nn.LayerNorm(d_model) + + # no ffn after self-attn, with ffn after cross-attn + if not self.no_ffn: + in_channels = d_model * 2 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + query, key, value = source, target, target + + # single-head attention + query = self.q_proj(query) # [B, L, C] + key = self.k_proj(key) # [B, L, C] + value = self.v_proj(value) # [B, L, C] + + if self.attention_type == 'swin' and attn_num_splits > 1: + if self.nhead > 1: + # we observe that multihead attention slows down the speed and increases the memory consumption + # without bringing obvious performance gains and thus the implementation is removed + raise NotImplementedError + else: + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=self.with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + message = single_head_full_attention(query, key, value) # [B, L, C] + + message = self.merge(message) # [B, L, C] + message = self.norm1(message) + + if not self.no_ffn: + message = self.mlp(torch.cat([source, message], dim=-1)) + message = self.norm2(message) + + return source + message + + +class TransformerBlock(nn.Module): + """self attention + cross attention + FFN""" + + def __init__(self, + d_model=256, + nhead=1, + attention_type='swin', + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerBlock, self).__init__() + + self.self_attn = TransformerLayer(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + no_ffn=True, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + self.cross_attn_ffn = TransformerLayer(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + + # self attention + source = self.self_attn(source, source, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # cross attention and ffn + source = self.cross_attn_ffn(source, target, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + return source + + +class FeatureTransformer(nn.Module): + def __init__(self, + num_layers=6, + d_model=128, + nhead=1, + attention_type='swin', + ffn_dim_expansion=4, + **kwargs, + ): + super(FeatureTransformer, self).__init__() + + self.attention_type = attention_type + + self.d_model = d_model + self.nhead = nhead + + self.layers = nn.ModuleList([ + TransformerBlock(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=True if attention_type == 'swin' and i % 2 == 1 else False, + ) + for i in range(num_layers)]) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, feature1, + attn_num_splits=None, + **kwargs, + ): + + b, c, h, w = feature0.shape + assert self.d_model == c + + feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + + if self.attention_type == 'swin' and attn_num_splits > 1: + # global and refine use different number of splits + window_size_h = h // attn_num_splits + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask = generate_shift_window_attn_mask( + input_resolution=(h, w), + window_size_h=window_size_h, + window_size_w=window_size_w, + shift_size_h=window_size_h // 2, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K*K, H/K*W/K, H/K*W/K] + else: + shifted_window_attn_mask = None + + # concat feature0 and feature1 in batch dimension to compute in parallel + concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] + concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] + + for layer in self.layers: + concat0 = layer(concat0, concat1, + height=h, + width=w, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # update feature1 + concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) + + feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] + + # reshape back + feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + + return feature0, feature1 + + +class FeatureFlowAttention(nn.Module): + """ + flow propagation with self-attention on feature + query: feature0, key: feature0, value: flow + """ + + def __init__(self, in_channels, + **kwargs, + ): + super(FeatureFlowAttention, self).__init__() + + self.q_proj = nn.Linear(in_channels, in_channels) + self.k_proj = nn.Linear(in_channels, in_channels) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, flow, + local_window_attn=False, + local_window_radius=1, + **kwargs, + ): + # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] + if local_window_attn: + return self.forward_local_window_attn(feature0, flow, + local_window_radius=local_window_radius) + + b, c, h, w = feature0.size() + + query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] + + query = self.q_proj(query) # [B, H*W, C] + key = self.k_proj(query) # [B, H*W, C] + + value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] + + scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, value) # [B, H*W, 2] + out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] + + return out + + def forward_local_window_attn(self, feature0, flow, + local_window_radius=1, + ): + assert flow.size(1) == 2 + assert local_window_radius > 0 + + b, c, h, w = feature0.size() + + feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) + ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] + + kernel_size = 2 * local_window_radius + 1 + + feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) + + feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, + padding=local_window_radius) # [B, C*(2R+1)^2), H*W] + + feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( + 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] + + flow_window = F.unfold(flow, kernel_size=kernel_size, + padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] + + flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute( + 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2] + + scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] + + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] + + return out diff --git a/sgm_vfi_arch/trident_conv.py b/sgm_vfi_arch/trident_conv.py new file mode 100644 index 0000000..29a2a73 --- /dev/null +++ b/sgm_vfi_arch/trident_conv.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _pair + + +class MultiScaleTridentConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + strides=1, + paddings=0, + dilations=1, + dilation=1, + groups=1, + num_branch=1, + test_branch_idx=-1, + bias=False, + norm=None, + activation=None, + ): + super(MultiScaleTridentConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.num_branch = num_branch + self.stride = _pair(stride) + self.groups = groups + self.with_bias = bias + self.dilation = dilation + if isinstance(paddings, int): + paddings = [paddings] * self.num_branch + if isinstance(dilations, int): + dilations = [dilations] * self.num_branch + if isinstance(strides, int): + strides = [strides] * self.num_branch + self.paddings = [_pair(padding) for padding in paddings] + self.dilations = [_pair(dilation) for dilation in dilations] + self.strides = [_pair(stride) for stride in strides] + self.test_branch_idx = test_branch_idx + self.norm = norm + self.activation = activation + + assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, inputs): + num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 + assert len(inputs) == num_branch + + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) + for input, stride, padding in zip(inputs, self.strides, self.paddings) + ] + else: + outputs = [ + F.conv2d( + inputs[0], + self.weight, + self.bias, + self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], + self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], + self.dilation, + self.groups, + ) + ] + + if self.norm is not None: + outputs = [self.norm(x) for x in outputs] + if self.activation is not None: + outputs = [self.activation(x) for x in outputs] + return outputs diff --git a/sgm_vfi_arch/utils.py b/sgm_vfi_arch/utils.py new file mode 100644 index 0000000..c796c3a --- /dev/null +++ b/sgm_vfi_arch/utils.py @@ -0,0 +1,98 @@ +import torch +import torch.nn.functional as F +from .position import PositionEmbeddingSine +from .geometry import coords_grid, generate_window_grid, normalize_coords + + +def split_feature(feature, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c + ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits + ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] + + return feature + +def merge_splits(splits, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( + new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( + new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] + + return merge + + +def feature_add_position(feature0, feature1, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + if attn_splits > 1: # add position in splited window + feature0_splits = split_feature(feature0, num_splits=attn_splits) + feature1_splits = split_feature(feature1, num_splits=attn_splits) + + position = pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + + def __init__(self, dims, mode='sintel', padding_factor=8, additional_pad=False): + self.ht, self.wd = dims[-2:] + add_pad = padding_factor*2 if additional_pad else 0 + pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor + add_pad + pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor + add_pad + if mode == 'sintel': + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] diff --git a/sgm_vfi_arch/warplayer.py b/sgm_vfi_arch/warplayer.py new file mode 100644 index 0000000..16fcbbf --- /dev/null +++ b/sgm_vfi_arch/warplayer.py @@ -0,0 +1,25 @@ +import torch + +backwarp_tenGrid = {} + + +def clear_warp_cache(): + """Free all cached grid tensors (call between frame pairs to reclaim VRAM).""" + backwarp_tenGrid.clear() + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device).view( + 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device).view( + 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat( + [tenHorizontal, tenVertical], 1).to(tenFlow.device) + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)