Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x). - Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes - Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB) - Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors) - Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate - single_pass toggle: True=arbitrary timestep (default), False=recursive like other models - ds_factor parameter for high-res input downscaling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
199 lines
7.2 KiB
Python
199 lines
7.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
# --------------------------------------------------------
|
|
# References:
|
|
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
|
|
# --------------------------------------------------------
|
|
|
|
import einops
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from omegaconf import OmegaConf
|
|
|
|
from ..configs import HypoNetConfig
|
|
from .utils import create_params_with_init, create_activation
|
|
|
|
|
|
class HypoNet(nn.Module):
|
|
r"""
|
|
The Hyponetwork with a coordinate-based MLP to be modulated.
|
|
"""
|
|
|
|
def __init__(self, config: HypoNetConfig, add_coord_dim=32):
|
|
super().__init__()
|
|
self.config = config
|
|
self.use_bias = config.use_bias
|
|
self.init_config = config.initialization
|
|
self.num_layer = config.n_layer
|
|
self.hidden_dims = config.hidden_dim
|
|
self.add_coord_dim = add_coord_dim
|
|
|
|
if len(self.hidden_dims) == 1:
|
|
self.hidden_dims = OmegaConf.to_object(self.hidden_dims) * (
|
|
self.num_layer - 1
|
|
) # exclude output layer
|
|
else:
|
|
assert len(self.hidden_dims) == self.num_layer - 1
|
|
|
|
if self.config.activation.type == "siren":
|
|
assert self.init_config.weight_init_type == "siren"
|
|
assert self.init_config.bias_init_type == "siren"
|
|
|
|
# after computes the shape of trainable parameters, initialize them
|
|
self.params_dict = None
|
|
self.params_shape_dict = self.compute_params_shape()
|
|
self.activation = create_activation(self.config.activation)
|
|
self.build_base_params_dict(self.config.initialization)
|
|
self.output_bias = config.output_bias
|
|
|
|
self.normalize_weight = config.normalize_weight
|
|
|
|
self.ignore_base_param_dict = {name: False for name in self.params_dict}
|
|
|
|
@staticmethod
|
|
def subsample_coords(coords, subcoord_idx=None):
|
|
if subcoord_idx is None:
|
|
return coords
|
|
|
|
batch_size = coords.shape[0]
|
|
sub_coords = []
|
|
coords = coords.view(batch_size, -1, coords.shape[-1])
|
|
for idx in range(batch_size):
|
|
sub_coords.append(coords[idx : idx + 1, subcoord_idx[idx]])
|
|
sub_coords = torch.cat(sub_coords, dim=0)
|
|
return sub_coords
|
|
|
|
def forward(self, coord, modulation_params_dict=None, pixel_latent=None):
|
|
sub_idx = None
|
|
if isinstance(coord, tuple):
|
|
coord, sub_idx = coord[0], coord[1]
|
|
|
|
if modulation_params_dict is not None:
|
|
self.check_valid_param_keys(modulation_params_dict)
|
|
|
|
batch_size, coord_shape, input_dim = (
|
|
coord.shape[0],
|
|
coord.shape[1:-1],
|
|
coord.shape[-1],
|
|
)
|
|
coord = coord.view(batch_size, -1, input_dim) # flatten the coordinates
|
|
assert pixel_latent is not None
|
|
pixel_latent = F.interpolate(
|
|
pixel_latent.permute(0, 3, 1, 2),
|
|
size=(coord_shape[1], coord_shape[2]),
|
|
mode="bilinear",
|
|
).permute(0, 2, 3, 1)
|
|
pixel_latent_dim = pixel_latent.shape[-1]
|
|
pixel_latent = pixel_latent.view(batch_size, -1, pixel_latent_dim)
|
|
hidden = coord
|
|
|
|
hidden = torch.cat([pixel_latent, hidden], dim=-1)
|
|
|
|
hidden = self.subsample_coords(hidden, sub_idx)
|
|
|
|
for idx in range(self.config.n_layer):
|
|
param_key = f"linear_wb{idx}"
|
|
base_param = einops.repeat(
|
|
self.params_dict[param_key], "n m -> b n m", b=batch_size
|
|
)
|
|
|
|
if (modulation_params_dict is not None) and (
|
|
param_key in modulation_params_dict.keys()
|
|
):
|
|
modulation_param = modulation_params_dict[param_key]
|
|
else:
|
|
if self.config.use_bias:
|
|
modulation_param = torch.ones_like(base_param[:, :-1])
|
|
else:
|
|
modulation_param = torch.ones_like(base_param)
|
|
|
|
if self.config.use_bias:
|
|
ones = torch.ones(*hidden.shape[:-1], 1, device=hidden.device)
|
|
hidden = torch.cat([hidden, ones], dim=-1)
|
|
|
|
base_param_w, base_param_b = (
|
|
base_param[:, :-1, :],
|
|
base_param[:, -1:, :],
|
|
)
|
|
|
|
if self.ignore_base_param_dict[param_key]:
|
|
base_param_w = 1.0
|
|
param_w = base_param_w * modulation_param
|
|
if self.normalize_weight:
|
|
param_w = F.normalize(param_w, dim=1)
|
|
modulated_param = torch.cat([param_w, base_param_b], dim=1)
|
|
else:
|
|
if self.ignore_base_param_dict[param_key]:
|
|
base_param = 1.0
|
|
if self.normalize_weight:
|
|
modulated_param = F.normalize(base_param * modulation_param, dim=1)
|
|
else:
|
|
modulated_param = base_param * modulation_param
|
|
# print([param_key,hidden.shape,modulated_param.shape])
|
|
hidden = torch.bmm(hidden, modulated_param)
|
|
|
|
if idx < (self.config.n_layer - 1):
|
|
hidden = self.activation(hidden)
|
|
|
|
outputs = hidden + self.output_bias
|
|
if sub_idx is None:
|
|
outputs = outputs.view(batch_size, *coord_shape, -1)
|
|
return outputs
|
|
|
|
def compute_params_shape(self):
|
|
"""
|
|
Computes the shape of MLP parameters.
|
|
The computed shapes are used to build the initial weights by `build_base_params_dict`.
|
|
"""
|
|
config = self.config
|
|
use_bias = self.use_bias
|
|
|
|
param_shape_dict = dict()
|
|
|
|
fan_in = config.input_dim
|
|
add_dim = self.add_coord_dim
|
|
fan_in = fan_in + add_dim
|
|
fan_in = fan_in + 1 if use_bias else fan_in
|
|
|
|
for i in range(config.n_layer - 1):
|
|
fan_out = self.hidden_dims[i]
|
|
param_shape_dict[f"linear_wb{i}"] = (fan_in, fan_out)
|
|
fan_in = fan_out + 1 if use_bias else fan_out
|
|
|
|
param_shape_dict[f"linear_wb{config.n_layer-1}"] = (fan_in, config.output_dim)
|
|
return param_shape_dict
|
|
|
|
def build_base_params_dict(self, init_config):
|
|
assert self.params_shape_dict
|
|
params_dict = nn.ParameterDict()
|
|
for idx, (name, shape) in enumerate(self.params_shape_dict.items()):
|
|
is_first = idx == 0
|
|
params = create_params_with_init(
|
|
shape,
|
|
init_type=init_config.weight_init_type,
|
|
include_bias=self.use_bias,
|
|
bias_init_type=init_config.bias_init_type,
|
|
is_first=is_first,
|
|
siren_w0=self.config.activation.siren_w0, # valid only for siren
|
|
)
|
|
params = nn.Parameter(params)
|
|
params_dict[name] = params
|
|
self.set_params_dict(params_dict)
|
|
|
|
def check_valid_param_keys(self, params_dict):
|
|
predefined_params_keys = self.params_shape_dict.keys()
|
|
for param_key in params_dict.keys():
|
|
if param_key in predefined_params_keys:
|
|
continue
|
|
else:
|
|
raise KeyError
|
|
|
|
def set_params_dict(self, params_dict):
|
|
self.check_valid_param_keys(params_dict)
|
|
self.params_dict = params_dict
|