Integrate EMA-VFI alongside existing BIM-VFI with three new ComfyUI nodes: Load EMA-VFI Model, EMA-VFI Interpolate, and EMA-VFI Segment Interpolate. Architecture files vendored from MCG-NJU/EMA-VFI with device-awareness fixes (removed hardcoded .cuda() calls), warp cache management, and relative imports. InputPadder extended to support EMA-VFI's replicate center-symmetric padding. Auto-installs timm dependency on first load. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
142 lines
5.9 KiB
Python
142 lines
5.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .warplayer import warp
|
|
from .refine import *
|
|
|
|
|
|
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
|
return nn.Sequential(
|
|
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
|
padding=padding, dilation=dilation, bias=True),
|
|
nn.PReLU(out_planes)
|
|
)
|
|
|
|
|
|
class Head(nn.Module):
|
|
def __init__(self, in_planes, scale, c, in_else=17):
|
|
super(Head, self).__init__()
|
|
self.upsample = nn.Sequential(nn.PixelShuffle(2), nn.PixelShuffle(2))
|
|
self.scale = scale
|
|
self.conv = nn.Sequential(
|
|
conv(in_planes*2 // (4*4) + in_else, c),
|
|
conv(c, c),
|
|
conv(c, 5),
|
|
)
|
|
|
|
def forward(self, motion_feature, x, flow): # /16 /8 /4
|
|
motion_feature = self.upsample(motion_feature) #/4 /2 /1
|
|
if self.scale != 4:
|
|
x = F.interpolate(x, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False)
|
|
if flow != None:
|
|
if self.scale != 4:
|
|
flow = F.interpolate(flow, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) * 4. / self.scale
|
|
x = torch.cat((x, flow), 1)
|
|
x = self.conv(torch.cat([motion_feature, x], 1))
|
|
if self.scale != 4:
|
|
x = F.interpolate(x, scale_factor = self.scale // 4, mode="bilinear", align_corners=False)
|
|
flow = x[:, :4] * (self.scale // 4)
|
|
else:
|
|
flow = x[:, :4]
|
|
mask = x[:, 4:5]
|
|
return flow, mask
|
|
|
|
|
|
class MultiScaleFlow(nn.Module):
|
|
def __init__(self, backbone, **kargs):
|
|
super(MultiScaleFlow, self).__init__()
|
|
self.flow_num_stage = len(kargs['hidden_dims'])
|
|
self.feature_bone = backbone
|
|
self.block = nn.ModuleList([Head( kargs['motion_dims'][-1-i] * kargs['depths'][-1-i] + kargs['embed_dims'][-1-i],
|
|
kargs['scales'][-1-i],
|
|
kargs['hidden_dims'][-1-i],
|
|
6 if i==0 else 17)
|
|
for i in range(self.flow_num_stage)])
|
|
self.unet = Unet(kargs['c'] * 2)
|
|
|
|
def warp_features(self, xs, flow):
|
|
y0 = []
|
|
y1 = []
|
|
B = xs[0].size(0) // 2
|
|
for x in xs:
|
|
y0.append(warp(x[:B], flow[:, 0:2]))
|
|
y1.append(warp(x[B:], flow[:, 2:4]))
|
|
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
|
|
return y0, y1
|
|
|
|
def calculate_flow(self, imgs, timestep, af=None, mf=None):
|
|
img0, img1 = imgs[:, :3], imgs[:, 3:6]
|
|
B = img0.size(0)
|
|
flow, mask = None, None
|
|
# appearence_features & motion_features
|
|
if (af is None) or (mf is None):
|
|
af, mf = self.feature_bone(img0, img1)
|
|
for i in range(self.flow_num_stage):
|
|
t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float, device=imgs.device)
|
|
if flow != None:
|
|
warped_img0 = warp(img0, flow[:, :2])
|
|
warped_img1 = warp(img1, flow[:, 2:4])
|
|
flow_, mask_ = self.block[i](
|
|
torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
|
|
torch.cat((img0, img1, warped_img0, warped_img1, mask), 1),
|
|
flow
|
|
)
|
|
flow = flow + flow_
|
|
mask = mask + mask_
|
|
else:
|
|
flow, mask = self.block[i](
|
|
torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
|
|
torch.cat((img0, img1), 1),
|
|
None
|
|
)
|
|
|
|
return flow, mask
|
|
|
|
def coraseWarp_and_Refine(self, imgs, af, flow, mask):
|
|
img0, img1 = imgs[:, :3], imgs[:, 3:6]
|
|
warped_img0 = warp(img0, flow[:, :2])
|
|
warped_img1 = warp(img1, flow[:, 2:4])
|
|
c0, c1 = self.warp_features(af, flow)
|
|
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
|
res = tmp[:, :3] * 2 - 1
|
|
mask_ = torch.sigmoid(mask)
|
|
merged = warped_img0 * mask_ + warped_img1 * (1 - mask_)
|
|
pred = torch.clamp(merged + res, 0, 1)
|
|
return pred
|
|
|
|
|
|
# Actually consist of 'calculate_flow' and 'coraseWarp_and_Refine'
|
|
def forward(self, x, timestep=0.5):
|
|
img0, img1 = x[:, :3], x[:, 3:6]
|
|
B = x.size(0)
|
|
flow_list = []
|
|
merged = []
|
|
mask_list = []
|
|
warped_img0 = img0
|
|
warped_img1 = img1
|
|
flow = None
|
|
# appearence_features & motion_features
|
|
af, mf = self.feature_bone(img0, img1)
|
|
for i in range(self.flow_num_stage):
|
|
t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float, device=x.device)
|
|
if flow != None:
|
|
flow_d, mask_d = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-timestep)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
|
|
torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow)
|
|
flow = flow + flow_d
|
|
mask = mask + mask_d
|
|
else:
|
|
flow, mask = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
|
|
torch.cat((img0, img1), 1), None)
|
|
mask_list.append(torch.sigmoid(mask))
|
|
flow_list.append(flow)
|
|
warped_img0 = warp(img0, flow[:, :2])
|
|
warped_img1 = warp(img1, flow[:, 2:4])
|
|
merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i]))
|
|
|
|
c0, c1 = self.warp_features(af, flow)
|
|
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
|
res = tmp[:, :3] * 2 - 1
|
|
pred = torch.clamp(merged[-1] + res, 0, 1)
|
|
return flow_list, mask_list, merged, pred
|