Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation
Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x). - Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes - Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB) - Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors) - Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate - single_pass toggle: True=arbitrary timestep (default), False=recursive like other models - ds_factor parameter for high-res input downscaling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
0
gimm_vfi_arch/generalizable_INR/modules/__init__.py
Normal file
0
gimm_vfi_arch/generalizable_INR/modules/__init__.py
Normal file
91
gimm_vfi_arch/generalizable_INR/modules/coord_sampler.py
Normal file
91
gimm_vfi_arch/generalizable_INR/modules/coord_sampler.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CoordSampler3D(nn.Module):
|
||||
def __init__(self, coord_range, t_coord_only=False):
|
||||
super().__init__()
|
||||
self.coord_range = coord_range
|
||||
self.t_coord_only = t_coord_only
|
||||
|
||||
def shape2coordinate(
|
||||
self,
|
||||
batch_size,
|
||||
spatial_shape,
|
||||
t_ids,
|
||||
coord_range=(-1.0, 1.0),
|
||||
upsample_ratio=1,
|
||||
device=None,
|
||||
):
|
||||
coords = []
|
||||
assert isinstance(t_ids, list)
|
||||
_coords = torch.tensor(t_ids, device=device) / 1.0
|
||||
coords.append(_coords.to(torch.float32))
|
||||
for num_s in spatial_shape:
|
||||
num_s = int(num_s * upsample_ratio)
|
||||
_coords = (0.5 + torch.arange(num_s, device=device)) / num_s
|
||||
_coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords
|
||||
coords.append(_coords)
|
||||
coords = torch.meshgrid(*coords, indexing="ij")
|
||||
coords = torch.stack(coords, dim=-1)
|
||||
ones_like_shape = (1,) * coords.ndim
|
||||
coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape)
|
||||
return coords # (B,T,H,W,3)
|
||||
|
||||
def batchshape2coordinate(
|
||||
self,
|
||||
batch_size,
|
||||
spatial_shape,
|
||||
t_ids,
|
||||
coord_range=(-1.0, 1.0),
|
||||
upsample_ratio=1,
|
||||
device=None,
|
||||
):
|
||||
coords = []
|
||||
_coords = torch.tensor(1, device=device)
|
||||
coords.append(_coords.to(torch.float32))
|
||||
for num_s in spatial_shape:
|
||||
num_s = int(num_s * upsample_ratio)
|
||||
_coords = (0.5 + torch.arange(num_s, device=device)) / num_s
|
||||
_coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords
|
||||
coords.append(_coords)
|
||||
coords = torch.meshgrid(*coords, indexing="ij")
|
||||
coords = torch.stack(coords, dim=-1)
|
||||
ones_like_shape = (1,) * coords.ndim
|
||||
# Now coords b,1,h,w,3, coords[...,0]=1.
|
||||
coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape)
|
||||
# assign per-sample timestep within the batch
|
||||
coords[..., :1] = coords[..., :1] * t_ids.reshape(-1, 1, 1, 1, 1)
|
||||
return coords
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch_size,
|
||||
s_shape,
|
||||
t_ids,
|
||||
coord_range=None,
|
||||
upsample_ratio=1.0,
|
||||
device=None,
|
||||
):
|
||||
coord_range = self.coord_range if coord_range is None else coord_range
|
||||
if isinstance(t_ids, list):
|
||||
coords = self.shape2coordinate(
|
||||
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
|
||||
)
|
||||
elif isinstance(t_ids, torch.Tensor):
|
||||
coords = self.batchshape2coordinate(
|
||||
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
|
||||
)
|
||||
if self.t_coord_only:
|
||||
coords = coords[..., :1]
|
||||
return coords
|
||||
340
gimm_vfi_arch/generalizable_INR/modules/fi_components.py
Normal file
340
gimm_vfi_arch/generalizable_INR/modules/fi_components.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# amt: https://github.com/MCG-NKU/AMT
|
||||
# motif: https://github.com/sichun233746/MoTIF
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .fi_utils import warp, resize
|
||||
|
||||
|
||||
class LateralBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(LateralBlock, self).__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
x = self.layers(x)
|
||||
return x + res
|
||||
|
||||
|
||||
def convrelu(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias=bias,
|
||||
),
|
||||
nn.PReLU(out_channels),
|
||||
)
|
||||
|
||||
|
||||
def multi_flow_combine(
|
||||
comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None
|
||||
):
|
||||
assert mean is None
|
||||
b, c, h, w = flow0.shape
|
||||
num_flows = c // 2
|
||||
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
||||
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
||||
|
||||
mask = (
|
||||
mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w)
|
||||
if mask is not None
|
||||
else None
|
||||
)
|
||||
img_res = (
|
||||
img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w)
|
||||
if img_res is not None
|
||||
else 0
|
||||
)
|
||||
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
|
||||
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
|
||||
mean = (
|
||||
torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1)
|
||||
if mean is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
img0_warp = warp(img0, flow0)
|
||||
img1_warp = warp(img1, flow1)
|
||||
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
|
||||
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
|
||||
|
||||
res = comb_block(img_warps.view(b, -1, h, w))
|
||||
imgt_pred = img_warps.mean(1) + res
|
||||
|
||||
imgt_pred = (imgt_pred + 1.0) / 2
|
||||
|
||||
return imgt_pred
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, side_channels, bias=True):
|
||||
super(ResBlock, self).__init__()
|
||||
self.side_channels = side_channels
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
||||
),
|
||||
nn.PReLU(in_channels),
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
side_channels,
|
||||
side_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
),
|
||||
nn.PReLU(side_channels),
|
||||
)
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
||||
),
|
||||
nn.PReLU(in_channels),
|
||||
)
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
side_channels,
|
||||
side_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
),
|
||||
nn.PReLU(side_channels),
|
||||
)
|
||||
self.conv5 = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
||||
)
|
||||
self.prelu = nn.PReLU(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
||||
res_feat = out[:, : -self.side_channels, ...]
|
||||
side_feat = out[:, -self.side_channels :, :, :]
|
||||
side_feat = self.conv2(side_feat)
|
||||
out = self.conv3(torch.cat([res_feat, side_feat], 1))
|
||||
|
||||
res_feat = out[:, : -self.side_channels, ...]
|
||||
side_feat = out[:, -self.side_channels :, :, :]
|
||||
side_feat = self.conv4(side_feat)
|
||||
out = self.conv5(torch.cat([res_feat, side_feat], 1))
|
||||
|
||||
out = self.prelu(x + out)
|
||||
return out
|
||||
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cdim,
|
||||
hidden_dim,
|
||||
flow_dim,
|
||||
corr_dim,
|
||||
corr_dim2,
|
||||
fc_dim,
|
||||
corr_levels=4,
|
||||
radius=3,
|
||||
scale_factor=None,
|
||||
out_num=1,
|
||||
):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
cor_planes = corr_levels * (2 * radius + 1) ** 2
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
|
||||
self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1)
|
||||
|
||||
self.gru = nn.Sequential(
|
||||
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
)
|
||||
|
||||
self.feat_head = nn.Sequential(
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
|
||||
)
|
||||
|
||||
self.flow_head = nn.Sequential(
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1),
|
||||
)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
|
||||
def forward(self, net, flow, corr):
|
||||
net = (
|
||||
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
|
||||
)
|
||||
cor = self.lrelu(self.convc1(corr))
|
||||
cor = self.lrelu(self.convc2(cor))
|
||||
flo = self.lrelu(self.convf1(flow))
|
||||
flo = self.lrelu(self.convf2(flo))
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
inp = self.lrelu(self.conv(cor_flo))
|
||||
inp = torch.cat([inp, flow, net], dim=1)
|
||||
|
||||
out = self.gru(inp)
|
||||
delta_net = self.feat_head(out)
|
||||
delta_flow = self.flow_head(out)
|
||||
|
||||
if self.scale_factor is not None:
|
||||
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
||||
delta_flow = self.scale_factor * resize(
|
||||
delta_flow, scale_factor=self.scale_factor
|
||||
)
|
||||
return delta_net, delta_flow
|
||||
|
||||
|
||||
def get_bn():
|
||||
return nn.BatchNorm2d
|
||||
|
||||
|
||||
class NewInitDecoder(nn.Module):
|
||||
def __init__(self, in_ch, skip_ch):
|
||||
super().__init__()
|
||||
norm_layer = get_bn()
|
||||
|
||||
self.upsample = nn.Sequential(
|
||||
nn.PixelShuffle(2),
|
||||
convrelu(in_ch // 4, in_ch // 4, 5, 1, 2),
|
||||
convrelu(in_ch // 4, in_ch // 4),
|
||||
convrelu(in_ch // 4, in_ch // 4),
|
||||
convrelu(in_ch // 4, in_ch // 4),
|
||||
convrelu(in_ch // 4, in_ch // 2),
|
||||
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
|
||||
norm_layer(in_ch // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
in_ch = in_ch // 2
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch * 2 + 16, in_ch, kernel_size=1, padding=0),
|
||||
ResBlock(in_ch, skip_ch),
|
||||
ResBlock(in_ch, skip_ch),
|
||||
ResBlock(in_ch, skip_ch),
|
||||
nn.Conv2d(in_ch, in_ch + 5, 3, 1, 1, 1, 1, True),
|
||||
)
|
||||
|
||||
def forward(self, f0, f1, flow0_in, flow1_in, img0=None, img1=None):
|
||||
f0 = self.upsample(f0)
|
||||
f1 = self.upsample(f1)
|
||||
f0_warp_ks = warp(f0, flow0_in)
|
||||
f1_warp_ks = warp(f1, flow1_in)
|
||||
|
||||
f_in = torch.cat([f0_warp_ks, f1_warp_ks, flow0_in, flow1_in], dim=1)
|
||||
|
||||
assert img0 is not None
|
||||
assert img1 is not None
|
||||
scale_factor = f_in.shape[2] / img0.shape[2]
|
||||
img0 = resize(img0, scale_factor=scale_factor)
|
||||
img1 = resize(img1, scale_factor=scale_factor)
|
||||
warped_img0 = warp(img0, flow0_in)
|
||||
warped_img1 = warp(img1, flow1_in)
|
||||
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
|
||||
|
||||
out = self.convblock(f_in)
|
||||
ft_ = out[:, 4:, ...]
|
||||
flow0 = flow0_in + out[:, :2, ...]
|
||||
flow1 = flow1_in + out[:, 2:4, ...]
|
||||
return flow0, flow1, ft_
|
||||
|
||||
|
||||
class NewMultiFlowDecoder(nn.Module):
|
||||
def __init__(self, in_ch, skip_ch, num_flows=3):
|
||||
super(NewMultiFlowDecoder, self).__init__()
|
||||
norm_layer = get_bn()
|
||||
|
||||
self.upsample = nn.Sequential(
|
||||
nn.PixelShuffle(2),
|
||||
nn.PixelShuffle(2),
|
||||
convrelu(in_ch // (4 * 4), in_ch // 4, 5, 1, 2),
|
||||
convrelu(in_ch // 4, in_ch // 4),
|
||||
convrelu(in_ch // 4, in_ch // 4),
|
||||
convrelu(in_ch // 4, in_ch // 4),
|
||||
convrelu(in_ch // 4, in_ch // 2),
|
||||
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
|
||||
norm_layer(in_ch // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.num_flows = num_flows
|
||||
ch_factor = 2
|
||||
self.convblock = nn.Sequential(
|
||||
convrelu(in_ch * ch_factor + 17, in_ch * ch_factor),
|
||||
ResBlock(in_ch * ch_factor, skip_ch),
|
||||
ResBlock(in_ch * ch_factor, skip_ch),
|
||||
ResBlock(in_ch * ch_factor, skip_ch),
|
||||
nn.Conv2d(in_ch * ch_factor, 8 * num_flows, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, ft_, f0, f1, flow0, flow1, mask=None, img0=None, img1=None):
|
||||
f0 = self.upsample(f0)
|
||||
# print([f1.shape,f0.shape])
|
||||
f1 = self.upsample(f1)
|
||||
n = self.num_flows
|
||||
flow0 = 4.0 * resize(flow0, scale_factor=4.0)
|
||||
flow1 = 4.0 * resize(flow1, scale_factor=4.0)
|
||||
|
||||
ft_ = resize(ft_, scale_factor=4.0)
|
||||
mask = resize(mask, scale_factor=4.0)
|
||||
f0_warp = warp(f0, flow0)
|
||||
f1_warp = warp(f1, flow1)
|
||||
|
||||
f_in = torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)
|
||||
|
||||
assert mask is not None
|
||||
f_in = torch.cat([f_in, mask], 1)
|
||||
|
||||
assert img0 is not None
|
||||
assert img1 is not None
|
||||
warped_img0 = warp(img0, flow0)
|
||||
warped_img1 = warp(img1, flow1)
|
||||
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
|
||||
|
||||
out = self.convblock(f_in)
|
||||
delta_flow0, delta_flow1, delta_mask, img_res = torch.split(
|
||||
out, [2 * n, 2 * n, n, 3 * n], 1
|
||||
)
|
||||
mask = delta_mask + mask.repeat(1, self.num_flows, 1, 1)
|
||||
mask = torch.sigmoid(mask)
|
||||
flow0 = delta_flow0 + flow0.repeat(1, self.num_flows, 1, 1)
|
||||
flow1 = delta_flow1 + flow1.repeat(1, self.num_flows, 1, 1)
|
||||
|
||||
return flow0, flow1, mask, img_res
|
||||
81
gimm_vfi_arch/generalizable_INR/modules/fi_utils.py
Normal file
81
gimm_vfi_arch/generalizable_INR/modules/fi_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# raft: https://github.com/princeton-vl/RAFT
|
||||
# ema-vfi: https://github.com/MCG-NJU/EMA-VFI
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
backwarp_tenGrid = {}
|
||||
|
||||
|
||||
def warp(tenInput, tenFlow):
|
||||
k = (str(tenFlow.device), str(tenFlow.size()))
|
||||
if k not in backwarp_tenGrid:
|
||||
tenHorizontal = (
|
||||
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device)
|
||||
.view(1, 1, 1, tenFlow.shape[3])
|
||||
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
||||
)
|
||||
tenVertical = (
|
||||
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device)
|
||||
.view(1, 1, tenFlow.shape[2], 1)
|
||||
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
||||
)
|
||||
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(tenFlow.device)
|
||||
|
||||
tenFlow = torch.cat(
|
||||
[
|
||||
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
||||
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
|
||||
],
|
||||
1,
|
||||
)
|
||||
|
||||
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
||||
return torch.nn.functional.grid_sample(
|
||||
input=tenInput,
|
||||
grid=g,
|
||||
mode="bilinear",
|
||||
padding_mode="border",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
|
||||
def normalize_flow(flows):
|
||||
# FIXME: MULTI-DIMENSION
|
||||
flow_scaler = torch.max(torch.abs(flows).flatten(1), dim=-1)[0].reshape(
|
||||
-1, 1, 1, 1, 1
|
||||
)
|
||||
flows = flows / flow_scaler # [-1,1]
|
||||
# # Adapt to [0,1]
|
||||
flows = (flows + 1.0) / 2.0
|
||||
return flows, flow_scaler
|
||||
|
||||
|
||||
def unnormalize_flow(flows, flow_scaler):
|
||||
return (flows * 2.0 - 1.0) * flow_scaler
|
||||
|
||||
|
||||
def resize(x, scale_factor):
|
||||
return F.interpolate(
|
||||
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd):
|
||||
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
return coords[None].repeat(batch, 1, 1, 1)
|
||||
|
||||
|
||||
def build_coord(img):
|
||||
N, C, H, W = img.shape
|
||||
coords = coords_grid(N, H // 8, W // 8)
|
||||
return coords
|
||||
198
gimm_vfi_arch/generalizable_INR/modules/hyponet.py
Normal file
198
gimm_vfi_arch/generalizable_INR/modules/hyponet.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||
# --------------------------------------------------------
|
||||
|
||||
import einops
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ..configs import HypoNetConfig
|
||||
from .utils import create_params_with_init, create_activation
|
||||
|
||||
|
||||
class HypoNet(nn.Module):
|
||||
r"""
|
||||
The Hyponetwork with a coordinate-based MLP to be modulated.
|
||||
"""
|
||||
|
||||
def __init__(self, config: HypoNetConfig, add_coord_dim=32):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_bias = config.use_bias
|
||||
self.init_config = config.initialization
|
||||
self.num_layer = config.n_layer
|
||||
self.hidden_dims = config.hidden_dim
|
||||
self.add_coord_dim = add_coord_dim
|
||||
|
||||
if len(self.hidden_dims) == 1:
|
||||
self.hidden_dims = OmegaConf.to_object(self.hidden_dims) * (
|
||||
self.num_layer - 1
|
||||
) # exclude output layer
|
||||
else:
|
||||
assert len(self.hidden_dims) == self.num_layer - 1
|
||||
|
||||
if self.config.activation.type == "siren":
|
||||
assert self.init_config.weight_init_type == "siren"
|
||||
assert self.init_config.bias_init_type == "siren"
|
||||
|
||||
# after computes the shape of trainable parameters, initialize them
|
||||
self.params_dict = None
|
||||
self.params_shape_dict = self.compute_params_shape()
|
||||
self.activation = create_activation(self.config.activation)
|
||||
self.build_base_params_dict(self.config.initialization)
|
||||
self.output_bias = config.output_bias
|
||||
|
||||
self.normalize_weight = config.normalize_weight
|
||||
|
||||
self.ignore_base_param_dict = {name: False for name in self.params_dict}
|
||||
|
||||
@staticmethod
|
||||
def subsample_coords(coords, subcoord_idx=None):
|
||||
if subcoord_idx is None:
|
||||
return coords
|
||||
|
||||
batch_size = coords.shape[0]
|
||||
sub_coords = []
|
||||
coords = coords.view(batch_size, -1, coords.shape[-1])
|
||||
for idx in range(batch_size):
|
||||
sub_coords.append(coords[idx : idx + 1, subcoord_idx[idx]])
|
||||
sub_coords = torch.cat(sub_coords, dim=0)
|
||||
return sub_coords
|
||||
|
||||
def forward(self, coord, modulation_params_dict=None, pixel_latent=None):
|
||||
sub_idx = None
|
||||
if isinstance(coord, tuple):
|
||||
coord, sub_idx = coord[0], coord[1]
|
||||
|
||||
if modulation_params_dict is not None:
|
||||
self.check_valid_param_keys(modulation_params_dict)
|
||||
|
||||
batch_size, coord_shape, input_dim = (
|
||||
coord.shape[0],
|
||||
coord.shape[1:-1],
|
||||
coord.shape[-1],
|
||||
)
|
||||
coord = coord.view(batch_size, -1, input_dim) # flatten the coordinates
|
||||
assert pixel_latent is not None
|
||||
pixel_latent = F.interpolate(
|
||||
pixel_latent.permute(0, 3, 1, 2),
|
||||
size=(coord_shape[1], coord_shape[2]),
|
||||
mode="bilinear",
|
||||
).permute(0, 2, 3, 1)
|
||||
pixel_latent_dim = pixel_latent.shape[-1]
|
||||
pixel_latent = pixel_latent.view(batch_size, -1, pixel_latent_dim)
|
||||
hidden = coord
|
||||
|
||||
hidden = torch.cat([pixel_latent, hidden], dim=-1)
|
||||
|
||||
hidden = self.subsample_coords(hidden, sub_idx)
|
||||
|
||||
for idx in range(self.config.n_layer):
|
||||
param_key = f"linear_wb{idx}"
|
||||
base_param = einops.repeat(
|
||||
self.params_dict[param_key], "n m -> b n m", b=batch_size
|
||||
)
|
||||
|
||||
if (modulation_params_dict is not None) and (
|
||||
param_key in modulation_params_dict.keys()
|
||||
):
|
||||
modulation_param = modulation_params_dict[param_key]
|
||||
else:
|
||||
if self.config.use_bias:
|
||||
modulation_param = torch.ones_like(base_param[:, :-1])
|
||||
else:
|
||||
modulation_param = torch.ones_like(base_param)
|
||||
|
||||
if self.config.use_bias:
|
||||
ones = torch.ones(*hidden.shape[:-1], 1, device=hidden.device)
|
||||
hidden = torch.cat([hidden, ones], dim=-1)
|
||||
|
||||
base_param_w, base_param_b = (
|
||||
base_param[:, :-1, :],
|
||||
base_param[:, -1:, :],
|
||||
)
|
||||
|
||||
if self.ignore_base_param_dict[param_key]:
|
||||
base_param_w = 1.0
|
||||
param_w = base_param_w * modulation_param
|
||||
if self.normalize_weight:
|
||||
param_w = F.normalize(param_w, dim=1)
|
||||
modulated_param = torch.cat([param_w, base_param_b], dim=1)
|
||||
else:
|
||||
if self.ignore_base_param_dict[param_key]:
|
||||
base_param = 1.0
|
||||
if self.normalize_weight:
|
||||
modulated_param = F.normalize(base_param * modulation_param, dim=1)
|
||||
else:
|
||||
modulated_param = base_param * modulation_param
|
||||
# print([param_key,hidden.shape,modulated_param.shape])
|
||||
hidden = torch.bmm(hidden, modulated_param)
|
||||
|
||||
if idx < (self.config.n_layer - 1):
|
||||
hidden = self.activation(hidden)
|
||||
|
||||
outputs = hidden + self.output_bias
|
||||
if sub_idx is None:
|
||||
outputs = outputs.view(batch_size, *coord_shape, -1)
|
||||
return outputs
|
||||
|
||||
def compute_params_shape(self):
|
||||
"""
|
||||
Computes the shape of MLP parameters.
|
||||
The computed shapes are used to build the initial weights by `build_base_params_dict`.
|
||||
"""
|
||||
config = self.config
|
||||
use_bias = self.use_bias
|
||||
|
||||
param_shape_dict = dict()
|
||||
|
||||
fan_in = config.input_dim
|
||||
add_dim = self.add_coord_dim
|
||||
fan_in = fan_in + add_dim
|
||||
fan_in = fan_in + 1 if use_bias else fan_in
|
||||
|
||||
for i in range(config.n_layer - 1):
|
||||
fan_out = self.hidden_dims[i]
|
||||
param_shape_dict[f"linear_wb{i}"] = (fan_in, fan_out)
|
||||
fan_in = fan_out + 1 if use_bias else fan_out
|
||||
|
||||
param_shape_dict[f"linear_wb{config.n_layer-1}"] = (fan_in, config.output_dim)
|
||||
return param_shape_dict
|
||||
|
||||
def build_base_params_dict(self, init_config):
|
||||
assert self.params_shape_dict
|
||||
params_dict = nn.ParameterDict()
|
||||
for idx, (name, shape) in enumerate(self.params_shape_dict.items()):
|
||||
is_first = idx == 0
|
||||
params = create_params_with_init(
|
||||
shape,
|
||||
init_type=init_config.weight_init_type,
|
||||
include_bias=self.use_bias,
|
||||
bias_init_type=init_config.bias_init_type,
|
||||
is_first=is_first,
|
||||
siren_w0=self.config.activation.siren_w0, # valid only for siren
|
||||
)
|
||||
params = nn.Parameter(params)
|
||||
params_dict[name] = params
|
||||
self.set_params_dict(params_dict)
|
||||
|
||||
def check_valid_param_keys(self, params_dict):
|
||||
predefined_params_keys = self.params_shape_dict.keys()
|
||||
for param_key in params_dict.keys():
|
||||
if param_key in predefined_params_keys:
|
||||
continue
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
def set_params_dict(self, params_dict):
|
||||
self.check_valid_param_keys(params_dict)
|
||||
self.params_dict = params_dict
|
||||
42
gimm_vfi_arch/generalizable_INR/modules/layers.py
Normal file
42
gimm_vfi_arch/generalizable_INR/modules/layers.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
|
||||
# define siren layer & Siren model
|
||||
class Sine(nn.Module):
|
||||
"""Sine activation with scaling.
|
||||
|
||||
Args:
|
||||
w0 (float): Omega_0 parameter from SIREN paper.
|
||||
"""
|
||||
|
||||
def __init__(self, w0=1.0):
|
||||
super().__init__()
|
||||
self.w0 = w0
|
||||
|
||||
def forward(self, x):
|
||||
return torch.sin(self.w0 * x)
|
||||
|
||||
|
||||
# Damping activation from http://arxiv.org/abs/2306.15242
|
||||
class Damping(nn.Module):
|
||||
"""Sine activation with sublinear factor
|
||||
|
||||
Args:
|
||||
w0 (float): Omega_0 parameter from SIREN paper.
|
||||
"""
|
||||
|
||||
def __init__(self, w0=1.0):
|
||||
super().__init__()
|
||||
self.w0 = w0
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.clamp(x, min=1e-30)
|
||||
return torch.sin(self.w0 * x) * torch.sqrt(x.abs())
|
||||
52
gimm_vfi_arch/generalizable_INR/modules/module_config.py
Normal file
52
gimm_vfi_arch/generalizable_INR/modules/module_config.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||
# --------------------------------------------------------
|
||||
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from omegaconf import MISSING
|
||||
|
||||
|
||||
@dataclass
|
||||
class HypoNetActivationConfig:
|
||||
type: str = "relu"
|
||||
siren_w0: Optional[float] = 30.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class HypoNetInitConfig:
|
||||
weight_init_type: Optional[str] = "kaiming_uniform"
|
||||
bias_init_type: Optional[str] = "zero"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HypoNetConfig:
|
||||
type: str = "mlp"
|
||||
n_layer: int = 5
|
||||
hidden_dim: List[int] = MISSING
|
||||
use_bias: bool = True
|
||||
input_dim: int = 2
|
||||
output_dim: int = 3
|
||||
output_bias: float = 0.5
|
||||
activation: HypoNetActivationConfig = field(default_factory=HypoNetActivationConfig)
|
||||
initialization: HypoNetInitConfig = field(default_factory=HypoNetInitConfig)
|
||||
|
||||
normalize_weight: bool = True
|
||||
linear_interpo: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CoordSamplerConfig:
|
||||
data_type: str = "image"
|
||||
t_coord_only: bool = False
|
||||
coord_range: List[float] = MISSING
|
||||
time_range: List[float] = MISSING
|
||||
train_strategy: Optional[str] = MISSING
|
||||
val_strategy: Optional[str] = MISSING
|
||||
patch_size: Optional[int] = MISSING
|
||||
672
gimm_vfi_arch/generalizable_INR/modules/softsplat.py
Normal file
672
gimm_vfi_arch/generalizable_INR/modules/softsplat.py
Normal file
@@ -0,0 +1,672 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# softmax-splatting: https://github.com/sniklaus/softmax-splatting
|
||||
# --------------------------------------------------------
|
||||
|
||||
import collections
|
||||
import cupy
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import typing
|
||||
|
||||
|
||||
##########################################################
|
||||
|
||||
|
||||
objCudacache = {}
|
||||
|
||||
|
||||
def cuda_int32(intIn: int):
|
||||
return cupy.int32(intIn)
|
||||
|
||||
|
||||
# end
|
||||
|
||||
|
||||
def cuda_float32(fltIn: float):
|
||||
return cupy.float32(fltIn)
|
||||
|
||||
|
||||
# end
|
||||
|
||||
|
||||
def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict):
|
||||
if "device" not in objCudacache:
|
||||
objCudacache["device"] = torch.cuda.get_device_name()
|
||||
# end
|
||||
|
||||
strKey = strFunction
|
||||
|
||||
for strVariable in objVariables:
|
||||
objValue = objVariables[strVariable]
|
||||
|
||||
strKey += strVariable
|
||||
|
||||
if objValue is None:
|
||||
continue
|
||||
|
||||
elif type(objValue) == int:
|
||||
strKey += str(objValue)
|
||||
|
||||
elif type(objValue) == float:
|
||||
strKey += str(objValue)
|
||||
|
||||
elif type(objValue) == bool:
|
||||
strKey += str(objValue)
|
||||
|
||||
elif type(objValue) == str:
|
||||
strKey += objValue
|
||||
|
||||
elif type(objValue) == torch.Tensor:
|
||||
strKey += str(objValue.dtype)
|
||||
strKey += str(objValue.shape)
|
||||
strKey += str(objValue.stride())
|
||||
|
||||
elif True:
|
||||
print(strVariable, type(objValue))
|
||||
assert False
|
||||
|
||||
# end
|
||||
# end
|
||||
|
||||
strKey += objCudacache["device"]
|
||||
|
||||
if strKey not in objCudacache:
|
||||
for strVariable in objVariables:
|
||||
objValue = objVariables[strVariable]
|
||||
|
||||
if objValue is None:
|
||||
continue
|
||||
|
||||
elif type(objValue) == int:
|
||||
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))
|
||||
|
||||
elif type(objValue) == float:
|
||||
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))
|
||||
|
||||
elif type(objValue) == bool:
|
||||
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))
|
||||
|
||||
elif type(objValue) == str:
|
||||
strKernel = strKernel.replace("{{" + strVariable + "}}", objValue)
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
|
||||
strKernel = strKernel.replace("{{type}}", "unsigned char")
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
|
||||
strKernel = strKernel.replace("{{type}}", "half")
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
|
||||
strKernel = strKernel.replace("{{type}}", "float")
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
|
||||
strKernel = strKernel.replace("{{type}}", "double")
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
|
||||
strKernel = strKernel.replace("{{type}}", "int")
|
||||
|
||||
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
|
||||
strKernel = strKernel.replace("{{type}}", "long")
|
||||
|
||||
elif type(objValue) == torch.Tensor:
|
||||
print(strVariable, objValue.dtype)
|
||||
assert False
|
||||
|
||||
elif True:
|
||||
print(strVariable, type(objValue))
|
||||
assert False
|
||||
|
||||
# end
|
||||
# end
|
||||
|
||||
while True:
|
||||
objMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel)
|
||||
|
||||
if objMatch is None:
|
||||
break
|
||||
# end
|
||||
|
||||
intArg = int(objMatch.group(2))
|
||||
|
||||
strTensor = objMatch.group(4)
|
||||
intSizes = objVariables[strTensor].size()
|
||||
|
||||
strKernel = strKernel.replace(
|
||||
objMatch.group(),
|
||||
str(
|
||||
intSizes[intArg]
|
||||
if torch.is_tensor(intSizes[intArg]) == False
|
||||
else intSizes[intArg].item()
|
||||
),
|
||||
)
|
||||
# end
|
||||
|
||||
while True:
|
||||
objMatch = re.search(r"(OFFSET_)([0-4])(\()", strKernel)
|
||||
|
||||
if objMatch is None:
|
||||
break
|
||||
# end
|
||||
|
||||
intStart = objMatch.span()[1]
|
||||
intStop = objMatch.span()[1]
|
||||
intParentheses = 1
|
||||
|
||||
while True:
|
||||
intParentheses += 1 if strKernel[intStop] == "(" else 0
|
||||
intParentheses -= 1 if strKernel[intStop] == ")" else 0
|
||||
|
||||
if intParentheses == 0:
|
||||
break
|
||||
# end
|
||||
|
||||
intStop += 1
|
||||
# end
|
||||
|
||||
intArgs = int(objMatch.group(2))
|
||||
strArgs = strKernel[intStart:intStop].split(",")
|
||||
|
||||
assert intArgs == len(strArgs) - 1
|
||||
|
||||
strTensor = strArgs[0]
|
||||
intStrides = objVariables[strTensor].stride()
|
||||
|
||||
strIndex = []
|
||||
|
||||
for intArg in range(intArgs):
|
||||
strIndex.append(
|
||||
"(("
|
||||
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
|
||||
+ ")*"
|
||||
+ str(
|
||||
intStrides[intArg]
|
||||
if torch.is_tensor(intStrides[intArg]) == False
|
||||
else intStrides[intArg].item()
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
# end
|
||||
|
||||
strKernel = strKernel.replace(
|
||||
"OFFSET_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")",
|
||||
"(" + str.join("+", strIndex) + ")",
|
||||
)
|
||||
# end
|
||||
|
||||
while True:
|
||||
objMatch = re.search(r"(VALUE_)([0-4])(\()", strKernel)
|
||||
|
||||
if objMatch is None:
|
||||
break
|
||||
# end
|
||||
|
||||
intStart = objMatch.span()[1]
|
||||
intStop = objMatch.span()[1]
|
||||
intParentheses = 1
|
||||
|
||||
while True:
|
||||
intParentheses += 1 if strKernel[intStop] == "(" else 0
|
||||
intParentheses -= 1 if strKernel[intStop] == ")" else 0
|
||||
|
||||
if intParentheses == 0:
|
||||
break
|
||||
# end
|
||||
|
||||
intStop += 1
|
||||
# end
|
||||
|
||||
intArgs = int(objMatch.group(2))
|
||||
strArgs = strKernel[intStart:intStop].split(",")
|
||||
|
||||
assert intArgs == len(strArgs) - 1
|
||||
|
||||
strTensor = strArgs[0]
|
||||
intStrides = objVariables[strTensor].stride()
|
||||
|
||||
strIndex = []
|
||||
|
||||
for intArg in range(intArgs):
|
||||
strIndex.append(
|
||||
"(("
|
||||
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
|
||||
+ ")*"
|
||||
+ str(
|
||||
intStrides[intArg]
|
||||
if torch.is_tensor(intStrides[intArg]) == False
|
||||
else intStrides[intArg].item()
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
# end
|
||||
|
||||
strKernel = strKernel.replace(
|
||||
"VALUE_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")",
|
||||
strTensor + "[" + str.join("+", strIndex) + "]",
|
||||
)
|
||||
# end
|
||||
|
||||
objCudacache[strKey] = {"strFunction": strFunction, "strKernel": strKernel}
|
||||
# end
|
||||
|
||||
return strKey
|
||||
|
||||
|
||||
# end
|
||||
|
||||
|
||||
@cupy.memoize(for_each_device=True)
|
||||
@torch.compiler.disable()
|
||||
def cuda_launch(strKey: str):
|
||||
try:
|
||||
os.environ.setdefault("CUDA_HOME", cupy.cuda.get_cuda_path())
|
||||
except Exception:
|
||||
if "CUDA_HOME" not in os.environ:
|
||||
raise RuntimeError("'CUDA_HOME' not set, unable to find cuda-toolkit installation.")
|
||||
|
||||
strKernel = objCudacache[strKey]["strKernel"]
|
||||
strFunction = objCudacache[strKey]["strFunction"]
|
||||
|
||||
return cupy.RawModule(
|
||||
code=strKernel,
|
||||
options=(
|
||||
"-I " + os.environ["CUDA_HOME"],
|
||||
"-I " + os.environ["CUDA_HOME"] + "/include",
|
||||
),
|
||||
).get_function(strFunction)
|
||||
|
||||
|
||||
|
||||
##########################################################
|
||||
|
||||
|
||||
@torch.compiler.disable()
|
||||
def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False):
|
||||
assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"]
|
||||
|
||||
if strMode == "sum":
|
||||
assert tenMetric is None
|
||||
if strMode == "avg":
|
||||
assert tenMetric is None
|
||||
if strMode.split("-")[0] == "linear":
|
||||
assert tenMetric is not None
|
||||
if strMode.split("-")[0] == "softmax":
|
||||
assert tenMetric is not None
|
||||
|
||||
if strMode == "avg":
|
||||
tenIn = torch.cat(
|
||||
[
|
||||
tenIn,
|
||||
tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]),
|
||||
],
|
||||
1,
|
||||
)
|
||||
|
||||
elif strMode.split("-")[0] == "linear":
|
||||
tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
|
||||
|
||||
elif strMode.split("-")[0] == "softmax":
|
||||
tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
|
||||
|
||||
# end
|
||||
if torch.isnan(tenIn).any():
|
||||
print("NaN values detected during training in tenIn. Exiting.")
|
||||
assert False
|
||||
|
||||
tenOut = softsplat_func.apply(tenIn, tenFlow)
|
||||
|
||||
if torch.isnan(tenOut).any():
|
||||
print("NaN values detected during training in tenOut_1. Exiting.")
|
||||
assert False
|
||||
|
||||
if strMode.split("-")[0] in ["avg", "linear", "softmax"]:
|
||||
tenNormalize = tenOut[:, -1:, :, :]
|
||||
|
||||
if len(strMode.split("-")) == 1:
|
||||
tenNormalize = tenNormalize + 0.0000001
|
||||
|
||||
elif strMode.split("-")[1] == "addeps":
|
||||
tenNormalize = tenNormalize + 0.0000001
|
||||
|
||||
elif strMode.split("-")[1] == "zeroeps":
|
||||
tenNormalize[tenNormalize == 0.0] = 1.0
|
||||
|
||||
elif strMode.split("-")[1] == "clipeps":
|
||||
tenNormalize = tenNormalize.clip(0.0000001, None)
|
||||
|
||||
# end
|
||||
|
||||
if return_norm:
|
||||
return tenOut[:, :-1, :, :], tenNormalize
|
||||
|
||||
tenOut = tenOut[:, :-1, :, :] / tenNormalize
|
||||
|
||||
if torch.isnan(tenOut).any():
|
||||
print("NaN values detected during training in tenOut_2. Exiting.")
|
||||
assert False
|
||||
|
||||
# end
|
||||
|
||||
return tenOut
|
||||
|
||||
|
||||
# end
|
||||
|
||||
|
||||
class softsplat_func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
||||
def forward(self, tenIn, tenFlow):
|
||||
tenOut = tenIn.new_zeros(
|
||||
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
|
||||
)
|
||||
|
||||
if tenIn.is_cuda == True:
|
||||
cuda_launch(
|
||||
cuda_kernel(
|
||||
"softsplat_out",
|
||||
"""
|
||||
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
|
||||
const int n,
|
||||
const {{type}}* __restrict__ tenIn,
|
||||
const {{type}}* __restrict__ tenFlow,
|
||||
{{type}}* __restrict__ tenOut
|
||||
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
|
||||
const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut);
|
||||
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
|
||||
const int intX = ( intIndex ) % SIZE_3(tenOut);
|
||||
|
||||
assert(SIZE_1(tenFlow) == 2);
|
||||
|
||||
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
||||
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
||||
|
||||
if (isfinite(fltX) == false) { return; }
|
||||
if (isfinite(fltY) == false) { return; }
|
||||
|
||||
{{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
|
||||
|
||||
int intNorthwestX = (int) (floor(fltX));
|
||||
int intNorthwestY = (int) (floor(fltY));
|
||||
int intNortheastX = intNorthwestX + 1;
|
||||
int intNortheastY = intNorthwestY;
|
||||
int intSouthwestX = intNorthwestX;
|
||||
int intSouthwestY = intNorthwestY + 1;
|
||||
int intSoutheastX = intNorthwestX + 1;
|
||||
int intSoutheastY = intNorthwestY + 1;
|
||||
|
||||
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
|
||||
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
|
||||
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
|
||||
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
|
||||
|
||||
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
|
||||
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
|
||||
}
|
||||
|
||||
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
|
||||
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
|
||||
}
|
||||
|
||||
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
|
||||
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
|
||||
}
|
||||
|
||||
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
|
||||
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
|
||||
}
|
||||
} }
|
||||
""",
|
||||
{"tenIn": tenIn, "tenFlow": tenFlow, "tenOut": tenOut},
|
||||
)
|
||||
)(
|
||||
grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
|
||||
block=tuple([512, 1, 1]),
|
||||
args=[
|
||||
cuda_int32(tenOut.nelement()),
|
||||
tenIn.data_ptr(),
|
||||
tenFlow.data_ptr(),
|
||||
tenOut.data_ptr(),
|
||||
],
|
||||
stream=collections.namedtuple("Stream", "ptr")(
|
||||
torch.cuda.current_stream().cuda_stream
|
||||
),
|
||||
)
|
||||
|
||||
elif tenIn.is_cuda != True:
|
||||
assert False
|
||||
|
||||
# end
|
||||
|
||||
self.save_for_backward(tenIn, tenFlow)
|
||||
|
||||
return tenOut
|
||||
|
||||
# end
|
||||
|
||||
@staticmethod
|
||||
@torch.compiler.disable()
|
||||
@torch.amp.custom_bwd(device_type="cuda")
|
||||
def backward(self, tenOutgrad):
|
||||
tenIn, tenFlow = self.saved_tensors
|
||||
|
||||
tenOutgrad = tenOutgrad.contiguous()
|
||||
assert tenOutgrad.is_cuda == True
|
||||
|
||||
tenIngrad = (
|
||||
tenIn.new_zeros(
|
||||
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
|
||||
)
|
||||
if self.needs_input_grad[0] == True
|
||||
else None
|
||||
)
|
||||
tenFlowgrad = (
|
||||
tenFlow.new_zeros(
|
||||
[tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]
|
||||
)
|
||||
if self.needs_input_grad[1] == True
|
||||
else None
|
||||
)
|
||||
|
||||
if tenIngrad is not None:
|
||||
cuda_launch(
|
||||
cuda_kernel(
|
||||
"softsplat_ingrad",
|
||||
"""
|
||||
extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
|
||||
const int n,
|
||||
const {{type}}* __restrict__ tenIn,
|
||||
const {{type}}* __restrict__ tenFlow,
|
||||
const {{type}}* __restrict__ tenOutgrad,
|
||||
{{type}}* __restrict__ tenIngrad,
|
||||
{{type}}* __restrict__ tenFlowgrad
|
||||
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||
const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
|
||||
const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
|
||||
const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
|
||||
const int intX = ( intIndex ) % SIZE_3(tenIngrad);
|
||||
|
||||
assert(SIZE_1(tenFlow) == 2);
|
||||
|
||||
{{type}} fltIngrad = 0.0f;
|
||||
|
||||
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
||||
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
||||
|
||||
if (isfinite(fltX) == false) { return; }
|
||||
if (isfinite(fltY) == false) { return; }
|
||||
|
||||
int intNorthwestX = (int) (floor(fltX));
|
||||
int intNorthwestY = (int) (floor(fltY));
|
||||
int intNortheastX = intNorthwestX + 1;
|
||||
int intNortheastY = intNorthwestY;
|
||||
int intSouthwestX = intNorthwestX;
|
||||
int intSouthwestY = intNorthwestY + 1;
|
||||
int intSoutheastX = intNorthwestX + 1;
|
||||
int intSoutheastY = intNorthwestY + 1;
|
||||
|
||||
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
|
||||
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
|
||||
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
|
||||
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
|
||||
|
||||
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
|
||||
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
|
||||
}
|
||||
|
||||
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
|
||||
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
|
||||
}
|
||||
|
||||
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
|
||||
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
|
||||
}
|
||||
|
||||
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
|
||||
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
|
||||
}
|
||||
|
||||
tenIngrad[intIndex] = fltIngrad;
|
||||
} }
|
||||
""",
|
||||
{
|
||||
"tenIn": tenIn,
|
||||
"tenFlow": tenFlow,
|
||||
"tenOutgrad": tenOutgrad,
|
||||
"tenIngrad": tenIngrad,
|
||||
"tenFlowgrad": tenFlowgrad,
|
||||
},
|
||||
)
|
||||
)(
|
||||
grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
|
||||
block=tuple([512, 1, 1]),
|
||||
args=[
|
||||
cuda_int32(tenIngrad.nelement()),
|
||||
tenIn.data_ptr(),
|
||||
tenFlow.data_ptr(),
|
||||
tenOutgrad.data_ptr(),
|
||||
tenIngrad.data_ptr(),
|
||||
None,
|
||||
],
|
||||
stream=collections.namedtuple("Stream", "ptr")(
|
||||
torch.cuda.current_stream().cuda_stream
|
||||
),
|
||||
)
|
||||
# end
|
||||
|
||||
if tenFlowgrad is not None:
|
||||
cuda_launch(
|
||||
cuda_kernel(
|
||||
"softsplat_flowgrad",
|
||||
"""
|
||||
extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
|
||||
const int n,
|
||||
const {{type}}* __restrict__ tenIn,
|
||||
const {{type}}* __restrict__ tenFlow,
|
||||
const {{type}}* __restrict__ tenOutgrad,
|
||||
{{type}}* __restrict__ tenIngrad,
|
||||
{{type}}* __restrict__ tenFlowgrad
|
||||
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
||||
const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
|
||||
const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
|
||||
const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
|
||||
const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
|
||||
|
||||
assert(SIZE_1(tenFlow) == 2);
|
||||
|
||||
{{type}} fltFlowgrad = 0.0f;
|
||||
|
||||
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
|
||||
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
|
||||
|
||||
if (isfinite(fltX) == false) { return; }
|
||||
if (isfinite(fltY) == false) { return; }
|
||||
|
||||
int intNorthwestX = (int) (floor(fltX));
|
||||
int intNorthwestY = (int) (floor(fltY));
|
||||
int intNortheastX = intNorthwestX + 1;
|
||||
int intNortheastY = intNorthwestY;
|
||||
int intSouthwestX = intNorthwestX;
|
||||
int intSouthwestY = intNorthwestY + 1;
|
||||
int intSoutheastX = intNorthwestX + 1;
|
||||
int intSoutheastY = intNorthwestY + 1;
|
||||
|
||||
{{type}} fltNorthwest = 0.0f;
|
||||
{{type}} fltNortheast = 0.0f;
|
||||
{{type}} fltSouthwest = 0.0f;
|
||||
{{type}} fltSoutheast = 0.0f;
|
||||
|
||||
if (intC == 0) {
|
||||
fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
|
||||
fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
|
||||
fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
|
||||
fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
|
||||
|
||||
} else if (intC == 1) {
|
||||
fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
|
||||
fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
|
||||
fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
|
||||
fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
|
||||
|
||||
}
|
||||
|
||||
for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
|
||||
{{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
|
||||
|
||||
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
|
||||
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
|
||||
}
|
||||
|
||||
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
|
||||
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
|
||||
}
|
||||
|
||||
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
|
||||
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
|
||||
}
|
||||
|
||||
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
|
||||
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
|
||||
}
|
||||
}
|
||||
|
||||
tenFlowgrad[intIndex] = fltFlowgrad;
|
||||
} }
|
||||
""",
|
||||
{
|
||||
"tenIn": tenIn,
|
||||
"tenFlow": tenFlow,
|
||||
"tenOutgrad": tenOutgrad,
|
||||
"tenIngrad": tenIngrad,
|
||||
"tenFlowgrad": tenFlowgrad,
|
||||
},
|
||||
)
|
||||
)(
|
||||
grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
|
||||
block=tuple([512, 1, 1]),
|
||||
args=[
|
||||
cuda_int32(tenFlowgrad.nelement()),
|
||||
tenIn.data_ptr(),
|
||||
tenFlow.data_ptr(),
|
||||
tenOutgrad.data_ptr(),
|
||||
None,
|
||||
tenFlowgrad.data_ptr(),
|
||||
],
|
||||
stream=collections.namedtuple("Stream", "ptr")(
|
||||
torch.cuda.current_stream().cuda_stream
|
||||
),
|
||||
)
|
||||
# end
|
||||
|
||||
return tenIngrad, tenFlowgrad
|
||||
|
||||
# end
|
||||
|
||||
|
||||
# end
|
||||
76
gimm_vfi_arch/generalizable_INR/modules/utils.py
Normal file
76
gimm_vfi_arch/generalizable_INR/modules/utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .layers import Sine, Damping
|
||||
|
||||
|
||||
def convert_int_to_list(size, len_list=2):
|
||||
if isinstance(size, int):
|
||||
return [size] * len_list
|
||||
else:
|
||||
assert len(size) == len_list
|
||||
return size
|
||||
|
||||
|
||||
def initialize_params(params, init_type, **kwargs):
|
||||
fan_in, fan_out = params.shape[0], params.shape[1]
|
||||
if init_type is None or init_type == "normal":
|
||||
nn.init.normal_(params)
|
||||
elif init_type == "kaiming_uniform":
|
||||
nn.init.kaiming_uniform_(params, a=math.sqrt(5))
|
||||
elif init_type == "uniform_fan_in":
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
nn.init.uniform_(params, -bound, bound)
|
||||
elif init_type == "zero":
|
||||
nn.init.zeros_(params)
|
||||
elif "siren" == init_type:
|
||||
assert "siren_w0" in kwargs.keys() and "is_first" in kwargs.keys()
|
||||
w0 = kwargs["siren_w0"]
|
||||
if kwargs["is_first"]:
|
||||
w_std = 1 / fan_in
|
||||
else:
|
||||
w_std = math.sqrt(6.0 / fan_in) / w0
|
||||
nn.init.uniform_(params, -w_std, w_std)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def create_params_with_init(
|
||||
shape, init_type="normal", include_bias=False, bias_init_type="zero", **kwargs
|
||||
):
|
||||
if not include_bias:
|
||||
params = torch.empty([shape[0], shape[1]])
|
||||
initialize_params(params, init_type, **kwargs)
|
||||
return params
|
||||
else:
|
||||
params = torch.empty([shape[0] - 1, shape[1]])
|
||||
bias = torch.empty([1, shape[1]])
|
||||
|
||||
initialize_params(params, init_type, **kwargs)
|
||||
initialize_params(bias, bias_init_type, **kwargs)
|
||||
return torch.cat([params, bias], dim=0)
|
||||
|
||||
|
||||
def create_activation(config):
|
||||
if config.type == "relu":
|
||||
activation = nn.ReLU()
|
||||
elif config.type == "siren":
|
||||
activation = Sine(config.siren_w0)
|
||||
elif config.type == "silu":
|
||||
activation = nn.SiLU()
|
||||
elif config.type == "damp":
|
||||
activation = Damping(config.siren_w0)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return activation
|
||||
Reference in New Issue
Block a user