Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation

Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates
all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x).

- Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes
- Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB)
- Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors)
- Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate
- single_pass toggle: True=arbitrary timestep (default), False=recursive like other models
- ds_factor parameter for high-res input downscaling

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 13:11:45 +01:00
parent 3c3d4b2537
commit d642255e70
56 changed files with 9774 additions and 1 deletions

View File

@@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
import torch
import torch.nn as nn
class CoordSampler3D(nn.Module):
def __init__(self, coord_range, t_coord_only=False):
super().__init__()
self.coord_range = coord_range
self.t_coord_only = t_coord_only
def shape2coordinate(
self,
batch_size,
spatial_shape,
t_ids,
coord_range=(-1.0, 1.0),
upsample_ratio=1,
device=None,
):
coords = []
assert isinstance(t_ids, list)
_coords = torch.tensor(t_ids, device=device) / 1.0
coords.append(_coords.to(torch.float32))
for num_s in spatial_shape:
num_s = int(num_s * upsample_ratio)
_coords = (0.5 + torch.arange(num_s, device=device)) / num_s
_coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords
coords.append(_coords)
coords = torch.meshgrid(*coords, indexing="ij")
coords = torch.stack(coords, dim=-1)
ones_like_shape = (1,) * coords.ndim
coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape)
return coords # (B,T,H,W,3)
def batchshape2coordinate(
self,
batch_size,
spatial_shape,
t_ids,
coord_range=(-1.0, 1.0),
upsample_ratio=1,
device=None,
):
coords = []
_coords = torch.tensor(1, device=device)
coords.append(_coords.to(torch.float32))
for num_s in spatial_shape:
num_s = int(num_s * upsample_ratio)
_coords = (0.5 + torch.arange(num_s, device=device)) / num_s
_coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords
coords.append(_coords)
coords = torch.meshgrid(*coords, indexing="ij")
coords = torch.stack(coords, dim=-1)
ones_like_shape = (1,) * coords.ndim
# Now coords b,1,h,w,3, coords[...,0]=1.
coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape)
# assign per-sample timestep within the batch
coords[..., :1] = coords[..., :1] * t_ids.reshape(-1, 1, 1, 1, 1)
return coords
def forward(
self,
batch_size,
s_shape,
t_ids,
coord_range=None,
upsample_ratio=1.0,
device=None,
):
coord_range = self.coord_range if coord_range is None else coord_range
if isinstance(t_ids, list):
coords = self.shape2coordinate(
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
)
elif isinstance(t_ids, torch.Tensor):
coords = self.batchshape2coordinate(
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
)
if self.t_coord_only:
coords = coords[..., :1]
return coords

View File

@@ -0,0 +1,340 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# amt: https://github.com/MCG-NKU/AMT
# motif: https://github.com/sichun233746/MoTIF
# --------------------------------------------------------
import torch
import torch.nn as nn
from .fi_utils import warp, resize
class LateralBlock(nn.Module):
def __init__(self, dim):
super(LateralBlock, self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(dim, dim, 3, 1, 1, bias=True),
)
def forward(self, x):
res = x
x = self.layers(x)
return x + res
def convrelu(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
bias=True,
):
return nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias=bias,
),
nn.PReLU(out_channels),
)
def multi_flow_combine(
comb_block, img0, img1, flow0, flow1, mask=None, img_res=None, mean=None
):
assert mean is None
b, c, h, w = flow0.shape
num_flows = c // 2
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
mask = (
mask.reshape(b, num_flows, 1, h, w).reshape(-1, 1, h, w)
if mask is not None
else None
)
img_res = (
img_res.reshape(b, num_flows, 3, h, w).reshape(-1, 3, h, w)
if img_res is not None
else 0
)
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
mean = (
torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1)
if mean is not None
else 0
)
img0_warp = warp(img0, flow0)
img1_warp = warp(img1, flow1)
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
res = comb_block(img_warps.view(b, -1, h, w))
imgt_pred = img_warps.mean(1) + res
imgt_pred = (imgt_pred + 1.0) / 2
return imgt_pred
class ResBlock(nn.Module):
def __init__(self, in_channels, side_channels, bias=True):
super(ResBlock, self).__init__()
self.side_channels = side_channels
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
),
nn.PReLU(in_channels),
)
self.conv2 = nn.Sequential(
nn.Conv2d(
side_channels,
side_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
),
nn.PReLU(side_channels),
)
self.conv3 = nn.Sequential(
nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
),
nn.PReLU(in_channels),
)
self.conv4 = nn.Sequential(
nn.Conv2d(
side_channels,
side_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
),
nn.PReLU(side_channels),
)
self.conv5 = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
)
self.prelu = nn.PReLU(in_channels)
def forward(self, x):
out = self.conv1(x)
res_feat = out[:, : -self.side_channels, ...]
side_feat = out[:, -self.side_channels :, :, :]
side_feat = self.conv2(side_feat)
out = self.conv3(torch.cat([res_feat, side_feat], 1))
res_feat = out[:, : -self.side_channels, ...]
side_feat = out[:, -self.side_channels :, :, :]
side_feat = self.conv4(side_feat)
out = self.conv5(torch.cat([res_feat, side_feat], 1))
out = self.prelu(x + out)
return out
class BasicUpdateBlock(nn.Module):
def __init__(
self,
cdim,
hidden_dim,
flow_dim,
corr_dim,
corr_dim2,
fc_dim,
corr_levels=4,
radius=3,
scale_factor=None,
out_num=1,
):
super(BasicUpdateBlock, self).__init__()
cor_planes = corr_levels * (2 * radius + 1) ** 2
self.scale_factor = scale_factor
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
self.convf1 = nn.Conv2d(4, flow_dim * 2, 7, padding=3)
self.convf2 = nn.Conv2d(flow_dim * 2, flow_dim, 3, padding=1)
self.conv = nn.Conv2d(flow_dim + corr_dim2, fc_dim, 3, padding=1)
self.gru = nn.Sequential(
nn.Conv2d(fc_dim + 4 + cdim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
)
self.feat_head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
)
self.flow_head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(hidden_dim, 4 * out_num, 3, padding=1),
)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, net, flow, corr):
net = (
resize(net, 1 / self.scale_factor) if self.scale_factor is not None else net
)
cor = self.lrelu(self.convc1(corr))
cor = self.lrelu(self.convc2(cor))
flo = self.lrelu(self.convf1(flow))
flo = self.lrelu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
inp = self.lrelu(self.conv(cor_flo))
inp = torch.cat([inp, flow, net], dim=1)
out = self.gru(inp)
delta_net = self.feat_head(out)
delta_flow = self.flow_head(out)
if self.scale_factor is not None:
delta_net = resize(delta_net, scale_factor=self.scale_factor)
delta_flow = self.scale_factor * resize(
delta_flow, scale_factor=self.scale_factor
)
return delta_net, delta_flow
def get_bn():
return nn.BatchNorm2d
class NewInitDecoder(nn.Module):
def __init__(self, in_ch, skip_ch):
super().__init__()
norm_layer = get_bn()
self.upsample = nn.Sequential(
nn.PixelShuffle(2),
convrelu(in_ch // 4, in_ch // 4, 5, 1, 2),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 2),
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
norm_layer(in_ch // 2),
nn.ReLU(inplace=True),
)
in_ch = in_ch // 2
self.convblock = nn.Sequential(
convrelu(in_ch * 2 + 16, in_ch, kernel_size=1, padding=0),
ResBlock(in_ch, skip_ch),
ResBlock(in_ch, skip_ch),
ResBlock(in_ch, skip_ch),
nn.Conv2d(in_ch, in_ch + 5, 3, 1, 1, 1, 1, True),
)
def forward(self, f0, f1, flow0_in, flow1_in, img0=None, img1=None):
f0 = self.upsample(f0)
f1 = self.upsample(f1)
f0_warp_ks = warp(f0, flow0_in)
f1_warp_ks = warp(f1, flow1_in)
f_in = torch.cat([f0_warp_ks, f1_warp_ks, flow0_in, flow1_in], dim=1)
assert img0 is not None
assert img1 is not None
scale_factor = f_in.shape[2] / img0.shape[2]
img0 = resize(img0, scale_factor=scale_factor)
img1 = resize(img1, scale_factor=scale_factor)
warped_img0 = warp(img0, flow0_in)
warped_img1 = warp(img1, flow1_in)
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
out = self.convblock(f_in)
ft_ = out[:, 4:, ...]
flow0 = flow0_in + out[:, :2, ...]
flow1 = flow1_in + out[:, 2:4, ...]
return flow0, flow1, ft_
class NewMultiFlowDecoder(nn.Module):
def __init__(self, in_ch, skip_ch, num_flows=3):
super(NewMultiFlowDecoder, self).__init__()
norm_layer = get_bn()
self.upsample = nn.Sequential(
nn.PixelShuffle(2),
nn.PixelShuffle(2),
convrelu(in_ch // (4 * 4), in_ch // 4, 5, 1, 2),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 4),
convrelu(in_ch // 4, in_ch // 2),
nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=1),
norm_layer(in_ch // 2),
nn.ReLU(inplace=True),
)
self.num_flows = num_flows
ch_factor = 2
self.convblock = nn.Sequential(
convrelu(in_ch * ch_factor + 17, in_ch * ch_factor),
ResBlock(in_ch * ch_factor, skip_ch),
ResBlock(in_ch * ch_factor, skip_ch),
ResBlock(in_ch * ch_factor, skip_ch),
nn.Conv2d(in_ch * ch_factor, 8 * num_flows, kernel_size=3, padding=1),
)
def forward(self, ft_, f0, f1, flow0, flow1, mask=None, img0=None, img1=None):
f0 = self.upsample(f0)
# print([f1.shape,f0.shape])
f1 = self.upsample(f1)
n = self.num_flows
flow0 = 4.0 * resize(flow0, scale_factor=4.0)
flow1 = 4.0 * resize(flow1, scale_factor=4.0)
ft_ = resize(ft_, scale_factor=4.0)
mask = resize(mask, scale_factor=4.0)
f0_warp = warp(f0, flow0)
f1_warp = warp(f1, flow1)
f_in = torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)
assert mask is not None
f_in = torch.cat([f_in, mask], 1)
assert img0 is not None
assert img1 is not None
warped_img0 = warp(img0, flow0)
warped_img1 = warp(img1, flow1)
f_in = torch.cat([f_in, img0, img1, warped_img0, warped_img1], dim=1)
out = self.convblock(f_in)
delta_flow0, delta_flow1, delta_mask, img_res = torch.split(
out, [2 * n, 2 * n, n, 3 * n], 1
)
mask = delta_mask + mask.repeat(1, self.num_flows, 1, 1)
mask = torch.sigmoid(mask)
flow0 = delta_flow0 + flow0.repeat(1, self.num_flows, 1, 1)
flow1 = delta_flow1 + flow1.repeat(1, self.num_flows, 1, 1)
return flow0, flow1, mask, img_res

View File

@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# raft: https://github.com/princeton-vl/RAFT
# ema-vfi: https://github.com/MCG-NJU/EMA-VFI
# --------------------------------------------------------
import torch
import torch.nn.functional as F
backwarp_tenGrid = {}
def warp(tenInput, tenFlow):
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = (
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device)
.view(1, 1, 1, tenFlow.shape[3])
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
)
tenVertical = (
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device)
.view(1, 1, tenFlow.shape[2], 1)
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
)
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(tenFlow.device)
tenFlow = torch.cat(
[
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
],
1,
)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(
input=tenInput,
grid=g,
mode="bilinear",
padding_mode="border",
align_corners=True,
)
def normalize_flow(flows):
# FIXME: MULTI-DIMENSION
flow_scaler = torch.max(torch.abs(flows).flatten(1), dim=-1)[0].reshape(
-1, 1, 1, 1, 1
)
flows = flows / flow_scaler # [-1,1]
# # Adapt to [0,1]
flows = (flows + 1.0) / 2.0
return flows, flow_scaler
def unnormalize_flow(flows, flow_scaler):
return (flows * 2.0 - 1.0) * flow_scaler
def resize(x, scale_factor):
return F.interpolate(
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
)
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def build_coord(img):
N, C, H, W = img.shape
coords = coords_grid(N, H // 8, W // 8)
return coords

View File

@@ -0,0 +1,198 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from ..configs import HypoNetConfig
from .utils import create_params_with_init, create_activation
class HypoNet(nn.Module):
r"""
The Hyponetwork with a coordinate-based MLP to be modulated.
"""
def __init__(self, config: HypoNetConfig, add_coord_dim=32):
super().__init__()
self.config = config
self.use_bias = config.use_bias
self.init_config = config.initialization
self.num_layer = config.n_layer
self.hidden_dims = config.hidden_dim
self.add_coord_dim = add_coord_dim
if len(self.hidden_dims) == 1:
self.hidden_dims = OmegaConf.to_object(self.hidden_dims) * (
self.num_layer - 1
) # exclude output layer
else:
assert len(self.hidden_dims) == self.num_layer - 1
if self.config.activation.type == "siren":
assert self.init_config.weight_init_type == "siren"
assert self.init_config.bias_init_type == "siren"
# after computes the shape of trainable parameters, initialize them
self.params_dict = None
self.params_shape_dict = self.compute_params_shape()
self.activation = create_activation(self.config.activation)
self.build_base_params_dict(self.config.initialization)
self.output_bias = config.output_bias
self.normalize_weight = config.normalize_weight
self.ignore_base_param_dict = {name: False for name in self.params_dict}
@staticmethod
def subsample_coords(coords, subcoord_idx=None):
if subcoord_idx is None:
return coords
batch_size = coords.shape[0]
sub_coords = []
coords = coords.view(batch_size, -1, coords.shape[-1])
for idx in range(batch_size):
sub_coords.append(coords[idx : idx + 1, subcoord_idx[idx]])
sub_coords = torch.cat(sub_coords, dim=0)
return sub_coords
def forward(self, coord, modulation_params_dict=None, pixel_latent=None):
sub_idx = None
if isinstance(coord, tuple):
coord, sub_idx = coord[0], coord[1]
if modulation_params_dict is not None:
self.check_valid_param_keys(modulation_params_dict)
batch_size, coord_shape, input_dim = (
coord.shape[0],
coord.shape[1:-1],
coord.shape[-1],
)
coord = coord.view(batch_size, -1, input_dim) # flatten the coordinates
assert pixel_latent is not None
pixel_latent = F.interpolate(
pixel_latent.permute(0, 3, 1, 2),
size=(coord_shape[1], coord_shape[2]),
mode="bilinear",
).permute(0, 2, 3, 1)
pixel_latent_dim = pixel_latent.shape[-1]
pixel_latent = pixel_latent.view(batch_size, -1, pixel_latent_dim)
hidden = coord
hidden = torch.cat([pixel_latent, hidden], dim=-1)
hidden = self.subsample_coords(hidden, sub_idx)
for idx in range(self.config.n_layer):
param_key = f"linear_wb{idx}"
base_param = einops.repeat(
self.params_dict[param_key], "n m -> b n m", b=batch_size
)
if (modulation_params_dict is not None) and (
param_key in modulation_params_dict.keys()
):
modulation_param = modulation_params_dict[param_key]
else:
if self.config.use_bias:
modulation_param = torch.ones_like(base_param[:, :-1])
else:
modulation_param = torch.ones_like(base_param)
if self.config.use_bias:
ones = torch.ones(*hidden.shape[:-1], 1, device=hidden.device)
hidden = torch.cat([hidden, ones], dim=-1)
base_param_w, base_param_b = (
base_param[:, :-1, :],
base_param[:, -1:, :],
)
if self.ignore_base_param_dict[param_key]:
base_param_w = 1.0
param_w = base_param_w * modulation_param
if self.normalize_weight:
param_w = F.normalize(param_w, dim=1)
modulated_param = torch.cat([param_w, base_param_b], dim=1)
else:
if self.ignore_base_param_dict[param_key]:
base_param = 1.0
if self.normalize_weight:
modulated_param = F.normalize(base_param * modulation_param, dim=1)
else:
modulated_param = base_param * modulation_param
# print([param_key,hidden.shape,modulated_param.shape])
hidden = torch.bmm(hidden, modulated_param)
if idx < (self.config.n_layer - 1):
hidden = self.activation(hidden)
outputs = hidden + self.output_bias
if sub_idx is None:
outputs = outputs.view(batch_size, *coord_shape, -1)
return outputs
def compute_params_shape(self):
"""
Computes the shape of MLP parameters.
The computed shapes are used to build the initial weights by `build_base_params_dict`.
"""
config = self.config
use_bias = self.use_bias
param_shape_dict = dict()
fan_in = config.input_dim
add_dim = self.add_coord_dim
fan_in = fan_in + add_dim
fan_in = fan_in + 1 if use_bias else fan_in
for i in range(config.n_layer - 1):
fan_out = self.hidden_dims[i]
param_shape_dict[f"linear_wb{i}"] = (fan_in, fan_out)
fan_in = fan_out + 1 if use_bias else fan_out
param_shape_dict[f"linear_wb{config.n_layer-1}"] = (fan_in, config.output_dim)
return param_shape_dict
def build_base_params_dict(self, init_config):
assert self.params_shape_dict
params_dict = nn.ParameterDict()
for idx, (name, shape) in enumerate(self.params_shape_dict.items()):
is_first = idx == 0
params = create_params_with_init(
shape,
init_type=init_config.weight_init_type,
include_bias=self.use_bias,
bias_init_type=init_config.bias_init_type,
is_first=is_first,
siren_w0=self.config.activation.siren_w0, # valid only for siren
)
params = nn.Parameter(params)
params_dict[name] = params
self.set_params_dict(params_dict)
def check_valid_param_keys(self, params_dict):
predefined_params_keys = self.params_shape_dict.keys()
for param_key in params_dict.keys():
if param_key in predefined_params_keys:
continue
else:
raise KeyError
def set_params_dict(self, params_dict):
self.check_valid_param_keys(params_dict)
self.params_dict = params_dict

View File

@@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
from torch import nn
import torch
# define siren layer & Siren model
class Sine(nn.Module):
"""Sine activation with scaling.
Args:
w0 (float): Omega_0 parameter from SIREN paper.
"""
def __init__(self, w0=1.0):
super().__init__()
self.w0 = w0
def forward(self, x):
return torch.sin(self.w0 * x)
# Damping activation from http://arxiv.org/abs/2306.15242
class Damping(nn.Module):
"""Sine activation with sublinear factor
Args:
w0 (float): Omega_0 parameter from SIREN paper.
"""
def __init__(self, w0=1.0):
super().__init__()
self.w0 = w0
def forward(self, x):
x = torch.clamp(x, min=1e-30)
return torch.sin(self.w0 * x) * torch.sqrt(x.abs())

View File

@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
from typing import List, Optional
from dataclasses import dataclass, field
from omegaconf import MISSING
@dataclass
class HypoNetActivationConfig:
type: str = "relu"
siren_w0: Optional[float] = 30.0
@dataclass
class HypoNetInitConfig:
weight_init_type: Optional[str] = "kaiming_uniform"
bias_init_type: Optional[str] = "zero"
@dataclass
class HypoNetConfig:
type: str = "mlp"
n_layer: int = 5
hidden_dim: List[int] = MISSING
use_bias: bool = True
input_dim: int = 2
output_dim: int = 3
output_bias: float = 0.5
activation: HypoNetActivationConfig = field(default_factory=HypoNetActivationConfig)
initialization: HypoNetInitConfig = field(default_factory=HypoNetInitConfig)
normalize_weight: bool = True
linear_interpo: bool = False
@dataclass
class CoordSamplerConfig:
data_type: str = "image"
t_coord_only: bool = False
coord_range: List[float] = MISSING
time_range: List[float] = MISSING
train_strategy: Optional[str] = MISSING
val_strategy: Optional[str] = MISSING
patch_size: Optional[int] = MISSING

View File

@@ -0,0 +1,672 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# softmax-splatting: https://github.com/sniklaus/softmax-splatting
# --------------------------------------------------------
import collections
import cupy
import os
import re
import torch
import typing
##########################################################
objCudacache = {}
def cuda_int32(intIn: int):
return cupy.int32(intIn)
# end
def cuda_float32(fltIn: float):
return cupy.float32(fltIn)
# end
def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict):
if "device" not in objCudacache:
objCudacache["device"] = torch.cuda.get_device_name()
# end
strKey = strFunction
for strVariable in objVariables:
objValue = objVariables[strVariable]
strKey += strVariable
if objValue is None:
continue
elif type(objValue) == int:
strKey += str(objValue)
elif type(objValue) == float:
strKey += str(objValue)
elif type(objValue) == bool:
strKey += str(objValue)
elif type(objValue) == str:
strKey += objValue
elif type(objValue) == torch.Tensor:
strKey += str(objValue.dtype)
strKey += str(objValue.shape)
strKey += str(objValue.stride())
elif True:
print(strVariable, type(objValue))
assert False
# end
# end
strKey += objCudacache["device"]
if strKey not in objCudacache:
for strVariable in objVariables:
objValue = objVariables[strVariable]
if objValue is None:
continue
elif type(objValue) == int:
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))
elif type(objValue) == float:
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))
elif type(objValue) == bool:
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue))
elif type(objValue) == str:
strKernel = strKernel.replace("{{" + strVariable + "}}", objValue)
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
strKernel = strKernel.replace("{{type}}", "unsigned char")
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
strKernel = strKernel.replace("{{type}}", "half")
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
strKernel = strKernel.replace("{{type}}", "float")
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
strKernel = strKernel.replace("{{type}}", "double")
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
strKernel = strKernel.replace("{{type}}", "int")
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
strKernel = strKernel.replace("{{type}}", "long")
elif type(objValue) == torch.Tensor:
print(strVariable, objValue.dtype)
assert False
elif True:
print(strVariable, type(objValue))
assert False
# end
# end
while True:
objMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel)
if objMatch is None:
break
# end
intArg = int(objMatch.group(2))
strTensor = objMatch.group(4)
intSizes = objVariables[strTensor].size()
strKernel = strKernel.replace(
objMatch.group(),
str(
intSizes[intArg]
if torch.is_tensor(intSizes[intArg]) == False
else intSizes[intArg].item()
),
)
# end
while True:
objMatch = re.search(r"(OFFSET_)([0-4])(\()", strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == "(" else 0
intParentheses -= 1 if strKernel[intStop] == ")" else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(",")
assert intArgs == len(strArgs) - 1
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append(
"(("
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
+ ")*"
+ str(
intStrides[intArg]
if torch.is_tensor(intStrides[intArg]) == False
else intStrides[intArg].item()
)
+ ")"
)
# end
strKernel = strKernel.replace(
"OFFSET_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")",
"(" + str.join("+", strIndex) + ")",
)
# end
while True:
objMatch = re.search(r"(VALUE_)([0-4])(\()", strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == "(" else 0
intParentheses -= 1 if strKernel[intStop] == ")" else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(",")
assert intArgs == len(strArgs) - 1
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append(
"(("
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
+ ")*"
+ str(
intStrides[intArg]
if torch.is_tensor(intStrides[intArg]) == False
else intStrides[intArg].item()
)
+ ")"
)
# end
strKernel = strKernel.replace(
"VALUE_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")",
strTensor + "[" + str.join("+", strIndex) + "]",
)
# end
objCudacache[strKey] = {"strFunction": strFunction, "strKernel": strKernel}
# end
return strKey
# end
@cupy.memoize(for_each_device=True)
@torch.compiler.disable()
def cuda_launch(strKey: str):
try:
os.environ.setdefault("CUDA_HOME", cupy.cuda.get_cuda_path())
except Exception:
if "CUDA_HOME" not in os.environ:
raise RuntimeError("'CUDA_HOME' not set, unable to find cuda-toolkit installation.")
strKernel = objCudacache[strKey]["strKernel"]
strFunction = objCudacache[strKey]["strFunction"]
return cupy.RawModule(
code=strKernel,
options=(
"-I " + os.environ["CUDA_HOME"],
"-I " + os.environ["CUDA_HOME"] + "/include",
),
).get_function(strFunction)
##########################################################
@torch.compiler.disable()
def softsplat(tenIn, tenFlow, tenMetric, strMode, return_norm=False):
assert strMode.split("-")[0] in ["sum", "avg", "linear", "softmax"]
if strMode == "sum":
assert tenMetric is None
if strMode == "avg":
assert tenMetric is None
if strMode.split("-")[0] == "linear":
assert tenMetric is not None
if strMode.split("-")[0] == "softmax":
assert tenMetric is not None
if strMode == "avg":
tenIn = torch.cat(
[
tenIn,
tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]),
],
1,
)
elif strMode.split("-")[0] == "linear":
tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
elif strMode.split("-")[0] == "softmax":
tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
# end
if torch.isnan(tenIn).any():
print("NaN values detected during training in tenIn. Exiting.")
assert False
tenOut = softsplat_func.apply(tenIn, tenFlow)
if torch.isnan(tenOut).any():
print("NaN values detected during training in tenOut_1. Exiting.")
assert False
if strMode.split("-")[0] in ["avg", "linear", "softmax"]:
tenNormalize = tenOut[:, -1:, :, :]
if len(strMode.split("-")) == 1:
tenNormalize = tenNormalize + 0.0000001
elif strMode.split("-")[1] == "addeps":
tenNormalize = tenNormalize + 0.0000001
elif strMode.split("-")[1] == "zeroeps":
tenNormalize[tenNormalize == 0.0] = 1.0
elif strMode.split("-")[1] == "clipeps":
tenNormalize = tenNormalize.clip(0.0000001, None)
# end
if return_norm:
return tenOut[:, :-1, :, :], tenNormalize
tenOut = tenOut[:, :-1, :, :] / tenNormalize
if torch.isnan(tenOut).any():
print("NaN values detected during training in tenOut_2. Exiting.")
assert False
# end
return tenOut
# end
class softsplat_func(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
def forward(self, tenIn, tenFlow):
tenOut = tenIn.new_zeros(
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
)
if tenIn.is_cuda == True:
cuda_launch(
cuda_kernel(
"softsplat_out",
"""
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
{{type}}* __restrict__ tenOut
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut);
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
const int intX = ( intIndex ) % SIZE_3(tenOut);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
{{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
}
} }
""",
{"tenIn": tenIn, "tenFlow": tenFlow, "tenOut": tenOut},
)
)(
grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cuda_int32(tenOut.nelement()),
tenIn.data_ptr(),
tenFlow.data_ptr(),
tenOut.data_ptr(),
],
stream=collections.namedtuple("Stream", "ptr")(
torch.cuda.current_stream().cuda_stream
),
)
elif tenIn.is_cuda != True:
assert False
# end
self.save_for_backward(tenIn, tenFlow)
return tenOut
# end
@staticmethod
@torch.compiler.disable()
@torch.amp.custom_bwd(device_type="cuda")
def backward(self, tenOutgrad):
tenIn, tenFlow = self.saved_tensors
tenOutgrad = tenOutgrad.contiguous()
assert tenOutgrad.is_cuda == True
tenIngrad = (
tenIn.new_zeros(
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
)
if self.needs_input_grad[0] == True
else None
)
tenFlowgrad = (
tenFlow.new_zeros(
[tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]
)
if self.needs_input_grad[1] == True
else None
)
if tenIngrad is not None:
cuda_launch(
cuda_kernel(
"softsplat_ingrad",
"""
extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
const {{type}}* __restrict__ tenOutgrad,
{{type}}* __restrict__ tenIngrad,
{{type}}* __restrict__ tenFlowgrad
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
const int intX = ( intIndex ) % SIZE_3(tenIngrad);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltIngrad = 0.0f;
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
}
tenIngrad[intIndex] = fltIngrad;
} }
""",
{
"tenIn": tenIn,
"tenFlow": tenFlow,
"tenOutgrad": tenOutgrad,
"tenIngrad": tenIngrad,
"tenFlowgrad": tenFlowgrad,
},
)
)(
grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cuda_int32(tenIngrad.nelement()),
tenIn.data_ptr(),
tenFlow.data_ptr(),
tenOutgrad.data_ptr(),
tenIngrad.data_ptr(),
None,
],
stream=collections.namedtuple("Stream", "ptr")(
torch.cuda.current_stream().cuda_stream
),
)
# end
if tenFlowgrad is not None:
cuda_launch(
cuda_kernel(
"softsplat_flowgrad",
"""
extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
const {{type}}* __restrict__ tenOutgrad,
{{type}}* __restrict__ tenIngrad,
{{type}}* __restrict__ tenFlowgrad
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltFlowgrad = 0.0f;
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = 0.0f;
{{type}} fltNortheast = 0.0f;
{{type}} fltSouthwest = 0.0f;
{{type}} fltSoutheast = 0.0f;
if (intC == 0) {
fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
} else if (intC == 1) {
fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
}
for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
{{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
}
}
tenFlowgrad[intIndex] = fltFlowgrad;
} }
""",
{
"tenIn": tenIn,
"tenFlow": tenFlow,
"tenOutgrad": tenOutgrad,
"tenIngrad": tenIngrad,
"tenFlowgrad": tenFlowgrad,
},
)
)(
grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cuda_int32(tenFlowgrad.nelement()),
tenIn.data_ptr(),
tenFlow.data_ptr(),
tenOutgrad.data_ptr(),
None,
tenFlowgrad.data_ptr(),
],
stream=collections.namedtuple("Stream", "ptr")(
torch.cuda.current_stream().cuda_stream
),
)
# end
return tenIngrad, tenFlowgrad
# end
# end

View File

@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
import math
import torch
import torch.nn as nn
from .layers import Sine, Damping
def convert_int_to_list(size, len_list=2):
if isinstance(size, int):
return [size] * len_list
else:
assert len(size) == len_list
return size
def initialize_params(params, init_type, **kwargs):
fan_in, fan_out = params.shape[0], params.shape[1]
if init_type is None or init_type == "normal":
nn.init.normal_(params)
elif init_type == "kaiming_uniform":
nn.init.kaiming_uniform_(params, a=math.sqrt(5))
elif init_type == "uniform_fan_in":
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(params, -bound, bound)
elif init_type == "zero":
nn.init.zeros_(params)
elif "siren" == init_type:
assert "siren_w0" in kwargs.keys() and "is_first" in kwargs.keys()
w0 = kwargs["siren_w0"]
if kwargs["is_first"]:
w_std = 1 / fan_in
else:
w_std = math.sqrt(6.0 / fan_in) / w0
nn.init.uniform_(params, -w_std, w_std)
else:
raise NotImplementedError
def create_params_with_init(
shape, init_type="normal", include_bias=False, bias_init_type="zero", **kwargs
):
if not include_bias:
params = torch.empty([shape[0], shape[1]])
initialize_params(params, init_type, **kwargs)
return params
else:
params = torch.empty([shape[0] - 1, shape[1]])
bias = torch.empty([1, shape[1]])
initialize_params(params, init_type, **kwargs)
initialize_params(bias, bias_init_type, **kwargs)
return torch.cat([params, bias], dim=0)
def create_activation(config):
if config.type == "relu":
activation = nn.ReLU()
elif config.type == "siren":
activation = Sine(config.siren_w0)
elif config.type == "silu":
activation = nn.SiLU()
elif config.type == "damp":
activation = Damping(config.siren_w0)
else:
raise NotImplementedError
return activation