Add BiM-VFI as interpolation method for direct transitions
Integrate KAIST VICLab's BiM-VFI (CVPR 2025) as a third option alongside RIFE and FILM for AI frame interpolation between sequences. - Add BIM_VFI enum value to DirectInterpolationMethod - Create core/bim_vfi_worker.py subprocess worker following FILM pattern - Add BimVfiEnv class managing repo clone, deps install, and checkpoint download (via gdown from Google Drive) - Add batch and single-frame dispatch in TransitionGenerator - Add bim_vfi_blend() to ImageBlender with FILM fallback - Update DirectTransitionDialog UI with BiM-VFI option and setup flow BiM-VFI setup auto-clones the repo, installs deps (cupy, basicsr-fixed, etc.) into the shared PyTorch venv, and downloads the checkpoint. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+2
-1
@@ -23,7 +23,7 @@ from .models import (
|
||||
DatabaseError,
|
||||
)
|
||||
from .database import DatabaseManager
|
||||
from .blender import ImageBlender, TransitionGenerator, RifeDownloader, PracticalRifeEnv, FilmEnv, OPTICAL_FLOW_PRESETS
|
||||
from .blender import ImageBlender, TransitionGenerator, RifeDownloader, PracticalRifeEnv, FilmEnv, BimVfiEnv, OPTICAL_FLOW_PRESETS
|
||||
from .manager import SymlinkManager
|
||||
from .video import encode_image_sequence, encode_from_file_list, find_ffmpeg
|
||||
|
||||
@@ -54,6 +54,7 @@ __all__ = [
|
||||
'RifeDownloader',
|
||||
'PracticalRifeEnv',
|
||||
'FilmEnv',
|
||||
'BimVfiEnv',
|
||||
'SymlinkManager',
|
||||
'OPTICAL_FLOW_PRESETS',
|
||||
'encode_image_sequence',
|
||||
|
||||
@@ -0,0 +1,277 @@
|
||||
#!/usr/bin/env python
|
||||
"""BiM-VFI interpolation worker - runs in isolated venv with PyTorch.
|
||||
|
||||
This script is executed via subprocess from the main application.
|
||||
It handles frame interpolation using KAIST VICLab's BiM-VFI model
|
||||
(Bidirectional Motion Field-Guided Frame Interpolation).
|
||||
|
||||
BiM-VFI is designed for non-uniform motions and produces high-quality
|
||||
results, especially for complex scenes (CVPR 2025).
|
||||
|
||||
Supports two modes:
|
||||
1. Single frame: --output with --timestep
|
||||
2. Batch mode: --output-dir with --frame-count (generates all frames at once)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
# Checkpoint filename
|
||||
BIM_VFI_CHECKPOINT = "bim_vfi.pth"
|
||||
|
||||
|
||||
def load_image(path: Path, device: torch.device) -> torch.Tensor:
|
||||
"""Load image as tensor.
|
||||
|
||||
Args:
|
||||
path: Path to image file.
|
||||
device: Device to load tensor to.
|
||||
|
||||
Returns:
|
||||
Image tensor (1, 3, H, W) normalized to [0, 1].
|
||||
"""
|
||||
img = Image.open(path).convert('RGB')
|
||||
arr = np.array(img).astype(np.float32) / 255.0
|
||||
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
|
||||
return tensor.to(device)
|
||||
|
||||
|
||||
def save_image(tensor: torch.Tensor, path: Path) -> None:
|
||||
"""Save tensor as image.
|
||||
|
||||
Args:
|
||||
tensor: Image tensor (1, 3, H, W) or (3, H, W) normalized to [0, 1].
|
||||
path: Output path.
|
||||
"""
|
||||
if tensor.dim() == 4:
|
||||
tensor = tensor.squeeze(0)
|
||||
arr = tensor.permute(1, 2, 0).cpu().numpy()
|
||||
arr = (arr * 255).clip(0, 255).astype(np.uint8)
|
||||
Image.fromarray(arr).save(path)
|
||||
|
||||
|
||||
# Global model cache
|
||||
_model_cache: dict = {}
|
||||
|
||||
|
||||
def get_model(repo_dir: Path, model_dir: Path, device: torch.device):
|
||||
"""Get or load BiM-VFI model (cached).
|
||||
|
||||
Args:
|
||||
repo_dir: Path to the cloned BiM-VFI repository.
|
||||
model_dir: Directory containing the checkpoint.
|
||||
device: Device to run on.
|
||||
|
||||
Returns:
|
||||
BiM-VFI model instance.
|
||||
"""
|
||||
cache_key = f"bim_vfi_{device}"
|
||||
if cache_key not in _model_cache:
|
||||
# Add repo to sys.path so we can import BiM-VFI modules
|
||||
repo_str = str(repo_dir)
|
||||
if repo_str not in sys.path:
|
||||
sys.path.insert(0, repo_str)
|
||||
|
||||
checkpoint_path = model_dir / BIM_VFI_CHECKPOINT
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"BiM-VFI checkpoint not found at {checkpoint_path}. "
|
||||
"Please download it from Google Drive and place it there."
|
||||
)
|
||||
|
||||
print(f"Loading BiM-VFI model from {checkpoint_path}...", file=sys.stderr)
|
||||
|
||||
# Import BiM-VFI's component registry
|
||||
from modules.components.components import make_components
|
||||
|
||||
# Create model with default config
|
||||
cfg = {'name': 'bim_vfi', 'args': {'pyr_level': 3, 'feat_channels': 32}}
|
||||
model = make_components(cfg)
|
||||
|
||||
# Load checkpoint
|
||||
ckpt = torch.load(str(checkpoint_path), map_location='cpu', weights_only=False)
|
||||
model.load_state_dict(ckpt['model'])
|
||||
|
||||
model.eval()
|
||||
model.to(device)
|
||||
_model_cache[cache_key] = model
|
||||
print("Model loaded.", file=sys.stderr)
|
||||
|
||||
return _model_cache[cache_key]
|
||||
|
||||
|
||||
def get_pyr_level(height: int) -> int:
|
||||
"""Get appropriate pyramid level based on image height.
|
||||
|
||||
Args:
|
||||
height: Image height in pixels.
|
||||
|
||||
Returns:
|
||||
Recommended pyramid level.
|
||||
"""
|
||||
if height >= 2160:
|
||||
return 7
|
||||
elif height >= 1080:
|
||||
return 6
|
||||
else:
|
||||
return 5
|
||||
|
||||
|
||||
def get_scale_factor(height: int) -> float:
|
||||
"""Get appropriate scale factor based on image height.
|
||||
|
||||
Args:
|
||||
height: Image height in pixels.
|
||||
|
||||
Returns:
|
||||
Recommended scale factor.
|
||||
"""
|
||||
if height >= 2160:
|
||||
return 0.25
|
||||
elif height >= 1080:
|
||||
return 0.5
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_single(
|
||||
model, img0: torch.Tensor, img1: torch.Tensor, t: float
|
||||
) -> torch.Tensor:
|
||||
"""Perform single frame interpolation using BiM-VFI.
|
||||
|
||||
Args:
|
||||
model: BiM-VFI model instance.
|
||||
img0: First frame tensor (1, 3, H, W) normalized to [0, 1].
|
||||
img1: Second frame tensor (1, 3, H, W) normalized to [0, 1].
|
||||
t: Interpolation timestep (0.0 to 1.0).
|
||||
|
||||
Returns:
|
||||
Interpolated frame tensor.
|
||||
"""
|
||||
h = img0.shape[2]
|
||||
pyr_level = get_pyr_level(h)
|
||||
scale_factor = get_scale_factor(h)
|
||||
|
||||
time_step = torch.tensor([t]).view(1, 1, 1, 1).to(img0.device)
|
||||
|
||||
# Distance weights for better quality
|
||||
dis0 = torch.ones((1, 1, h, img0.shape[3]), device=img0.device) * t
|
||||
dis1 = 1 - dis0
|
||||
|
||||
results_dict = model(
|
||||
img0=img0, img1=img1,
|
||||
time_step=time_step,
|
||||
dis0=dis0, dis1=dis1,
|
||||
scale_factor=scale_factor,
|
||||
ratio=(1.0 / scale_factor),
|
||||
pyr_level=pyr_level,
|
||||
nr_lvl_skipped=0
|
||||
)
|
||||
|
||||
return results_dict['imgt_pred'].clamp(0, 1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_batch(
|
||||
model, img0: torch.Tensor, img1: torch.Tensor, frame_count: int
|
||||
) -> list[torch.Tensor]:
|
||||
"""Generate multiple interpolated frames using BiM-VFI.
|
||||
|
||||
Args:
|
||||
model: BiM-VFI model instance.
|
||||
img0: First frame tensor (1, 3, H, W) normalized to [0, 1].
|
||||
img1: Second frame tensor (1, 3, H, W) normalized to [0, 1].
|
||||
frame_count: Number of frames to generate between img0 and img1.
|
||||
|
||||
Returns:
|
||||
List of interpolated frame tensors in order.
|
||||
"""
|
||||
frames = []
|
||||
for i in range(frame_count):
|
||||
t = (i + 1) / (frame_count + 1)
|
||||
frame = interpolate_single(model, img0, img1, t)
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='BiM-VFI frame interpolation worker')
|
||||
parser.add_argument('--input0', required=True, help='Path to first input image')
|
||||
parser.add_argument('--input1', required=True, help='Path to second input image')
|
||||
parser.add_argument('--output', help='Path to output image (single frame mode)')
|
||||
parser.add_argument('--output-dir', help='Output directory (batch mode)')
|
||||
parser.add_argument('--output-pattern', default='frame_{:04d}.png',
|
||||
help='Output filename pattern for batch mode')
|
||||
parser.add_argument('--timestep', type=float, default=0.5,
|
||||
help='Interpolation timestep 0-1 (single frame mode)')
|
||||
parser.add_argument('--frame-count', type=int,
|
||||
help='Number of frames to generate (batch mode)')
|
||||
parser.add_argument('--repo-dir', required=True, help='Path to BiM-VFI repo clone')
|
||||
parser.add_argument('--model-dir', required=True, help='Model cache directory')
|
||||
parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'], help='Device to use')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate arguments
|
||||
batch_mode = args.output_dir is not None and args.frame_count is not None
|
||||
single_mode = args.output is not None
|
||||
|
||||
if not batch_mode and not single_mode:
|
||||
print("Error: Must specify either --output (single) or --output-dir + --frame-count (batch)",
|
||||
file=sys.stderr)
|
||||
return 1
|
||||
|
||||
try:
|
||||
# Select device
|
||||
if args.device == 'cuda' and torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
# Load model
|
||||
repo_dir = Path(args.repo_dir)
|
||||
model_dir = Path(args.model_dir)
|
||||
model = get_model(repo_dir, model_dir, device)
|
||||
|
||||
# Load images
|
||||
img0 = load_image(Path(args.input0), device)
|
||||
img1 = load_image(Path(args.input1), device)
|
||||
|
||||
if batch_mode:
|
||||
# Batch mode - generate all frames at once
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Generating {args.frame_count} frames...", file=sys.stderr)
|
||||
frames = interpolate_batch(model, img0, img1, args.frame_count)
|
||||
|
||||
for i, frame in enumerate(frames):
|
||||
output_path = output_dir / args.output_pattern.format(i)
|
||||
save_image(frame, output_path)
|
||||
print(f"Saved {output_path.name}", file=sys.stderr)
|
||||
|
||||
print(f"Success: Generated {len(frames)} frames", file=sys.stderr)
|
||||
else:
|
||||
# Single frame mode
|
||||
result = interpolate_single(model, img0, img1, args.timestep)
|
||||
save_image(result, Path(args.output))
|
||||
print("Success", file=sys.stderr)
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
import traceback
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
+459
-1
@@ -477,6 +477,302 @@ class FilmEnv:
|
||||
return False, str(e), []
|
||||
|
||||
|
||||
class BimVfiEnv:
|
||||
"""Manages BiM-VFI frame interpolation using shared venv with RIFE."""
|
||||
|
||||
VENV_DIR = PRACTICAL_RIFE_VENV_DIR # Share venv with RIFE
|
||||
REPO_DIR = CACHE_DIR / 'BiM-VFI'
|
||||
MODEL_CACHE_DIR = CACHE_DIR / 'bim-vfi'
|
||||
CHECKPOINT_FILENAME = 'bim_vfi.pth'
|
||||
REPO_URL = 'https://github.com/KAIST-VICLab/BiM-VFI.git'
|
||||
# Google Drive file ID for the checkpoint
|
||||
GDRIVE_FILE_ID = '18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC'
|
||||
|
||||
# Extra pip packages needed beyond the base torch venv
|
||||
EXTRA_PACKAGES = [
|
||||
'basicsr-fixed', 'imageio', 'pyyaml', 'opencv-python',
|
||||
'lpips', 'ptflops',
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_venv_python(cls) -> Optional[Path]:
|
||||
"""Get path to venv Python executable."""
|
||||
if cls.VENV_DIR.exists():
|
||||
if sys.platform == 'win32':
|
||||
return cls.VENV_DIR / 'Scripts' / 'python.exe'
|
||||
return cls.VENV_DIR / 'bin' / 'python'
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_checkpoint_path(cls) -> Path:
|
||||
"""Get path to the BiM-VFI checkpoint."""
|
||||
return cls.MODEL_CACHE_DIR / cls.CHECKPOINT_FILENAME
|
||||
|
||||
@classmethod
|
||||
def is_setup(cls) -> bool:
|
||||
"""Check if venv exists, repo is cloned, and checkpoint is present."""
|
||||
python = cls.get_venv_python()
|
||||
if not python or not python.exists():
|
||||
return False
|
||||
if not cls.REPO_DIR.exists():
|
||||
return False
|
||||
return cls.get_checkpoint_path().exists()
|
||||
|
||||
@classmethod
|
||||
def setup_bim_vfi(cls, progress_callback=None, cancelled_check=None) -> bool:
|
||||
"""Clone repo, install deps, and download checkpoint.
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback(message, percent) for progress.
|
||||
cancelled_check: Optional callable that returns True if cancelled.
|
||||
|
||||
Returns:
|
||||
True if setup was successful.
|
||||
"""
|
||||
python = cls.get_venv_python()
|
||||
if not python or not python.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
# Step 1: Clone repo if needed
|
||||
if not cls.REPO_DIR.exists():
|
||||
if progress_callback:
|
||||
progress_callback("Cloning BiM-VFI repository...", 10)
|
||||
if cancelled_check and cancelled_check():
|
||||
return False
|
||||
|
||||
result = subprocess.run(
|
||||
['git', 'clone', '--depth', '1', cls.REPO_URL, str(cls.REPO_DIR)],
|
||||
capture_output=True, text=True, timeout=300
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(f"[BiM-VFI] git clone failed: {result.stderr}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
# Step 2: Install extra packages
|
||||
if progress_callback:
|
||||
progress_callback("Installing BiM-VFI dependencies...", 30)
|
||||
if cancelled_check and cancelled_check():
|
||||
return False
|
||||
|
||||
pip = cls.VENV_DIR / ('Scripts' if sys.platform == 'win32' else 'bin') / 'pip'
|
||||
result = subprocess.run(
|
||||
[str(pip), 'install', '--quiet'] + cls.EXTRA_PACKAGES,
|
||||
capture_output=True, text=True, timeout=600
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(f"[BiM-VFI] pip install failed: {result.stderr}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
# Step 2b: Install cupy (needs CUDA-specific wheel)
|
||||
if progress_callback:
|
||||
progress_callback("Installing cupy (CUDA support)...", 50)
|
||||
if cancelled_check and cancelled_check():
|
||||
return False
|
||||
|
||||
# Try cupy-cuda12x first (CUDA 12), fall back to cupy-cuda11x
|
||||
cupy_installed = False
|
||||
for cupy_pkg in ['cupy-cuda12x', 'cupy-cuda11x']:
|
||||
result = subprocess.run(
|
||||
[str(pip), 'install', '--quiet', cupy_pkg],
|
||||
capture_output=True, text=True, timeout=600
|
||||
)
|
||||
if result.returncode == 0:
|
||||
cupy_installed = True
|
||||
break
|
||||
|
||||
if not cupy_installed:
|
||||
print("[BiM-VFI] Warning: cupy install failed, trying generic cupy", file=sys.stderr)
|
||||
subprocess.run(
|
||||
[str(pip), 'install', '--quiet', 'cupy'],
|
||||
capture_output=True, text=True, timeout=600
|
||||
)
|
||||
|
||||
# Step 3: Download checkpoint
|
||||
checkpoint_path = cls.get_checkpoint_path()
|
||||
if not checkpoint_path.exists():
|
||||
if progress_callback:
|
||||
progress_callback("Downloading BiM-VFI checkpoint (~300MB)...", 60)
|
||||
if cancelled_check and cancelled_check():
|
||||
return False
|
||||
|
||||
cls.MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use gdown to download from Google Drive
|
||||
result = subprocess.run(
|
||||
[str(pip), 'install', '--quiet', 'gdown'],
|
||||
capture_output=True, text=True, timeout=120
|
||||
)
|
||||
|
||||
gdown_bin = cls.VENV_DIR / ('Scripts' if sys.platform == 'win32' else 'bin') / 'gdown'
|
||||
tmp_path = checkpoint_path.with_suffix('.tmp')
|
||||
result = subprocess.run(
|
||||
[str(gdown_bin), '--id', cls.GDRIVE_FILE_ID,
|
||||
'--output', str(tmp_path)],
|
||||
capture_output=True, text=True, timeout=600
|
||||
)
|
||||
if result.returncode == 0 and tmp_path.exists():
|
||||
tmp_path.rename(checkpoint_path)
|
||||
else:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
error = result.stderr.strip() if result.stderr else "unknown error"
|
||||
print(f"[BiM-VFI] Download failed: {error}", file=sys.stderr)
|
||||
print("[BiM-VFI] Please download manually from Google Drive and place at "
|
||||
f"{checkpoint_path}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
if progress_callback:
|
||||
progress_callback("BiM-VFI setup complete!", 100)
|
||||
|
||||
return cls.is_setup()
|
||||
|
||||
except Exception as e:
|
||||
print(f"[BiM-VFI] Setup error: {e}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_worker_script(cls) -> Path:
|
||||
"""Get path to the BiM-VFI worker script."""
|
||||
return Path(__file__).parent / 'bim_vfi_worker.py'
|
||||
|
||||
@classmethod
|
||||
def run_interpolation(
|
||||
cls,
|
||||
img_a_path: Path,
|
||||
img_b_path: Path,
|
||||
output_path: Path,
|
||||
t: float
|
||||
) -> tuple[bool, str]:
|
||||
"""Run BiM-VFI interpolation via subprocess in venv.
|
||||
|
||||
Args:
|
||||
img_a_path: Path to first input image.
|
||||
img_b_path: Path to second input image.
|
||||
output_path: Path to output image.
|
||||
t: Timestep for interpolation (0.0 to 1.0).
|
||||
|
||||
Returns:
|
||||
Tuple of (success, error_message).
|
||||
"""
|
||||
python = cls.get_venv_python()
|
||||
if not python or not python.exists():
|
||||
return False, "venv python not found"
|
||||
|
||||
script = cls.get_worker_script()
|
||||
if not script.exists():
|
||||
return False, f"worker script not found: {script}"
|
||||
|
||||
cmd = [
|
||||
str(python), str(script),
|
||||
'--input0', str(img_a_path),
|
||||
'--input1', str(img_b_path),
|
||||
'--output', str(output_path),
|
||||
'--timestep', str(t),
|
||||
'--repo-dir', str(cls.REPO_DIR),
|
||||
'--model-dir', str(cls.MODEL_CACHE_DIR)
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300 # 5 minute timeout per frame (BiM-VFI can be slow)
|
||||
)
|
||||
if result.returncode == 0 and output_path.exists():
|
||||
return True, ""
|
||||
else:
|
||||
error = result.stderr.strip() if result.stderr else f"returncode={result.returncode}"
|
||||
return False, error
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "timeout (300s)"
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def run_batch_interpolation(
|
||||
cls,
|
||||
img_a_path: Path,
|
||||
img_b_path: Path,
|
||||
output_dir: Path,
|
||||
frame_count: int,
|
||||
output_pattern: str = 'frame_{:04d}.png'
|
||||
) -> tuple[bool, str, list[Path]]:
|
||||
"""Run BiM-VFI batch interpolation via subprocess in venv.
|
||||
|
||||
Args:
|
||||
img_a_path: Path to first input image.
|
||||
img_b_path: Path to second input image.
|
||||
output_dir: Directory to save output frames.
|
||||
frame_count: Number of frames to generate.
|
||||
output_pattern: Filename pattern for output frames.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, error_message, list_of_output_paths).
|
||||
"""
|
||||
python = cls.get_venv_python()
|
||||
if not python or not python.exists():
|
||||
return False, "venv python not found", []
|
||||
|
||||
script = cls.get_worker_script()
|
||||
if not script.exists():
|
||||
return False, f"worker script not found: {script}", []
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cmd = [
|
||||
str(python), str(script),
|
||||
'--input0', str(img_a_path),
|
||||
'--input1', str(img_b_path),
|
||||
'--output-dir', str(output_dir),
|
||||
'--frame-count', str(frame_count),
|
||||
'--output-pattern', output_pattern,
|
||||
'--repo-dir', str(cls.REPO_DIR),
|
||||
'--model-dir', str(cls.MODEL_CACHE_DIR)
|
||||
]
|
||||
|
||||
try:
|
||||
timeout = max(300, frame_count * 45) # At least 5 min, +45s per frame
|
||||
|
||||
print(f"[BiM-VFI] Running batch interpolation: {frame_count} frames", file=sys.stderr)
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
output_paths = [
|
||||
output_dir / output_pattern.format(i)
|
||||
for i in range(frame_count)
|
||||
]
|
||||
existing_paths = [p for p in output_paths if p.exists()]
|
||||
|
||||
if result.returncode == 0 and len(existing_paths) == frame_count:
|
||||
print(f"[BiM-VFI] Success: generated {len(existing_paths)} frames", file=sys.stderr)
|
||||
return True, "", output_paths
|
||||
else:
|
||||
error_parts = []
|
||||
if result.returncode != 0:
|
||||
error_parts.append(f"returncode={result.returncode}")
|
||||
if result.stderr and result.stderr.strip():
|
||||
error_parts.append(f"stderr: {result.stderr.strip()}")
|
||||
if len(existing_paths) != frame_count:
|
||||
error_parts.append(f"expected {frame_count} frames, got {len(existing_paths)}")
|
||||
|
||||
error = "; ".join(error_parts) if error_parts else "unknown error"
|
||||
print(f"[BiM-VFI] Failed: {error}", file=sys.stderr)
|
||||
return False, error, existing_paths
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"[BiM-VFI] Timeout after {timeout}s", file=sys.stderr)
|
||||
return False, f"timeout ({timeout}s)", []
|
||||
except Exception as e:
|
||||
print(f"[BiM-VFI] Exception: {e}", file=sys.stderr)
|
||||
return False, str(e), []
|
||||
|
||||
|
||||
class RifeDownloader:
|
||||
"""Handles automatic download and caching of rife-ncnn-vulkan binary."""
|
||||
|
||||
@@ -1100,6 +1396,53 @@ class ImageBlender:
|
||||
# Fall back to Practical-RIFE
|
||||
return ImageBlender.practical_rife_blend(img_a, img_b, t)
|
||||
|
||||
@staticmethod
|
||||
def bim_vfi_blend(
|
||||
img_a: Image.Image,
|
||||
img_b: Image.Image,
|
||||
t: float
|
||||
) -> Image.Image:
|
||||
"""Blend using BiM-VFI for high-quality interpolation.
|
||||
|
||||
BiM-VFI (Bidirectional Motion Field-Guided Frame Interpolation)
|
||||
handles non-uniform motions well (CVPR 2025).
|
||||
|
||||
Args:
|
||||
img_a: First PIL Image (source frame).
|
||||
img_b: Second PIL Image (target frame).
|
||||
t: Interpolation factor 0.0 (100% A) to 1.0 (100% B).
|
||||
|
||||
Returns:
|
||||
AI-interpolated blended PIL Image.
|
||||
"""
|
||||
if not BimVfiEnv.is_setup():
|
||||
print("[BiM-VFI] Not set up, falling back to FILM", file=sys.stderr)
|
||||
return ImageBlender.film_blend(img_a, img_b, t)
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp = Path(tmpdir)
|
||||
input_a = tmp / 'a.png'
|
||||
input_b = tmp / 'b.png'
|
||||
output_file = tmp / 'out.png'
|
||||
|
||||
img_a.convert('RGB').save(input_a)
|
||||
img_b.convert('RGB').save(input_b)
|
||||
|
||||
success, error_msg = BimVfiEnv.run_interpolation(
|
||||
input_a, input_b, output_file, t
|
||||
)
|
||||
|
||||
if success and output_file.exists():
|
||||
return Image.open(output_file).copy()
|
||||
else:
|
||||
print(f"[BiM-VFI] Interpolation failed: {error_msg}, falling back to FILM", file=sys.stderr)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[BiM-VFI] Exception: {e}, falling back to FILM", file=sys.stderr)
|
||||
|
||||
return ImageBlender.film_blend(img_a, img_b, t)
|
||||
|
||||
@staticmethod
|
||||
def blend_images(
|
||||
img_a_path: Path,
|
||||
@@ -1652,7 +1995,13 @@ class TransitionGenerator:
|
||||
img_a_path, img_b_path, frame_count, dest, base_seq_num
|
||||
)
|
||||
|
||||
# For RIFE (or FILM fallback), generate frames one at a time
|
||||
# For BiM-VFI, use batch mode
|
||||
if method == DirectInterpolationMethod.BIM_VFI and BimVfiEnv.is_setup():
|
||||
return self._generate_bim_vfi_frames_batch(
|
||||
img_a_path, img_b_path, frame_count, dest, base_seq_num
|
||||
)
|
||||
|
||||
# For RIFE (or fallback), generate frames one at a time
|
||||
# Load source images
|
||||
img_a = Image.open(img_a_path)
|
||||
img_b = Image.open(img_b_path)
|
||||
@@ -1674,6 +2023,8 @@ class TransitionGenerator:
|
||||
# Generate interpolated frame
|
||||
if method == DirectInterpolationMethod.FILM:
|
||||
blended = ImageBlender.film_blend(img_a, img_b, t)
|
||||
elif method == DirectInterpolationMethod.BIM_VFI:
|
||||
blended = ImageBlender.bim_vfi_blend(img_a, img_b, t)
|
||||
else: # RIFE
|
||||
blended = ImageBlender.practical_rife_blend(
|
||||
img_a, img_b, t,
|
||||
@@ -1844,3 +2195,110 @@ class TransitionGenerator:
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
def _generate_bim_vfi_frames_batch(
|
||||
self,
|
||||
img_a_path: Path,
|
||||
img_b_path: Path,
|
||||
frame_count: int,
|
||||
dest: Path,
|
||||
base_seq_num: int
|
||||
) -> list[BlendResult]:
|
||||
"""Generate BiM-VFI frames using batch mode.
|
||||
|
||||
Args:
|
||||
img_a_path: Path to last frame of first sequence.
|
||||
img_b_path: Path to first frame of second sequence.
|
||||
frame_count: Number of interpolated frames to generate.
|
||||
dest: Destination directory for generated frames.
|
||||
base_seq_num: Starting sequence number for continuous naming.
|
||||
|
||||
Returns:
|
||||
List of BlendResult objects.
|
||||
"""
|
||||
results = []
|
||||
|
||||
temp_pattern = 'bimvfi_temp_{:04d}.png'
|
||||
|
||||
success, error, temp_paths = BimVfiEnv.run_batch_interpolation(
|
||||
img_a_path,
|
||||
img_b_path,
|
||||
dest,
|
||||
frame_count,
|
||||
temp_pattern
|
||||
)
|
||||
|
||||
if not success:
|
||||
for i in range(frame_count):
|
||||
t = (i + 1) / (frame_count + 1)
|
||||
ext = f".{self.settings.output_format.lower()}"
|
||||
seq_num = base_seq_num + i
|
||||
output_name = f"seq_{seq_num:05d}{ext}"
|
||||
output_path = dest / output_name
|
||||
|
||||
results.append(BlendResult(
|
||||
output_path=output_path,
|
||||
source_a=img_a_path,
|
||||
source_b=img_b_path,
|
||||
blend_factor=t,
|
||||
success=False,
|
||||
error=error
|
||||
))
|
||||
return results
|
||||
|
||||
for i, temp_path in enumerate(temp_paths):
|
||||
t = (i + 1) / (frame_count + 1)
|
||||
ext = f".{self.settings.output_format.lower()}"
|
||||
seq_num = base_seq_num + i
|
||||
output_name = f"seq_{seq_num:05d}{ext}"
|
||||
output_path = dest / output_name
|
||||
|
||||
try:
|
||||
if temp_path.exists():
|
||||
frame = Image.open(temp_path)
|
||||
|
||||
if self.settings.output_format.lower() in ('jpg', 'jpeg'):
|
||||
frame = frame.convert('RGB')
|
||||
|
||||
save_kwargs = {}
|
||||
if self.settings.output_format.lower() in ('jpg', 'jpeg'):
|
||||
save_kwargs['quality'] = self.settings.output_quality
|
||||
elif self.settings.output_format.lower() == 'webp':
|
||||
save_kwargs['lossless'] = True
|
||||
save_kwargs['method'] = self.settings.webp_method
|
||||
elif self.settings.output_format.lower() == 'png':
|
||||
save_kwargs['compress_level'] = 6
|
||||
|
||||
frame.save(output_path, **save_kwargs)
|
||||
frame.close()
|
||||
|
||||
if temp_path != output_path:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
|
||||
results.append(BlendResult(
|
||||
output_path=output_path,
|
||||
source_a=img_a_path,
|
||||
source_b=img_b_path,
|
||||
blend_factor=t,
|
||||
success=True
|
||||
))
|
||||
else:
|
||||
results.append(BlendResult(
|
||||
output_path=output_path,
|
||||
source_a=img_a_path,
|
||||
source_b=img_b_path,
|
||||
blend_factor=t,
|
||||
success=False,
|
||||
error=f"Temp file not found: {temp_path}"
|
||||
))
|
||||
except Exception as e:
|
||||
results.append(BlendResult(
|
||||
output_path=output_path,
|
||||
source_a=img_a_path,
|
||||
source_b=img_b_path,
|
||||
blend_factor=t,
|
||||
success=False,
|
||||
error=str(e)
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
@@ -36,6 +36,7 @@ class DirectInterpolationMethod(Enum):
|
||||
"""Method for direct frame interpolation between sequences."""
|
||||
RIFE = 'rife'
|
||||
FILM = 'film'
|
||||
BIM_VFI = 'bim_vfi'
|
||||
|
||||
|
||||
# --- Data Classes ---
|
||||
|
||||
+59
-5
@@ -67,6 +67,7 @@ from core import (
|
||||
find_ffmpeg,
|
||||
PracticalRifeEnv,
|
||||
FilmEnv,
|
||||
BimVfiEnv,
|
||||
SymlinkManager,
|
||||
OPTICAL_FLOW_PRESETS,
|
||||
)
|
||||
@@ -236,8 +237,12 @@ class DirectTransitionDialog(QDialog):
|
||||
self.method_combo = QComboBox()
|
||||
self.method_combo.addItem("RIFE (Fast, small motion)", DirectInterpolationMethod.RIFE)
|
||||
self.method_combo.addItem("FILM (Slow, large motion)", DirectInterpolationMethod.FILM)
|
||||
if method == DirectInterpolationMethod.FILM:
|
||||
self.method_combo.setCurrentIndex(1)
|
||||
self.method_combo.addItem("BiM-VFI (Best quality, slowest)", DirectInterpolationMethod.BIM_VFI)
|
||||
# Set current method
|
||||
for i in range(self.method_combo.count()):
|
||||
if self.method_combo.itemData(i) == method:
|
||||
self.method_combo.setCurrentIndex(i)
|
||||
break
|
||||
form_layout.addRow("Method:", self.method_combo)
|
||||
|
||||
# Frame count
|
||||
@@ -268,7 +273,8 @@ class DirectTransitionDialog(QDialog):
|
||||
# Explanation
|
||||
explain = QLabel(
|
||||
"RIFE: Fast AI interpolation, best for small motion and color shifts.\n"
|
||||
"FILM: Google Research model, better for large motion and scene gaps.\n\n"
|
||||
"FILM: Google Research model, better for large motion and scene gaps.\n"
|
||||
"BiM-VFI: CVPR 2025, best quality for non-uniform/complex motions.\n\n"
|
||||
"Generated frames bridge the gap between the last frame of this\n"
|
||||
"sequence and the first frame of the next MAIN sequence."
|
||||
)
|
||||
@@ -305,6 +311,7 @@ class DirectTransitionDialog(QDialog):
|
||||
|
||||
rife_ready = PracticalRifeEnv.is_setup()
|
||||
film_ready = FilmEnv.is_setup() if rife_ready else False
|
||||
bim_ready = BimVfiEnv.is_setup() if rife_ready else False
|
||||
|
||||
if method == DirectInterpolationMethod.RIFE:
|
||||
if rife_ready:
|
||||
@@ -316,13 +323,13 @@ class DirectTransitionDialog(QDialog):
|
||||
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
|
||||
self.setup_btn.setVisible(True)
|
||||
self.setup_btn.setText("Setup PyTorch Environment")
|
||||
else: # FILM
|
||||
elif method == DirectInterpolationMethod.FILM:
|
||||
if film_ready:
|
||||
self.status_label.setText("FILM: Ready")
|
||||
self.status_label.setStyleSheet("color: green; font-size: 10px;")
|
||||
self.setup_btn.setVisible(False)
|
||||
elif rife_ready:
|
||||
self.status_label.setText("FILM: Package not installed")
|
||||
self.status_label.setText("FILM: Model not downloaded")
|
||||
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
|
||||
self.setup_btn.setVisible(True)
|
||||
self.setup_btn.setText("Install FILM Package")
|
||||
@@ -331,6 +338,21 @@ class DirectTransitionDialog(QDialog):
|
||||
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
|
||||
self.setup_btn.setVisible(True)
|
||||
self.setup_btn.setText("Setup PyTorch Environment")
|
||||
else: # BIM_VFI
|
||||
if bim_ready:
|
||||
self.status_label.setText("BiM-VFI: Ready")
|
||||
self.status_label.setStyleSheet("color: green; font-size: 10px;")
|
||||
self.setup_btn.setVisible(False)
|
||||
elif rife_ready:
|
||||
self.status_label.setText("BiM-VFI: Not installed (repo + model needed)")
|
||||
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
|
||||
self.setup_btn.setVisible(True)
|
||||
self.setup_btn.setText("Install BiM-VFI")
|
||||
else:
|
||||
self.status_label.setText("BiM-VFI: Not installed (PyTorch required first)")
|
||||
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
|
||||
self.setup_btn.setVisible(True)
|
||||
self.setup_btn.setText("Setup PyTorch Environment")
|
||||
|
||||
def _on_setup(self) -> None:
|
||||
"""Handle setup button click."""
|
||||
@@ -397,6 +419,38 @@ class DirectTransitionDialog(QDialog):
|
||||
)
|
||||
return
|
||||
|
||||
# If BiM-VFI selected and we need to install it
|
||||
if method == DirectInterpolationMethod.BIM_VFI and not BimVfiEnv.is_setup():
|
||||
progress = QProgressDialog(
|
||||
"Setting up BiM-VFI...", "Cancel", 0, 100, self
|
||||
)
|
||||
progress.setWindowTitle("Setup")
|
||||
progress.setWindowModality(Qt.WindowModality.WindowModal)
|
||||
progress.setMinimumDuration(0)
|
||||
progress.setValue(0)
|
||||
|
||||
def progress_cb(msg, pct):
|
||||
progress.setLabelText(msg)
|
||||
progress.setValue(pct)
|
||||
|
||||
def cancelled_check():
|
||||
QApplication.processEvents()
|
||||
return progress.wasCanceled()
|
||||
|
||||
success = BimVfiEnv.setup_bim_vfi(progress_cb, cancelled_check)
|
||||
progress.close()
|
||||
|
||||
if not success:
|
||||
if not cancelled_check():
|
||||
QMessageBox.warning(
|
||||
self, "Setup Failed",
|
||||
"Failed to set up BiM-VFI.\n\n"
|
||||
"If the checkpoint download failed, you can download it "
|
||||
"manually from Google Drive and place it at:\n"
|
||||
f"{BimVfiEnv.get_checkpoint_path()}"
|
||||
)
|
||||
return
|
||||
|
||||
self._update_status()
|
||||
|
||||
def _on_remove(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user