Files
ComfyUI-Tween/gimm_vfi_arch/generalizable_INR/modules/utils.py
Ethanfel d642255e70 Add GIMM-VFI support (NeurIPS 2024) with single-pass arbitrary-timestep interpolation
Integrates GIMM-VFI alongside existing BIM/EMA/SGM models. Key feature: generates
all intermediate frames in one forward pass (no recursive 2x passes needed for 4x/8x).

- Vendor gimm_vfi_arch/ from kijai/ComfyUI-GIMM-VFI with device fixes
- Two variants: RAFT-based (~80MB) and FlowFormer-based (~123MB)
- Auto-download checkpoints from HuggingFace (Kijai/GIMM-VFI_safetensors)
- Three new nodes: Load GIMM-VFI Model, GIMM-VFI Interpolate, GIMM-VFI Segment Interpolate
- single_pass toggle: True=arbitrary timestep (default), False=recursive like other models
- ds_factor parameter for high-res input downscaling

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 13:11:45 +01:00

77 lines
2.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# ginr-ipc: https://github.com/kakaobrain/ginr-ipc
# --------------------------------------------------------
import math
import torch
import torch.nn as nn
from .layers import Sine, Damping
def convert_int_to_list(size, len_list=2):
if isinstance(size, int):
return [size] * len_list
else:
assert len(size) == len_list
return size
def initialize_params(params, init_type, **kwargs):
fan_in, fan_out = params.shape[0], params.shape[1]
if init_type is None or init_type == "normal":
nn.init.normal_(params)
elif init_type == "kaiming_uniform":
nn.init.kaiming_uniform_(params, a=math.sqrt(5))
elif init_type == "uniform_fan_in":
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(params, -bound, bound)
elif init_type == "zero":
nn.init.zeros_(params)
elif "siren" == init_type:
assert "siren_w0" in kwargs.keys() and "is_first" in kwargs.keys()
w0 = kwargs["siren_w0"]
if kwargs["is_first"]:
w_std = 1 / fan_in
else:
w_std = math.sqrt(6.0 / fan_in) / w0
nn.init.uniform_(params, -w_std, w_std)
else:
raise NotImplementedError
def create_params_with_init(
shape, init_type="normal", include_bias=False, bias_init_type="zero", **kwargs
):
if not include_bias:
params = torch.empty([shape[0], shape[1]])
initialize_params(params, init_type, **kwargs)
return params
else:
params = torch.empty([shape[0] - 1, shape[1]])
bias = torch.empty([1, shape[1]])
initialize_params(params, init_type, **kwargs)
initialize_params(bias, bias_init_type, **kwargs)
return torch.cat([params, bias], dim=0)
def create_activation(config):
if config.type == "relu":
activation = nn.ReLU()
elif config.type == "siren":
activation = Sine(config.siren_w0)
elif config.type == "silu":
activation = nn.SiLU()
elif config.type == "damp":
activation = Damping(config.siren_w0)
else:
raise NotImplementedError
return activation