Files
ComfyUI-Tween/gimm_vfi_arch/generalizable_INR/modules/hyponet.py
Ethanfel d642255e70 Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation
Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates
all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x).

- Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes
- Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB)
- Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors)
- Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate
- single_pass toggle: True=arbitrary timestep (default), False=recursive like other models
- ds_factor parameter for high-res input downscaling

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 13:11:45 +01:00

199 lines
7.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from ..configs import HypoNetConfig
from .utils import create_params_with_init, create_activation
class HypoNet(nn.Module):
r"""
The Hyponetwork with a coordinate-based MLP to be modulated.
"""
def __init__(self, config: HypoNetConfig, add_coord_dim=32):
super().__init__()
self.config = config
self.use_bias = config.use_bias
self.init_config = config.initialization
self.num_layer = config.n_layer
self.hidden_dims = config.hidden_dim
self.add_coord_dim = add_coord_dim
if len(self.hidden_dims) == 1:
self.hidden_dims = OmegaConf.to_object(self.hidden_dims) * (
self.num_layer - 1
) # exclude output layer
else:
assert len(self.hidden_dims) == self.num_layer - 1
if self.config.activation.type == "siren":
assert self.init_config.weight_init_type == "siren"
assert self.init_config.bias_init_type == "siren"
# after computes the shape of trainable parameters, initialize them
self.params_dict = None
self.params_shape_dict = self.compute_params_shape()
self.activation = create_activation(self.config.activation)
self.build_base_params_dict(self.config.initialization)
self.output_bias = config.output_bias
self.normalize_weight = config.normalize_weight
self.ignore_base_param_dict = {name: False for name in self.params_dict}
@staticmethod
def subsample_coords(coords, subcoord_idx=None):
if subcoord_idx is None:
return coords
batch_size = coords.shape[0]
sub_coords = []
coords = coords.view(batch_size, -1, coords.shape[-1])
for idx in range(batch_size):
sub_coords.append(coords[idx : idx + 1, subcoord_idx[idx]])
sub_coords = torch.cat(sub_coords, dim=0)
return sub_coords
def forward(self, coord, modulation_params_dict=None, pixel_latent=None):
sub_idx = None
if isinstance(coord, tuple):
coord, sub_idx = coord[0], coord[1]
if modulation_params_dict is not None:
self.check_valid_param_keys(modulation_params_dict)
batch_size, coord_shape, input_dim = (
coord.shape[0],
coord.shape[1:-1],
coord.shape[-1],
)
coord = coord.view(batch_size, -1, input_dim) # flatten the coordinates
assert pixel_latent is not None
pixel_latent = F.interpolate(
pixel_latent.permute(0, 3, 1, 2),
size=(coord_shape[1], coord_shape[2]),
mode="bilinear",
).permute(0, 2, 3, 1)
pixel_latent_dim = pixel_latent.shape[-1]
pixel_latent = pixel_latent.view(batch_size, -1, pixel_latent_dim)
hidden = coord
hidden = torch.cat([pixel_latent, hidden], dim=-1)
hidden = self.subsample_coords(hidden, sub_idx)
for idx in range(self.config.n_layer):
param_key = f"linear_wb{idx}"
base_param = einops.repeat(
self.params_dict[param_key], "n m -> b n m", b=batch_size
)
if (modulation_params_dict is not None) and (
param_key in modulation_params_dict.keys()
):
modulation_param = modulation_params_dict[param_key]
else:
if self.config.use_bias:
modulation_param = torch.ones_like(base_param[:, :-1])
else:
modulation_param = torch.ones_like(base_param)
if self.config.use_bias:
ones = torch.ones(*hidden.shape[:-1], 1, device=hidden.device)
hidden = torch.cat([hidden, ones], dim=-1)
base_param_w, base_param_b = (
base_param[:, :-1, :],
base_param[:, -1:, :],
)
if self.ignore_base_param_dict[param_key]:
base_param_w = 1.0
param_w = base_param_w * modulation_param
if self.normalize_weight:
param_w = F.normalize(param_w, dim=1)
modulated_param = torch.cat([param_w, base_param_b], dim=1)
else:
if self.ignore_base_param_dict[param_key]:
base_param = 1.0
if self.normalize_weight:
modulated_param = F.normalize(base_param * modulation_param, dim=1)
else:
modulated_param = base_param * modulation_param
# print([param_key,hidden.shape,modulated_param.shape])
hidden = torch.bmm(hidden, modulated_param)
if idx < (self.config.n_layer - 1):
hidden = self.activation(hidden)
outputs = hidden + self.output_bias
if sub_idx is None:
outputs = outputs.view(batch_size, *coord_shape, -1)
return outputs
def compute_params_shape(self):
"""
Computes the shape of MLP parameters.
The computed shapes are used to build the initial weights by `build_base_params_dict`.
"""
config = self.config
use_bias = self.use_bias
param_shape_dict = dict()
fan_in = config.input_dim
add_dim = self.add_coord_dim
fan_in = fan_in + add_dim
fan_in = fan_in + 1 if use_bias else fan_in
for i in range(config.n_layer - 1):
fan_out = self.hidden_dims[i]
param_shape_dict[f"linear_wb{i}"] = (fan_in, fan_out)
fan_in = fan_out + 1 if use_bias else fan_out
param_shape_dict[f"linear_wb{config.n_layer-1}"] = (fan_in, config.output_dim)
return param_shape_dict
def build_base_params_dict(self, init_config):
assert self.params_shape_dict
params_dict = nn.ParameterDict()
for idx, (name, shape) in enumerate(self.params_shape_dict.items()):
is_first = idx == 0
params = create_params_with_init(
shape,
init_type=init_config.weight_init_type,
include_bias=self.use_bias,
bias_init_type=init_config.bias_init_type,
is_first=is_first,
siren_w0=self.config.activation.siren_w0, # valid only for siren
)
params = nn.Parameter(params)
params_dict[name] = params
self.set_params_dict(params_dict)
def check_valid_param_keys(self, params_dict):
predefined_params_keys = self.params_shape_dict.keys()
for param_key in params_dict.keys():
if param_key in predefined_params_keys:
continue
else:
raise KeyError
def set_params_dict(self, params_dict):
self.check_valid_param_keys(params_dict)
self.params_dict = params_dict