- Use generate_draft_block_mask_refined for sparse attention mask (matches naxci1's generate_draft_block_mask_sage with proper half-block key scoring) - Remove spurious repeat_interleave(2, dim=-1) from generate_draft_block_mask that doubled the key dimension incorrectly - Add torch.clamp(0, 1) to _to_frames output (matches naxci1's tensor2video) - Add .to(self.device) on LQ video slices in streaming loop for all pipelines Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
872 lines
32 KiB
Python
872 lines
32 KiB
Python
import logging
|
|
import math
|
|
import os
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .bim_vfi_arch import BiMVFI
|
|
from .ema_vfi_arch import feature_extractor as ema_feature_extractor
|
|
from .ema_vfi_arch import MultiScaleFlow as EMAMultiScaleFlow
|
|
from .sgm_vfi_arch import feature_extractor as sgm_feature_extractor
|
|
from .sgm_vfi_arch import MultiScaleFlow as SGMMultiScaleFlow
|
|
from .utils.padder import InputPadder
|
|
|
|
logger = logging.getLogger("Tween")
|
|
|
|
|
|
class BiMVFIModel:
|
|
"""Clean inference wrapper around BiMVFI for ComfyUI integration."""
|
|
|
|
def __init__(self, checkpoint_path, pyr_level=3, auto_pyr_level=True, device="cpu"):
|
|
self.pyr_level = pyr_level
|
|
self.auto_pyr_level = auto_pyr_level
|
|
self.device = device
|
|
|
|
self.model = BiMVFI(pyr_level=pyr_level, feat_channels=32)
|
|
self._load_checkpoint(checkpoint_path)
|
|
self.model.eval()
|
|
self.model.to(device)
|
|
|
|
def _load_checkpoint(self, checkpoint_path):
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
|
|
# Handle different checkpoint formats
|
|
if "model" in checkpoint:
|
|
state_dict = checkpoint["model"]
|
|
elif "state_dict" in checkpoint:
|
|
state_dict = checkpoint["state_dict"]
|
|
else:
|
|
state_dict = checkpoint
|
|
|
|
# Strip common prefixes (e.g. "module." from DDP or "model." from wrapper)
|
|
cleaned = {}
|
|
for k, v in state_dict.items():
|
|
key = k
|
|
if key.startswith("module."):
|
|
key = key[len("module."):]
|
|
if key.startswith("model."):
|
|
key = key[len("model."):]
|
|
cleaned[key] = v
|
|
|
|
self.model.load_state_dict(cleaned)
|
|
|
|
def to(self, device):
|
|
self.device = device
|
|
self.model.to(device)
|
|
return self
|
|
|
|
def _get_pyr_level(self, h):
|
|
if self.auto_pyr_level:
|
|
if h >= 2160:
|
|
return 7
|
|
elif h >= 1080:
|
|
return 6
|
|
elif h >= 540:
|
|
return 5
|
|
else:
|
|
return 3
|
|
return self.pyr_level
|
|
|
|
@torch.no_grad()
|
|
def interpolate_pair(self, frame0, frame1, time_step=0.5):
|
|
"""Interpolate a single frame between two input frames.
|
|
|
|
Args:
|
|
frame0: [1, C, H, W] tensor, float32, range [0, 1]
|
|
frame1: [1, C, H, W] tensor, float32, range [0, 1]
|
|
time_step: float in (0, 1), temporal position of interpolated frame
|
|
|
|
Returns:
|
|
Interpolated frame as [1, C, H, W] tensor, float32, clamped to [0, 1]
|
|
"""
|
|
device = next(self.model.parameters()).device
|
|
img0 = frame0.to(device)
|
|
img1 = frame1.to(device)
|
|
|
|
pyr_level = self._get_pyr_level(img0.shape[2])
|
|
time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1)
|
|
|
|
result_dict = self.model(
|
|
img0=img0, img1=img1,
|
|
time_step=time_step_tensor,
|
|
pyr_level=pyr_level,
|
|
)
|
|
|
|
interp = result_dict["imgt_pred"]
|
|
interp = torch.clamp(interp, 0, 1)
|
|
return interp
|
|
|
|
@torch.no_grad()
|
|
def interpolate_batch(self, frames0, frames1, time_step=0.5):
|
|
"""Interpolate multiple frame pairs at once.
|
|
|
|
Args:
|
|
frames0: [B, C, H, W] tensor, float32, range [0, 1]
|
|
frames1: [B, C, H, W] tensor, float32, range [0, 1]
|
|
time_step: float in (0, 1), temporal position of interpolated frames
|
|
|
|
Returns:
|
|
Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1]
|
|
"""
|
|
device = next(self.model.parameters()).device
|
|
img0 = frames0.to(device)
|
|
img1 = frames1.to(device)
|
|
|
|
pyr_level = self._get_pyr_level(img0.shape[2])
|
|
time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1)
|
|
|
|
result_dict = self.model(
|
|
img0=img0, img1=img1,
|
|
time_step=time_step_tensor,
|
|
pyr_level=pyr_level,
|
|
)
|
|
|
|
interp = result_dict["imgt_pred"]
|
|
interp = torch.clamp(interp, 0, 1)
|
|
return interp
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# EMA-VFI model wrapper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _ema_init_model_config(F=32, W=7, depth=[2, 2, 2, 4, 4]):
|
|
"""Build EMA-VFI model config dicts (backbone + multiscale)."""
|
|
return {
|
|
'embed_dims': [F, 2*F, 4*F, 8*F, 16*F],
|
|
'motion_dims': [0, 0, 0, 8*F//depth[-2], 16*F//depth[-1]],
|
|
'num_heads': [8*F//32, 16*F//32],
|
|
'mlp_ratios': [4, 4],
|
|
'qkv_bias': True,
|
|
'norm_layer': partial(nn.LayerNorm, eps=1e-6),
|
|
'depths': depth,
|
|
'window_sizes': [W, W]
|
|
}, {
|
|
'embed_dims': [F, 2*F, 4*F, 8*F, 16*F],
|
|
'motion_dims': [0, 0, 0, 8*F//depth[-2], 16*F//depth[-1]],
|
|
'depths': depth,
|
|
'num_heads': [8*F//32, 16*F//32],
|
|
'window_sizes': [W, W],
|
|
'scales': [4, 8, 16],
|
|
'hidden_dims': [4*F, 4*F],
|
|
'c': F
|
|
}
|
|
|
|
|
|
def _ema_detect_variant(filename):
|
|
"""Auto-detect model variant and timestep support from filename.
|
|
|
|
Returns (F, depth, supports_arbitrary_t).
|
|
"""
|
|
name = filename.lower()
|
|
is_small = "small" in name
|
|
supports_t = "_t." in name or "_t_" in name or name.endswith("_t")
|
|
|
|
if is_small:
|
|
return 16, [2, 2, 2, 2, 2], supports_t
|
|
else:
|
|
return 32, [2, 2, 2, 4, 4], supports_t
|
|
|
|
|
|
class EMAVFIModel:
|
|
"""Clean inference wrapper around EMA-VFI for ComfyUI integration."""
|
|
|
|
def __init__(self, checkpoint_path, variant="auto", tta=False, device="cpu"):
|
|
import os
|
|
filename = os.path.basename(checkpoint_path)
|
|
|
|
if variant == "auto":
|
|
F_dim, depth, self.supports_arbitrary_t = _ema_detect_variant(filename)
|
|
elif variant == "small":
|
|
F_dim, depth = 16, [2, 2, 2, 2, 2]
|
|
self.supports_arbitrary_t = "_t." in filename.lower() or "_t_" in filename.lower()
|
|
else: # large
|
|
F_dim, depth = 32, [2, 2, 2, 4, 4]
|
|
self.supports_arbitrary_t = "_t." in filename.lower() or "_t_" in filename.lower()
|
|
|
|
self.tta = tta
|
|
self.device = device
|
|
self.variant_name = "small" if F_dim == 16 else "large"
|
|
|
|
backbone_cfg, multiscale_cfg = _ema_init_model_config(F=F_dim, depth=depth)
|
|
backbone = ema_feature_extractor(**backbone_cfg)
|
|
self.model = EMAMultiScaleFlow(backbone, **multiscale_cfg)
|
|
self._load_checkpoint(checkpoint_path)
|
|
self.model.eval()
|
|
self.model.to(device)
|
|
|
|
def _load_checkpoint(self, checkpoint_path):
|
|
"""Load checkpoint with module prefix stripping and buffer filtering."""
|
|
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
|
|
# Handle wrapped checkpoint formats
|
|
if isinstance(state_dict, dict):
|
|
if "model" in state_dict:
|
|
state_dict = state_dict["model"]
|
|
elif "state_dict" in state_dict:
|
|
state_dict = state_dict["state_dict"]
|
|
|
|
# Strip "module." prefix and filter out attn_mask/HW buffers
|
|
cleaned = {}
|
|
for k, v in state_dict.items():
|
|
if "attn_mask" in k or k.endswith(".HW"):
|
|
continue
|
|
key = k
|
|
if key.startswith("module."):
|
|
key = key[len("module."):]
|
|
cleaned[key] = v
|
|
|
|
self.model.load_state_dict(cleaned)
|
|
|
|
def to(self, device):
|
|
"""Move model to device (returns self for chaining)."""
|
|
self.device = device
|
|
self.model.to(device)
|
|
return self
|
|
|
|
@torch.no_grad()
|
|
def _inference(self, img0, img1, timestep=0.5):
|
|
"""Run single inference pass. Inputs already padded, on device."""
|
|
B = img0.shape[0]
|
|
imgs = torch.cat((img0, img1), 1)
|
|
|
|
if self.tta:
|
|
imgs_ = imgs.flip(2).flip(3)
|
|
input_batch = torch.cat((imgs, imgs_), 0)
|
|
_, _, _, preds = self.model(input_batch, timestep=timestep)
|
|
return (preds[:B] + preds[B:].flip(2).flip(3)) / 2.
|
|
else:
|
|
_, _, _, pred = self.model(imgs, timestep=timestep)
|
|
return pred
|
|
|
|
@torch.no_grad()
|
|
def interpolate_pair(self, frame0, frame1, time_step=0.5):
|
|
"""Interpolate a single frame between two input frames.
|
|
|
|
Args:
|
|
frame0: [1, C, H, W] tensor, float32, range [0, 1]
|
|
frame1: [1, C, H, W] tensor, float32, range [0, 1]
|
|
time_step: float in (0, 1)
|
|
|
|
Returns:
|
|
Interpolated frame as [1, C, H, W] tensor, float32, clamped to [0, 1]
|
|
"""
|
|
device = next(self.model.parameters()).device
|
|
img0 = frame0.to(device)
|
|
img1 = frame1.to(device)
|
|
|
|
padder = InputPadder(img0.shape, divisor=32, mode='replicate', center=True)
|
|
img0, img1 = padder.pad(img0, img1)
|
|
|
|
pred = self._inference(img0, img1, timestep=time_step)
|
|
pred = padder.unpad(pred)
|
|
return torch.clamp(pred, 0, 1)
|
|
|
|
@torch.no_grad()
|
|
def interpolate_batch(self, frames0, frames1, time_step=0.5):
|
|
"""Interpolate multiple frame pairs at once.
|
|
|
|
Args:
|
|
frames0: [B, C, H, W] tensor, float32, range [0, 1]
|
|
frames1: [B, C, H, W] tensor, float32, range [0, 1]
|
|
time_step: float in (0, 1)
|
|
|
|
Returns:
|
|
Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1]
|
|
"""
|
|
device = next(self.model.parameters()).device
|
|
img0 = frames0.to(device)
|
|
img1 = frames1.to(device)
|
|
|
|
padder = InputPadder(img0.shape, divisor=32, mode='replicate', center=True)
|
|
img0, img1 = padder.pad(img0, img1)
|
|
|
|
pred = self._inference(img0, img1, timestep=time_step)
|
|
pred = padder.unpad(pred)
|
|
return torch.clamp(pred, 0, 1)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SGM-VFI model wrapper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _sgm_init_model_config(F=16, W=7, depth=[2, 2, 2, 4], num_key_points=0.5):
|
|
"""Build SGM-VFI model config dicts (backbone + multiscale)."""
|
|
return {
|
|
'embed_dims': [F, 2*F, 4*F, 8*F],
|
|
'num_heads': [8*F//32],
|
|
'mlp_ratios': [4],
|
|
'qkv_bias': True,
|
|
'norm_layer': partial(nn.LayerNorm, eps=1e-6),
|
|
'depths': depth,
|
|
'window_sizes': [W]
|
|
}, {
|
|
'embed_dims': [F, 2*F, 4*F, 8*F],
|
|
'motion_dims': [0, 0, 0, 8*F//depth[-1]],
|
|
'depths': depth,
|
|
'scales': [8],
|
|
'hidden_dims': [4*F],
|
|
'c': F,
|
|
'num_key_points': num_key_points,
|
|
}
|
|
|
|
|
|
def _sgm_detect_variant(filename):
|
|
"""Auto-detect SGM-VFI model variant from filename.
|
|
|
|
Returns (F, depth).
|
|
Default is small (F=16) since the primary checkpoint (ours-1-2-points)
|
|
is a small model. Only detect base when "base" is in the filename.
|
|
"""
|
|
name = filename.lower()
|
|
is_base = "base" in name
|
|
if is_base:
|
|
return 32, [2, 2, 2, 6]
|
|
else:
|
|
return 16, [2, 2, 2, 4]
|
|
|
|
|
|
class SGMVFIModel:
|
|
"""Clean inference wrapper around SGM-VFI for ComfyUI integration."""
|
|
|
|
def __init__(self, checkpoint_path, variant="auto", num_key_points=0.5, tta=False, device="cpu"):
|
|
import os
|
|
filename = os.path.basename(checkpoint_path)
|
|
|
|
if variant == "auto":
|
|
F_dim, depth = _sgm_detect_variant(filename)
|
|
elif variant == "small":
|
|
F_dim, depth = 16, [2, 2, 2, 4]
|
|
else: # base
|
|
F_dim, depth = 32, [2, 2, 2, 6]
|
|
|
|
self.tta = tta
|
|
self.device = device
|
|
self.variant_name = "small" if F_dim == 16 else "base"
|
|
|
|
backbone_cfg, multiscale_cfg = _sgm_init_model_config(
|
|
F=F_dim, depth=depth, num_key_points=num_key_points)
|
|
backbone = sgm_feature_extractor(**backbone_cfg)
|
|
self.model = SGMMultiScaleFlow(backbone, **multiscale_cfg)
|
|
self._load_checkpoint(checkpoint_path)
|
|
self.model.eval()
|
|
self.model.to(device)
|
|
|
|
def _load_checkpoint(self, checkpoint_path):
|
|
"""Load checkpoint with module prefix stripping and buffer filtering."""
|
|
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
|
|
# Handle wrapped checkpoint formats
|
|
if isinstance(state_dict, dict):
|
|
if "model" in state_dict:
|
|
state_dict = state_dict["model"]
|
|
elif "state_dict" in state_dict:
|
|
state_dict = state_dict["state_dict"]
|
|
|
|
# Strip "module." prefix and filter out attn_mask/HW buffers
|
|
cleaned = {}
|
|
for k, v in state_dict.items():
|
|
if "attn_mask" in k or k.endswith(".HW"):
|
|
continue
|
|
key = k
|
|
if key.startswith("module."):
|
|
key = key[len("module."):]
|
|
cleaned[key] = v
|
|
|
|
self.model.load_state_dict(cleaned, strict=False)
|
|
|
|
def to(self, device):
|
|
"""Move model to device (returns self for chaining)."""
|
|
self.device = device
|
|
self.model.to(device)
|
|
return self
|
|
|
|
@torch.no_grad()
|
|
def _inference(self, img0, img1, timestep=0.5):
|
|
"""Run single inference pass. Inputs already padded, on device."""
|
|
B = img0.shape[0]
|
|
imgs = torch.cat((img0, img1), 1)
|
|
|
|
if self.tta:
|
|
imgs_ = imgs.flip(2).flip(3)
|
|
input_batch = torch.cat((imgs, imgs_), 0)
|
|
_, _, _, preds, _ = self.model(input_batch, timestep=timestep)
|
|
return (preds[:B] + preds[B:].flip(2).flip(3)) / 2.
|
|
else:
|
|
_, _, _, pred, _ = self.model(imgs, timestep=timestep)
|
|
return pred
|
|
|
|
@torch.no_grad()
|
|
def interpolate_pair(self, frame0, frame1, time_step=0.5):
|
|
"""Interpolate a single frame between two input frames.
|
|
|
|
Args:
|
|
frame0: [1, C, H, W] tensor, float32, range [0, 1]
|
|
frame1: [1, C, H, W] tensor, float32, range [0, 1]
|
|
time_step: float in (0, 1)
|
|
|
|
Returns:
|
|
Interpolated frame as [1, C, H, W] tensor, float32, clamped to [0, 1]
|
|
"""
|
|
device = next(self.model.parameters()).device
|
|
img0 = frame0.to(device)
|
|
img1 = frame1.to(device)
|
|
|
|
padder = InputPadder(img0.shape, divisor=32, mode='replicate', center=True)
|
|
img0, img1 = padder.pad(img0, img1)
|
|
|
|
pred = self._inference(img0, img1, timestep=time_step)
|
|
pred = padder.unpad(pred)
|
|
return torch.clamp(pred, 0, 1)
|
|
|
|
@torch.no_grad()
|
|
def interpolate_batch(self, frames0, frames1, time_step=0.5):
|
|
"""Interpolate multiple frame pairs at once.
|
|
|
|
Args:
|
|
frames0: [B, C, H, W] tensor, float32, range [0, 1]
|
|
frames1: [B, C, H, W] tensor, float32, range [0, 1]
|
|
time_step: float in (0, 1)
|
|
|
|
Returns:
|
|
Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1]
|
|
"""
|
|
device = next(self.model.parameters()).device
|
|
img0 = frames0.to(device)
|
|
img1 = frames1.to(device)
|
|
|
|
padder = InputPadder(img0.shape, divisor=32, mode='replicate', center=True)
|
|
img0, img1 = padder.pad(img0, img1)
|
|
|
|
pred = self._inference(img0, img1, timestep=time_step)
|
|
pred = padder.unpad(pred)
|
|
return torch.clamp(pred, 0, 1)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GIMM-VFI model wrapper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class GIMMVFIModel:
|
|
"""Clean inference wrapper around GIMM-VFI for ComfyUI integration.
|
|
|
|
Supports two modes:
|
|
- interpolate_batch(): standard single-midpoint interface (compatible with
|
|
recursive _interpolate_frames machinery used by other models)
|
|
- interpolate_multi(): GIMM-VFI's unique single-pass mode, generates all
|
|
N-1 intermediate frames between each pair in one forward pass
|
|
"""
|
|
|
|
def __init__(self, checkpoint_path, flow_checkpoint_path, variant="auto",
|
|
ds_factor=1.0, device="cpu"):
|
|
import os
|
|
import yaml
|
|
from omegaconf import OmegaConf
|
|
from .gimm_vfi_arch import (
|
|
GIMMVFI_R, GIMMVFI_F, GIMMVFIConfig,
|
|
GIMM_RAFT, GIMM_FlowFormer, gimm_get_flowformer_cfg,
|
|
GIMMInputPadder, GIMMRaftArgs, easydict_to_dict,
|
|
)
|
|
import comfy.utils
|
|
|
|
self.ds_factor = ds_factor
|
|
self.device = device
|
|
self._InputPadder = GIMMInputPadder
|
|
|
|
filename = os.path.basename(checkpoint_path).lower()
|
|
|
|
# Detect variant from filename
|
|
if variant == "auto":
|
|
self.is_flowformer = "gimmvfi_f" in filename
|
|
else:
|
|
self.is_flowformer = (variant == "flowformer")
|
|
|
|
self.variant_name = "flowformer" if self.is_flowformer else "raft"
|
|
|
|
# Load config
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
if self.is_flowformer:
|
|
config_path = os.path.join(script_dir, "gimm_vfi_arch", "configs", "gimmvfi_f_arb.yaml")
|
|
else:
|
|
config_path = os.path.join(script_dir, "gimm_vfi_arch", "configs", "gimmvfi_r_arb.yaml")
|
|
|
|
with open(config_path) as f:
|
|
config = yaml.load(f, Loader=yaml.FullLoader)
|
|
config = easydict_to_dict(config)
|
|
config = OmegaConf.create(config)
|
|
arch_defaults = GIMMVFIConfig.create(config.arch)
|
|
config = OmegaConf.merge(arch_defaults, config.arch)
|
|
|
|
# Build model + flow estimator
|
|
dtype = torch.float32
|
|
|
|
if self.is_flowformer:
|
|
self.model = GIMMVFI_F(dtype, config)
|
|
cfg = gimm_get_flowformer_cfg()
|
|
flow_estimator = GIMM_FlowFormer(cfg.latentcostformer)
|
|
flow_sd = comfy.utils.load_torch_file(flow_checkpoint_path)
|
|
flow_estimator.load_state_dict(flow_sd, strict=True)
|
|
else:
|
|
self.model = GIMMVFI_R(dtype, config)
|
|
raft_args = GIMMRaftArgs(small=False, mixed_precision=False, alternate_corr=False)
|
|
flow_estimator = GIMM_RAFT(raft_args)
|
|
flow_sd = comfy.utils.load_torch_file(flow_checkpoint_path)
|
|
flow_estimator.load_state_dict(flow_sd, strict=True)
|
|
|
|
# Load main model weights
|
|
sd = comfy.utils.load_torch_file(checkpoint_path)
|
|
self.model.load_state_dict(sd, strict=False)
|
|
|
|
self.model.flow_estimator = flow_estimator
|
|
self.model.eval()
|
|
|
|
def to(self, device):
|
|
"""Move model to device (returns self for chaining)."""
|
|
self.device = device if isinstance(device, str) else str(device)
|
|
self.model.to(device)
|
|
return self
|
|
|
|
@torch.no_grad()
|
|
def interpolate_batch(self, frames0, frames1, time_step=0.5):
|
|
"""Interpolate a single midpoint frame per pair (standard interface).
|
|
|
|
Args:
|
|
frames0: [B, C, H, W] tensor, float32, range [0, 1]
|
|
frames1: [B, C, H, W] tensor, float32, range [0, 1]
|
|
time_step: float in (0, 1)
|
|
|
|
Returns:
|
|
Interpolated frames as [B, C, H, W] tensor, float32, clamped to [0, 1]
|
|
"""
|
|
device = next(self.model.parameters()).device
|
|
results = []
|
|
|
|
for i in range(frames0.shape[0]):
|
|
I0 = frames0[i:i+1].to(device)
|
|
I2 = frames1[i:i+1].to(device)
|
|
|
|
padder = self._InputPadder(I0.shape, 32)
|
|
I0_p, I2_p = padder.pad(I0, I2)
|
|
|
|
xs = torch.cat((I0_p.unsqueeze(2), I2_p.unsqueeze(2)), dim=2)
|
|
batch_size = xs.shape[0]
|
|
s_shape = xs.shape[-2:]
|
|
|
|
coord_inputs = [(
|
|
self.model.sample_coord_input(
|
|
batch_size, s_shape, [time_step],
|
|
device=xs.device, upsample_ratio=self.ds_factor,
|
|
),
|
|
None,
|
|
)]
|
|
timesteps = [
|
|
time_step * torch.ones(xs.shape[0]).to(xs.device)
|
|
]
|
|
|
|
all_outputs = self.model(xs, coord_inputs, t=timesteps, ds_factor=self.ds_factor)
|
|
pred = padder.unpad(all_outputs["imgt_pred"][0])
|
|
results.append(torch.clamp(pred, 0, 1))
|
|
|
|
return torch.cat(results, dim=0)
|
|
|
|
@torch.no_grad()
|
|
def interpolate_multi(self, frame0, frame1, num_intermediates):
|
|
"""Generate all intermediate frames between a pair in one forward pass.
|
|
|
|
This is GIMM-VFI's unique capability -- arbitrary timestep interpolation
|
|
without recursive 2x passes.
|
|
|
|
Args:
|
|
frame0: [1, C, H, W] tensor, float32, range [0, 1]
|
|
frame1: [1, C, H, W] tensor, float32, range [0, 1]
|
|
num_intermediates: int, number of intermediate frames to generate
|
|
|
|
Returns:
|
|
List of [1, C, H, W] tensors, float32, clamped to [0, 1]
|
|
"""
|
|
device = next(self.model.parameters()).device
|
|
I0 = frame0.to(device)
|
|
I2 = frame1.to(device)
|
|
|
|
padder = self._InputPadder(I0.shape, 32)
|
|
I0_p, I2_p = padder.pad(I0, I2)
|
|
|
|
xs = torch.cat((I0_p.unsqueeze(2), I2_p.unsqueeze(2)), dim=2)
|
|
batch_size = xs.shape[0]
|
|
s_shape = xs.shape[-2:]
|
|
interp_factor = num_intermediates + 1
|
|
|
|
coord_inputs = [
|
|
(
|
|
self.model.sample_coord_input(
|
|
batch_size, s_shape,
|
|
[1.0 / interp_factor * i],
|
|
device=xs.device,
|
|
upsample_ratio=self.ds_factor,
|
|
),
|
|
None,
|
|
)
|
|
for i in range(1, interp_factor)
|
|
]
|
|
timesteps = [
|
|
i * 1.0 / interp_factor * torch.ones(xs.shape[0]).to(xs.device)
|
|
for i in range(1, interp_factor)
|
|
]
|
|
|
|
all_outputs = self.model(xs, coord_inputs, t=timesteps, ds_factor=self.ds_factor)
|
|
|
|
results = []
|
|
for pred in all_outputs["imgt_pred"]:
|
|
unpadded = padder.unpad(pred)
|
|
results.append(torch.clamp(unpadded, 0, 1))
|
|
|
|
return results
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# FlashVSR model wrapper (4x video super-resolution)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class FlashVSRModel:
|
|
"""Inference wrapper for FlashVSR diffusion-based video super-resolution.
|
|
|
|
Supports three pipeline modes:
|
|
- full: Standard VAE decode, highest quality
|
|
- tiny: TCDecoder decode, faster
|
|
- tiny-long: Streaming TCDecoder decode, lowest VRAM for long videos
|
|
"""
|
|
|
|
# Minimum input frame count required by the pipeline
|
|
MIN_FRAMES = 21
|
|
|
|
def __init__(self, model_dir, mode="tiny", device="cuda:0", dtype=torch.bfloat16):
|
|
from safetensors.torch import load_file
|
|
from .flashvsr_arch import (
|
|
ModelManager, FlashVSRFullPipeline,
|
|
FlashVSRTinyPipeline, FlashVSRTinyLongPipeline,
|
|
)
|
|
from .flashvsr_arch.models.utils import Causal_LQ4x_Proj
|
|
from .flashvsr_arch.models.TCDecoder import build_tcdecoder
|
|
|
|
self.mode = mode
|
|
self.device = device
|
|
self.dtype = dtype
|
|
|
|
dit_path = os.path.join(model_dir, "FlashVSR1_1.safetensors")
|
|
vae_path = os.path.join(model_dir, "Wan2.1_VAE.safetensors")
|
|
lq_path = os.path.join(model_dir, "LQ_proj_in.safetensors")
|
|
tcd_path = os.path.join(model_dir, "TCDecoder.safetensors")
|
|
prompt_path = os.path.join(model_dir, "Prompt.safetensors")
|
|
|
|
mm = ModelManager(torch_dtype=dtype, device="cpu")
|
|
|
|
if mode == "full":
|
|
mm.load_models([dit_path, vae_path])
|
|
self.pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device)
|
|
self.pipe.vae.model.encoder = None
|
|
self.pipe.vae.model.conv1 = None
|
|
else:
|
|
mm.load_models([dit_path])
|
|
Pipeline = FlashVSRTinyLongPipeline if mode == "tiny-long" else FlashVSRTinyPipeline
|
|
self.pipe = Pipeline.from_model_manager(mm, device=device)
|
|
|
|
# TCDecoder for ALL modes (streaming per-chunk decode with LQ conditioning)
|
|
self.pipe.TCDecoder = build_tcdecoder(
|
|
[512, 256, 128, 128], device, dtype, 16 + 768,
|
|
)
|
|
self.pipe.TCDecoder.load_state_dict(
|
|
load_file(tcd_path, device=device), strict=False,
|
|
)
|
|
self.pipe.TCDecoder.clean_mem()
|
|
|
|
# LQ frame projection — Causal variant for FlashVSR v1.1
|
|
self.pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(3, 1536, 1).to(device, dtype)
|
|
if os.path.exists(lq_path):
|
|
lq_sd = load_file(lq_path, device="cpu")
|
|
cleaned = {}
|
|
for k, v in lq_sd.items():
|
|
cleaned[k.removeprefix("LQ_proj_in.")] = v
|
|
self.pipe.denoising_model().LQ_proj_in.load_state_dict(cleaned, strict=True)
|
|
self.pipe.denoising_model().LQ_proj_in.to(device)
|
|
|
|
self.pipe.to(device, dtype)
|
|
self.pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
|
self.pipe.init_cross_kv(prompt_path=prompt_path)
|
|
self.pipe.load_models_to_device([]) # offload to CPU
|
|
|
|
def to(self, device):
|
|
self.device = device
|
|
self.pipe.device = device
|
|
return self
|
|
|
|
def load_to_device(self):
|
|
"""Load models to the compute device for inference."""
|
|
names = ["dit", "vae"] if self.mode == "full" else ["dit"]
|
|
self.pipe.load_models_to_device(names)
|
|
|
|
def offload(self):
|
|
"""Offload models to CPU."""
|
|
self.pipe.load_models_to_device([])
|
|
|
|
def clear_caches(self):
|
|
if hasattr(self.pipe.denoising_model(), "LQ_proj_in"):
|
|
self.pipe.denoising_model().LQ_proj_in.clear_cache()
|
|
if hasattr(self.pipe, "vae") and self.pipe.vae is not None:
|
|
self.pipe.vae.clear_cache()
|
|
if hasattr(self.pipe, "TCDecoder") and self.pipe.TCDecoder is not None:
|
|
self.pipe.TCDecoder.clean_mem()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Frame preprocessing / postprocessing helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _compute_dims(w, h, scale, align=128):
|
|
sw, sh = w * scale, h * scale
|
|
tw = math.ceil(sw / align) * align
|
|
th = math.ceil(sh / align) * align
|
|
return sw, sh, tw, th
|
|
|
|
@staticmethod
|
|
def _restore_video_sequence(result, expected):
|
|
"""Trim pipeline output to the expected frame count."""
|
|
if result.shape[0] > expected:
|
|
result = result[:expected]
|
|
elif result.shape[0] < expected:
|
|
pad = result[-1:].expand(expected - result.shape[0], *result.shape[1:])
|
|
result = torch.cat([result, pad], dim=0)
|
|
return result
|
|
|
|
@staticmethod
|
|
def _next_8n5(n, minimum=21):
|
|
"""Next integer >= n of the form 8k+5 (minimum 21)."""
|
|
if n < minimum:
|
|
return minimum
|
|
return ((n - 5 + 7) // 8) * 8 + 5
|
|
|
|
def _prepare_video(self, frames, scale):
|
|
"""Convert [F, H, W, C] [0,1] frames to padded [1, C, F_padded, H, W] [-1,1].
|
|
|
|
Matches naxci1/ComfyUI-FlashVSR_Stable two-stage temporal padding:
|
|
1. Bicubic-upscale each frame to target resolution
|
|
2. Centered symmetric padding to 128-pixel alignment (reflect mode)
|
|
3. Normalize to [-1, 1]
|
|
4. Stage 1: Pad frame count to next 8n+5 (min 21) by repeating last frame
|
|
5. Stage 2: Add 4 → result is always 8k+1 (since 8n+5+4 = 8(n+1)+1)
|
|
|
|
Returns:
|
|
video: [1, C, F_padded, H, W] tensor
|
|
th, tw: padded spatial dimensions
|
|
nf: padded frame count
|
|
sh, sw: actual (unpadded) spatial dimensions
|
|
pad_top, pad_left: spatial padding offsets for output cropping
|
|
"""
|
|
N, H, W, C = frames.shape
|
|
sw, sh, tw, th = self._compute_dims(W, H, scale)
|
|
|
|
# Stage 1: pad frame count to next 8n+5 (matches naxci1 process_chunk)
|
|
N_padded = self._next_8n5(N)
|
|
|
|
# Stage 2: add 4 → gives 8(n+1)+1, always a valid 8k+1
|
|
target = N_padded + 4
|
|
|
|
# Centered spatial padding offsets
|
|
pad_top = (th - sh) // 2
|
|
pad_bottom = th - sh - pad_top
|
|
pad_left = (tw - sw) // 2
|
|
pad_right = tw - sw - pad_left
|
|
|
|
processed = []
|
|
for i in range(target):
|
|
frame_idx = min(i, N - 1) # clamp to last real frame
|
|
frame = frames[frame_idx].permute(2, 0, 1).unsqueeze(0) # [1, C, H, W]
|
|
upscaled = F.interpolate(frame, size=(sh, sw), mode='bicubic', align_corners=False)
|
|
if pad_top > 0 or pad_bottom > 0 or pad_left > 0 or pad_right > 0:
|
|
# Centered reflect padding (matches naxci1 reference)
|
|
try:
|
|
upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='reflect')
|
|
except RuntimeError:
|
|
# Reflect requires pad < input size; fall back to replicate
|
|
upscaled = F.pad(upscaled, (pad_left, pad_right, pad_top, pad_bottom), mode='replicate')
|
|
normalized = upscaled * 2.0 - 1.0
|
|
processed.append(normalized.squeeze(0).cpu().to(self.dtype))
|
|
|
|
video = torch.stack(processed, 0).permute(1, 0, 2, 3).unsqueeze(0)
|
|
nf = video.shape[2]
|
|
|
|
return video, th, tw, nf, sh, sw, pad_top, pad_left
|
|
|
|
@staticmethod
|
|
def _to_frames(video):
|
|
"""Convert [C, F, H, W] [-1,1] pipeline output to [F, H, W, C] [0,1]."""
|
|
from einops import rearrange
|
|
v = video.squeeze(0) if video.dim() == 5 else video
|
|
v = rearrange(v, "C F H W -> F H W C")
|
|
return torch.clamp((v.float() + 1.0) / 2.0, 0.0, 1.0)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Main upscale method
|
|
# ------------------------------------------------------------------
|
|
|
|
@torch.no_grad()
|
|
def upscale(self, frames, scale=4, tiled=True, tile_size=(60, 104),
|
|
topk_ratio=2.0, kv_ratio=3.0, local_range=11,
|
|
color_fix=True, unload_dit=False, seed=1,
|
|
progress_bar_cmd=None):
|
|
"""Upscale video frames with FlashVSR.
|
|
|
|
Args:
|
|
frames: [F, H, W, C] float32 [0, 1] with F >= 21
|
|
scale: Upscaling factor (2 or 4)
|
|
tiled: Enable VAE tiled decode (saves VRAM)
|
|
tile_size: (H, W) tile size for VAE tiling
|
|
topk_ratio: Sparse attention ratio (higher = faster, less detail)
|
|
kv_ratio: KV cache ratio (higher = more quality, more VRAM)
|
|
local_range: Local attention window (9=sharp, 11=stable)
|
|
color_fix: Apply wavelet color correction
|
|
unload_dit: Offload DiT before VAE decode (saves VRAM)
|
|
seed: Random seed
|
|
progress_bar_cmd: Callable wrapping an iterable for progress display
|
|
|
|
Returns:
|
|
[F, H*scale, W*scale, C] float32 [0, 1]
|
|
"""
|
|
if progress_bar_cmd is None:
|
|
from tqdm import tqdm
|
|
progress_bar_cmd = tqdm
|
|
|
|
original_count = frames.shape[0]
|
|
|
|
# Prepare video tensor (bicubic upscale + centered pad)
|
|
video, th, tw, nf, sh, sw, pad_top, pad_left = self._prepare_video(frames, scale)
|
|
|
|
# Move LQ video to compute device (except for "long" mode which streams)
|
|
if "long" not in self.pipe.__class__.__name__.lower():
|
|
video = video.to(self.pipe.device)
|
|
|
|
# Run pipeline
|
|
out = self.pipe(
|
|
prompt="", negative_prompt="",
|
|
cfg_scale=1.0, num_inference_steps=1,
|
|
seed=seed, tiled=tiled, tile_size=tile_size,
|
|
progress_bar_cmd=progress_bar_cmd,
|
|
LQ_video=video,
|
|
num_frames=nf, height=th, width=tw,
|
|
is_full_block=False, if_buffer=True,
|
|
topk_ratio=topk_ratio * 768 * 1280 / (th * tw),
|
|
kv_ratio=kv_ratio, local_range=local_range,
|
|
color_fix=color_fix, unload_dit=unload_dit,
|
|
)
|
|
|
|
# Convert to ComfyUI format with centered spatial crop
|
|
result = self._to_frames(out).cpu()
|
|
result = result[:, pad_top:pad_top + sh, pad_left:pad_left + sw, :]
|
|
|
|
# Trim to original frame count
|
|
result = self._restore_video_sequence(result, original_count)
|
|
|
|
return result
|