Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation
Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x). - Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes - Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB) - Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors) - Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate - single_pass toggle: True=arbitrary timestep (default), False=recursive like other models - ds_factor parameter for high-res input downscaling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
508
gimm_vfi_arch/generalizable_INR/gimmvfi_r.py
Normal file
508
gimm_vfi_arch/generalizable_INR/gimmvfi_r.py
Normal file
@@ -0,0 +1,508 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# amt: https://github.com/MCG-NKU/AMT
|
||||
# motif: https://github.com/sichun233746/MoTIF
|
||||
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .configs import GIMMVFIConfig
|
||||
from .modules.coord_sampler import CoordSampler3D
|
||||
from .modules.hyponet import HypoNet
|
||||
from .modules.fi_components import *
|
||||
from .modules.fi_utils import (
|
||||
normalize_flow,
|
||||
unnormalize_flow,
|
||||
warp,
|
||||
resize,
|
||||
build_coord,
|
||||
)
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .raft.corr import BidirCorrBlock
|
||||
from .modules.softsplat import softsplat
|
||||
|
||||
|
||||
class GIMMVFI_R(nn.Module):
|
||||
Config = GIMMVFIConfig
|
||||
|
||||
def __init__(self, dtype, config: GIMMVFIConfig):
|
||||
super().__init__()
|
||||
self.config = config = config.copy()
|
||||
self.hyponet_config = config.hyponet
|
||||
self.raft_iter = 20
|
||||
|
||||
######### Encoder and Decoder Settings #########
|
||||
#self.flow_estimator = initialize_RAFT()
|
||||
cur_f_dims = [128, 96]
|
||||
f_dims = [256, 128]
|
||||
self.dtype = dtype
|
||||
|
||||
skip_channels = f_dims[-1] // 2
|
||||
self.num_flows = 3
|
||||
|
||||
self.amt_last_cproj = nn.Conv2d(cur_f_dims[0], f_dims[0], 1)
|
||||
self.amt_second_last_cproj = nn.Conv2d(cur_f_dims[1], f_dims[1], 1)
|
||||
self.amt_fproj = nn.Conv2d(f_dims[0], f_dims[0], 1)
|
||||
self.amt_init_decoder = NewInitDecoder(f_dims[0], skip_channels)
|
||||
self.amt_final_decoder = NewMultiFlowDecoder(f_dims[1], skip_channels)
|
||||
|
||||
self.amt_update4_low = self._get_updateblock(f_dims[0] // 2, 2.0)
|
||||
self.amt_update4_high = self._get_updateblock(f_dims[0] // 2, None)
|
||||
|
||||
self.amt_comb_block = nn.Sequential(
|
||||
nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3),
|
||||
nn.PReLU(6 * self.num_flows),
|
||||
nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3),
|
||||
)
|
||||
|
||||
################ GIMM settings #################
|
||||
self.coord_sampler = CoordSampler3D(config.coord_range)
|
||||
|
||||
self.g_filter = torch.nn.Parameter(
|
||||
torch.FloatTensor(
|
||||
[
|
||||
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
|
||||
[1.0 / 8.0, 1.0 / 4.0, 1.0 / 8.0],
|
||||
[1.0 / 16.0, 1.0 / 8.0, 1.0 / 16.0],
|
||||
]
|
||||
).reshape(1, 1, 1, 3, 3),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.fwarp_type = config.fwarp_type
|
||||
|
||||
self.alpha_v = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
|
||||
self.alpha_fe = torch.nn.Parameter(torch.FloatTensor([1]), requires_grad=True)
|
||||
|
||||
channel = 32
|
||||
in_dim = 2
|
||||
self.cnn_encoder = nn.Sequential(
|
||||
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
|
||||
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
LateralBlock(channel),
|
||||
LateralBlock(channel),
|
||||
LateralBlock(channel),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(
|
||||
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
|
||||
),
|
||||
)
|
||||
channel = 64
|
||||
in_dim = 64
|
||||
self.res_conv = nn.Sequential(
|
||||
nn.Conv2d(in_dim, channel // 2, 3, 1, 1, bias=True, groups=1),
|
||||
nn.Conv2d(channel // 2, channel, 3, 1, 1, bias=True, groups=1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
LateralBlock(channel),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(
|
||||
channel, channel // 2, 3, 1, 1, padding_mode="reflect", bias=True
|
||||
),
|
||||
)
|
||||
|
||||
self.hyponet = HypoNet(config.hyponet, add_coord_dim=32)
|
||||
|
||||
def _get_updateblock(self, cdim, scale_factor=None):
|
||||
return BasicUpdateBlock(
|
||||
cdim=cdim,
|
||||
hidden_dim=192,
|
||||
flow_dim=64,
|
||||
corr_dim=256,
|
||||
corr_dim2=192,
|
||||
fc_dim=188,
|
||||
scale_factor=scale_factor,
|
||||
corr_levels=4,
|
||||
radius=4,
|
||||
)
|
||||
|
||||
def cal_bidirection_flow(self, im0, im1, iters=20):
|
||||
f01, features0, fnet0 = self.flow_estimator(
|
||||
im0.to(self.dtype), im1.to(self.dtype), return_feat=True, iters=20
|
||||
)
|
||||
f10, features1, fnet1 = self.flow_estimator(
|
||||
im1.to(self.dtype), im0.to(self.dtype), return_feat=True, iters=20
|
||||
)
|
||||
corr_fn = BidirCorrBlock(self.amt_fproj(fnet0), self.amt_fproj(fnet1), radius=4)
|
||||
features0 = [
|
||||
self.amt_second_last_cproj(features0[0]),
|
||||
self.amt_last_cproj(features0[1]),
|
||||
]
|
||||
features1 = [
|
||||
self.amt_second_last_cproj(features1[0]),
|
||||
self.amt_last_cproj(features1[1]),
|
||||
]
|
||||
flow01 = f01.unsqueeze(2)
|
||||
flow10 = f10.unsqueeze(2)
|
||||
noraml_flows = torch.cat([flow01, -flow10], dim=2)
|
||||
noraml_flows, flow_scalers = normalize_flow(noraml_flows)
|
||||
|
||||
ori_flows = torch.cat([flow01, flow10], dim=2)
|
||||
return (
|
||||
noraml_flows,
|
||||
ori_flows,
|
||||
flow_scalers,
|
||||
features0,
|
||||
features1,
|
||||
corr_fn,
|
||||
torch.cat([f01.unsqueeze(2), f10.unsqueeze(2)], dim=2),
|
||||
)
|
||||
|
||||
def predict_flow(self, f, coord, t, flows):
|
||||
raft_flow01 = flows[:, :, 0].detach()
|
||||
raft_flow10 = flows[:, :, 1].detach()
|
||||
|
||||
# calculate splatting metrics
|
||||
weights1, weights2 = self.cal_splatting_weights(raft_flow01, raft_flow10)
|
||||
strtype = self.fwarp_type + "-zeroeps"
|
||||
|
||||
# b,c,h,w
|
||||
pixel_latent_0 = self.cnn_encoder(f[:, :, 0])
|
||||
pixel_latent_1 = self.cnn_encoder(f[:, :, 1])
|
||||
pixel_latent = []
|
||||
|
||||
for i, cur_t in enumerate(t):
|
||||
cur_t = cur_t.reshape(-1, 1, 1, 1)
|
||||
|
||||
tmp_pixel_latent_0 = softsplat(
|
||||
tenIn=pixel_latent_0,
|
||||
tenFlow=raft_flow01 * cur_t,
|
||||
tenMetric=weights1,
|
||||
strMode=strtype,
|
||||
)
|
||||
tmp_pixel_latent_1 = softsplat(
|
||||
tenIn=pixel_latent_1,
|
||||
tenFlow=raft_flow10 * (1 - cur_t),
|
||||
tenMetric=weights2,
|
||||
strMode=strtype,
|
||||
)
|
||||
|
||||
tmp_pixel_latent = torch.cat(
|
||||
[tmp_pixel_latent_0, tmp_pixel_latent_1], dim=1
|
||||
)
|
||||
tmp_pixel_latent = tmp_pixel_latent + self.res_conv(
|
||||
torch.cat([pixel_latent_0, pixel_latent_1, tmp_pixel_latent], dim=1)
|
||||
)
|
||||
pixel_latent.append(tmp_pixel_latent.permute(0, 2, 3, 1))
|
||||
|
||||
all_outputs = []
|
||||
permute_idx_range = [i for i in range(1, f.ndim - 1)]
|
||||
for idx, c in enumerate(coord):
|
||||
assert c[0][0, 0, 0, 0, 0] == t[idx][0].squeeze()
|
||||
assert isinstance(c, tuple)
|
||||
|
||||
if c[1] is None:
|
||||
outputs = self.hyponet(
|
||||
c, modulation_params_dict=None, pixel_latent=pixel_latent[idx]
|
||||
).permute(0, -1, *permute_idx_range)
|
||||
else:
|
||||
outputs = self.hyponet(
|
||||
c, modulation_params_dict=None, pixel_latent=pixel_latent[idx]
|
||||
)
|
||||
all_outputs.append(outputs)
|
||||
|
||||
return all_outputs
|
||||
|
||||
def warp_w_mask(self, img0, img1, ft0, ft1, mask, scale=1):
|
||||
ft0 = scale * resize(ft0, scale_factor=scale)
|
||||
ft1 = scale * resize(ft1, scale_factor=scale)
|
||||
mask = resize(mask, scale_factor=scale).sigmoid()
|
||||
img0_warp = warp(img0, ft0)
|
||||
img1_warp = warp(img1, ft1)
|
||||
img_warp = mask * img0_warp + (1 - mask) * img1_warp
|
||||
return img_warp
|
||||
|
||||
@torch.compiler.disable()
|
||||
def frame_synthesize(
|
||||
self, img_xs, flow_t, features0, features1, corr_fn, cur_t, full_img=None
|
||||
):
|
||||
"""
|
||||
flow_t: b,2,h,w
|
||||
cur_t: b,1,1,1
|
||||
"""
|
||||
batch_size = img_xs.shape[0] # b,c,t,h,w
|
||||
img0 = 2 * img_xs[:, :, 0] - 1.0
|
||||
img1 = 2 * img_xs[:, :, 1] - 1.0
|
||||
|
||||
##################### update the predicted flow #####################
|
||||
##initialize coordinates for looking up
|
||||
lookup_coord = build_coord(img_xs[:, :, 0]).to(
|
||||
img_xs[:, :, 0].device
|
||||
) # H//8,W//8
|
||||
|
||||
flow_t0_fullsize = flow_t * (-cur_t)
|
||||
flow_t1_fullsize = flow_t * (1.0 - cur_t)
|
||||
|
||||
inv = 1 / 4
|
||||
flow_t0_inr4 = inv * resize(flow_t0_fullsize, inv)
|
||||
flow_t1_inr4 = inv * resize(flow_t1_fullsize, inv)
|
||||
|
||||
############################# scale 1/4 #############################
|
||||
# i. Initialize feature t at scale 1/4
|
||||
flowt0_4, flowt1_4, ft_4_ = self.amt_init_decoder(
|
||||
features0[-1],
|
||||
features1[-1],
|
||||
flow_t0_inr4,
|
||||
flow_t1_inr4,
|
||||
img0=img0,
|
||||
img1=img1,
|
||||
)
|
||||
features0, features1 = features0[:-1], features1[:-1]
|
||||
|
||||
mask_4_, ft_4_ = ft_4_[:, :1], ft_4_[:, 1:]
|
||||
img_warp_4 = self.warp_w_mask(img0, img1, flowt0_4, flowt1_4, mask_4_, scale=4)
|
||||
img_warp_4 = (img_warp_4 + 1.0) / 2
|
||||
img_warp_4 = torch.clamp(img_warp_4, 0, 1)
|
||||
|
||||
corr_4, flow_4_lr = self._amt_corr_scale_lookup(
|
||||
corr_fn, lookup_coord, flowt0_4, flowt1_4, cur_t, downsample=2
|
||||
)
|
||||
|
||||
delta_ft_4_, delta_flow_4 = self.amt_update4_low(ft_4_, flow_4_lr, corr_4)
|
||||
delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
|
||||
flowt0_4 = flowt0_4 + delta_flow0_4
|
||||
flowt1_4 = flowt1_4 + delta_flow1_4
|
||||
ft_4_ = ft_4_ + delta_ft_4_
|
||||
|
||||
# iii. residue update with lookup corr
|
||||
corr_4 = resize(corr_4, scale_factor=2.0)
|
||||
|
||||
flow_4 = torch.cat([flowt0_4, flowt1_4], dim=1)
|
||||
delta_ft_4_, delta_flow_4 = self.amt_update4_high(ft_4_, flow_4, corr_4)
|
||||
flowt0_4 = flowt0_4 + delta_flow_4[:, :2]
|
||||
flowt1_4 = flowt1_4 + delta_flow_4[:, 2:4]
|
||||
ft_4_ = ft_4_ + delta_ft_4_
|
||||
|
||||
############################# scale 1/1 #############################
|
||||
flowt0_1, flowt1_1, mask, img_res = self.amt_final_decoder(
|
||||
ft_4_,
|
||||
features0[0],
|
||||
features1[0],
|
||||
flowt0_4,
|
||||
flowt1_4,
|
||||
mask=mask_4_,
|
||||
img0=img0,
|
||||
img1=img1,
|
||||
)
|
||||
|
||||
if full_img is not None:
|
||||
img0 = 2 * full_img[:, :, 0] - 1.0
|
||||
img1 = 2 * full_img[:, :, 1] - 1.0
|
||||
inv = img1.shape[2] / flowt0_1.shape[2]
|
||||
flowt0_1 = inv * resize(flowt0_1, scale_factor=inv)
|
||||
flowt1_1 = inv * resize(flowt1_1, scale_factor=inv)
|
||||
flow_t0_fullsize = inv * resize(flow_t0_fullsize, scale_factor=inv)
|
||||
flow_t1_fullsize = inv * resize(flow_t1_fullsize, scale_factor=inv)
|
||||
mask = resize(mask, scale_factor=inv)
|
||||
img_res = resize(img_res, scale_factor=inv)
|
||||
|
||||
imgt_pred = multi_flow_combine(
|
||||
self.amt_comb_block, img0, img1, flowt0_1, flowt1_1, mask, img_res, None
|
||||
)
|
||||
imgt_pred = torch.clamp(imgt_pred, 0, 1)
|
||||
|
||||
######################################################################
|
||||
|
||||
flowt0_1 = flowt0_1.reshape(
|
||||
batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1]
|
||||
)
|
||||
flowt1_1 = flowt1_1.reshape(
|
||||
batch_size, self.num_flows, 2, img0.shape[-2], img0.shape[-1]
|
||||
)
|
||||
|
||||
flowt0_pred = [flowt0_1, flowt0_4]
|
||||
flowt1_pred = [flowt1_1, flowt1_4]
|
||||
other_pred = [img_warp_4]
|
||||
return imgt_pred, flowt0_pred, flowt1_pred, other_pred
|
||||
|
||||
def forward(self, img_xs, coord=None, t=None, iters=None, ds_factor=None):
|
||||
assert isinstance(t, list)
|
||||
assert isinstance(coord, list)
|
||||
assert len(t) == len(coord)
|
||||
full_size_img = None
|
||||
if ds_factor is not None:
|
||||
full_size_img = img_xs.clone()
|
||||
img_xs = torch.cat(
|
||||
[
|
||||
resize(img_xs[:, :, 0], scale_factor=ds_factor).unsqueeze(2),
|
||||
resize(img_xs[:, :, 1], scale_factor=ds_factor).unsqueeze(2),
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
iters = self.raft_iter if iters is None else iters
|
||||
(
|
||||
normal_flows,
|
||||
flows,
|
||||
flow_scalers,
|
||||
features0,
|
||||
features1,
|
||||
corr_fn,
|
||||
preserved_raft_flows,
|
||||
) = self.cal_bidirection_flow(
|
||||
255 * img_xs[:, :, 0], 255 * img_xs[:, :, 1], iters=iters
|
||||
)
|
||||
assert coord is not None
|
||||
|
||||
# List of flows
|
||||
normal_inr_flows = self.predict_flow(normal_flows, coord, t, flows)
|
||||
|
||||
############ Unnormalize the predicted/reconstructed flow ############
|
||||
start_idx = 0
|
||||
if coord[0][1] is not None:
|
||||
# Subsmapled flows for reconstruction supervision in the GIMM module
|
||||
# In such case, by default, first two coords are subsampled for supervision up-mentioned
|
||||
# normalized flow_t versus positive t-axis
|
||||
assert len(coord) > 2
|
||||
flow_t = [
|
||||
unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze()
|
||||
for i in range(2, len(coord))
|
||||
]
|
||||
start_idx = 2
|
||||
else:
|
||||
flow_t = [
|
||||
unnormalize_flow(normal_inr_flows[i], flow_scalers).squeeze()
|
||||
for i in range(len(coord))
|
||||
]
|
||||
|
||||
imgt_preds, flowt0_preds, flowt1_preds, all_others = [], [], [], []
|
||||
|
||||
for idx in range(start_idx, len(coord)):
|
||||
cur_flow_t = flow_t[idx - start_idx]
|
||||
cur_t = t[idx].reshape(-1, 1, 1, 1)
|
||||
if cur_flow_t.ndim != 4:
|
||||
cur_flow_t = cur_flow_t.unsqueeze(0)
|
||||
assert cur_flow_t.ndim == 4
|
||||
|
||||
imgt_pred, flowt0_pred, flowt1_pred, others = self.frame_synthesize(
|
||||
img_xs,
|
||||
cur_flow_t,
|
||||
features0,
|
||||
features1,
|
||||
corr_fn,
|
||||
cur_t,
|
||||
full_img=full_size_img,
|
||||
)
|
||||
|
||||
imgt_preds.append(imgt_pred)
|
||||
flowt0_preds.append(flowt0_pred)
|
||||
flowt1_preds.append(flowt1_pred)
|
||||
all_others.append(others)
|
||||
|
||||
return {
|
||||
"imgt_pred": imgt_preds,
|
||||
"other_pred": all_others,
|
||||
"flowt0_pred": flowt0_preds,
|
||||
"flowt1_pred": flowt1_preds,
|
||||
"raft_flow": preserved_raft_flows,
|
||||
"ninrflow": normal_inr_flows,
|
||||
"nflow": normal_flows,
|
||||
"flowt": flow_t,
|
||||
}
|
||||
|
||||
def warp_frame(self, frame, flow):
|
||||
return warp(frame, flow)
|
||||
|
||||
def compute_psnr(self, preds, targets, reduction="mean"):
|
||||
assert reduction in ["mean", "sum", "none"]
|
||||
batch_size = preds.shape[0]
|
||||
sample_mses = torch.reshape((preds - targets) ** 2, (batch_size, -1)).mean(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
if reduction == "mean":
|
||||
psnr = (-10 * torch.log10(sample_mses)).mean()
|
||||
elif reduction == "sum":
|
||||
psnr = (-10 * torch.log10(sample_mses)).sum()
|
||||
else:
|
||||
psnr = -10 * torch.log10(sample_mses)
|
||||
|
||||
return psnr
|
||||
|
||||
def sample_coord_input(
|
||||
self,
|
||||
batch_size,
|
||||
s_shape,
|
||||
t_ids,
|
||||
coord_range=None,
|
||||
upsample_ratio=1.0,
|
||||
device=None,
|
||||
):
|
||||
assert device is not None
|
||||
assert coord_range is None
|
||||
coord_inputs = self.coord_sampler(
|
||||
batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
|
||||
)
|
||||
return coord_inputs
|
||||
|
||||
def cal_splatting_weights(self, raft_flow01, raft_flow10):
|
||||
batch_size = raft_flow01.shape[0]
|
||||
raft_flows = torch.cat([raft_flow01, raft_flow10], dim=0)
|
||||
|
||||
## flow variance metric
|
||||
sqaure_mean, mean_square = torch.split(
|
||||
F.conv3d(
|
||||
F.pad(
|
||||
torch.cat([raft_flows**2, raft_flows], 1),
|
||||
(1, 1, 1, 1),
|
||||
mode="reflect",
|
||||
).unsqueeze(1),
|
||||
self.g_filter,
|
||||
).squeeze(1),
|
||||
2,
|
||||
dim=1,
|
||||
)
|
||||
var = (
|
||||
(sqaure_mean - mean_square**2)
|
||||
.clamp(1e-9, None)
|
||||
.sqrt()
|
||||
.mean(1)
|
||||
.unsqueeze(1)
|
||||
)
|
||||
var01 = var[:batch_size]
|
||||
var10 = var[batch_size:]
|
||||
|
||||
## flow warp metirc
|
||||
f01_warp = -warp(raft_flow10, raft_flow01)
|
||||
f10_warp = -warp(raft_flow01, raft_flow10)
|
||||
err01 = (
|
||||
torch.nn.functional.l1_loss(
|
||||
input=f01_warp, target=raft_flow01, reduction="none"
|
||||
)
|
||||
.mean(1)
|
||||
.unsqueeze(1)
|
||||
)
|
||||
err02 = (
|
||||
torch.nn.functional.l1_loss(
|
||||
input=f10_warp, target=raft_flow10, reduction="none"
|
||||
)
|
||||
.mean(1)
|
||||
.unsqueeze(1)
|
||||
)
|
||||
|
||||
weights1 = 1 / (1 + err01 * self.alpha_fe) + 1 / (1 + var01 * self.alpha_v)
|
||||
weights2 = 1 / (1 + err02 * self.alpha_fe) + 1 / (1 + var10 * self.alpha_v)
|
||||
|
||||
return weights1, weights2
|
||||
|
||||
def _amt_corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
|
||||
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
|
||||
# based on linear assumption
|
||||
t0_scale = 1.0 / embt
|
||||
t1_scale = 1.0 / (1.0 - embt)
|
||||
if downsample != 1:
|
||||
inv = 1 / downsample
|
||||
flow0 = inv * resize(flow0, scale_factor=inv)
|
||||
flow1 = inv * resize(flow1, scale_factor=inv)
|
||||
|
||||
corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
|
||||
corr = torch.cat([corr0, corr1], dim=1)
|
||||
flow = torch.cat([flow0, flow1], dim=1)
|
||||
return corr, flow
|
||||
Reference in New Issue
Block a user