Add BiM-VFI as interpolation method for direct transitions

Integrate KAIST VICLab's BiM-VFI (CVPR 2025) as a third option
alongside RIFE and FILM for AI frame interpolation between sequences.

- Add BIM_VFI enum value to DirectInterpolationMethod
- Create core/bim_vfi_worker.py subprocess worker following FILM pattern
- Add BimVfiEnv class managing repo clone, deps install, and checkpoint
  download (via gdown from Google Drive)
- Add batch and single-frame dispatch in TransitionGenerator
- Add bim_vfi_blend() to ImageBlender with FILM fallback
- Update DirectTransitionDialog UI with BiM-VFI option and setup flow

BiM-VFI setup auto-clones the repo, installs deps (cupy, basicsr-fixed,
etc.) into the shared PyTorch venv, and downloads the checkpoint.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-12 21:35:54 +01:00
parent 546e93ceb1
commit 67422302da
5 changed files with 798 additions and 7 deletions
+2 -1
View File
@@ -23,7 +23,7 @@ from .models import (
DatabaseError,
)
from .database import DatabaseManager
from .blender import ImageBlender, TransitionGenerator, RifeDownloader, PracticalRifeEnv, FilmEnv, OPTICAL_FLOW_PRESETS
from .blender import ImageBlender, TransitionGenerator, RifeDownloader, PracticalRifeEnv, FilmEnv, BimVfiEnv, OPTICAL_FLOW_PRESETS
from .manager import SymlinkManager
from .video import encode_image_sequence, encode_from_file_list, find_ffmpeg
@@ -54,6 +54,7 @@ __all__ = [
'RifeDownloader',
'PracticalRifeEnv',
'FilmEnv',
'BimVfiEnv',
'SymlinkManager',
'OPTICAL_FLOW_PRESETS',
'encode_image_sequence',
+277
View File
@@ -0,0 +1,277 @@
#!/usr/bin/env python
"""BiM-VFI interpolation worker - runs in isolated venv with PyTorch.
This script is executed via subprocess from the main application.
It handles frame interpolation using KAIST VICLab's BiM-VFI model
(Bidirectional Motion Field-Guided Frame Interpolation).
BiM-VFI is designed for non-uniform motions and produces high-quality
results, especially for complex scenes (CVPR 2025).
Supports two modes:
1. Single frame: --output with --timestep
2. Batch mode: --output-dir with --frame-count (generates all frames at once)
"""
import argparse
import sys
from pathlib import Path
import numpy as np
import torch
from PIL import Image
# Checkpoint filename
BIM_VFI_CHECKPOINT = "bim_vfi.pth"
def load_image(path: Path, device: torch.device) -> torch.Tensor:
"""Load image as tensor.
Args:
path: Path to image file.
device: Device to load tensor to.
Returns:
Image tensor (1, 3, H, W) normalized to [0, 1].
"""
img = Image.open(path).convert('RGB')
arr = np.array(img).astype(np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
return tensor.to(device)
def save_image(tensor: torch.Tensor, path: Path) -> None:
"""Save tensor as image.
Args:
tensor: Image tensor (1, 3, H, W) or (3, H, W) normalized to [0, 1].
path: Output path.
"""
if tensor.dim() == 4:
tensor = tensor.squeeze(0)
arr = tensor.permute(1, 2, 0).cpu().numpy()
arr = (arr * 255).clip(0, 255).astype(np.uint8)
Image.fromarray(arr).save(path)
# Global model cache
_model_cache: dict = {}
def get_model(repo_dir: Path, model_dir: Path, device: torch.device):
"""Get or load BiM-VFI model (cached).
Args:
repo_dir: Path to the cloned BiM-VFI repository.
model_dir: Directory containing the checkpoint.
device: Device to run on.
Returns:
BiM-VFI model instance.
"""
cache_key = f"bim_vfi_{device}"
if cache_key not in _model_cache:
# Add repo to sys.path so we can import BiM-VFI modules
repo_str = str(repo_dir)
if repo_str not in sys.path:
sys.path.insert(0, repo_str)
checkpoint_path = model_dir / BIM_VFI_CHECKPOINT
if not checkpoint_path.exists():
raise FileNotFoundError(
f"BiM-VFI checkpoint not found at {checkpoint_path}. "
"Please download it from Google Drive and place it there."
)
print(f"Loading BiM-VFI model from {checkpoint_path}...", file=sys.stderr)
# Import BiM-VFI's component registry
from modules.components.components import make_components
# Create model with default config
cfg = {'name': 'bim_vfi', 'args': {'pyr_level': 3, 'feat_channels': 32}}
model = make_components(cfg)
# Load checkpoint
ckpt = torch.load(str(checkpoint_path), map_location='cpu', weights_only=False)
model.load_state_dict(ckpt['model'])
model.eval()
model.to(device)
_model_cache[cache_key] = model
print("Model loaded.", file=sys.stderr)
return _model_cache[cache_key]
def get_pyr_level(height: int) -> int:
"""Get appropriate pyramid level based on image height.
Args:
height: Image height in pixels.
Returns:
Recommended pyramid level.
"""
if height >= 2160:
return 7
elif height >= 1080:
return 6
else:
return 5
def get_scale_factor(height: int) -> float:
"""Get appropriate scale factor based on image height.
Args:
height: Image height in pixels.
Returns:
Recommended scale factor.
"""
if height >= 2160:
return 0.25
elif height >= 1080:
return 0.5
else:
return 1.0
@torch.no_grad()
def interpolate_single(
model, img0: torch.Tensor, img1: torch.Tensor, t: float
) -> torch.Tensor:
"""Perform single frame interpolation using BiM-VFI.
Args:
model: BiM-VFI model instance.
img0: First frame tensor (1, 3, H, W) normalized to [0, 1].
img1: Second frame tensor (1, 3, H, W) normalized to [0, 1].
t: Interpolation timestep (0.0 to 1.0).
Returns:
Interpolated frame tensor.
"""
h = img0.shape[2]
pyr_level = get_pyr_level(h)
scale_factor = get_scale_factor(h)
time_step = torch.tensor([t]).view(1, 1, 1, 1).to(img0.device)
# Distance weights for better quality
dis0 = torch.ones((1, 1, h, img0.shape[3]), device=img0.device) * t
dis1 = 1 - dis0
results_dict = model(
img0=img0, img1=img1,
time_step=time_step,
dis0=dis0, dis1=dis1,
scale_factor=scale_factor,
ratio=(1.0 / scale_factor),
pyr_level=pyr_level,
nr_lvl_skipped=0
)
return results_dict['imgt_pred'].clamp(0, 1)
@torch.no_grad()
def interpolate_batch(
model, img0: torch.Tensor, img1: torch.Tensor, frame_count: int
) -> list[torch.Tensor]:
"""Generate multiple interpolated frames using BiM-VFI.
Args:
model: BiM-VFI model instance.
img0: First frame tensor (1, 3, H, W) normalized to [0, 1].
img1: Second frame tensor (1, 3, H, W) normalized to [0, 1].
frame_count: Number of frames to generate between img0 and img1.
Returns:
List of interpolated frame tensors in order.
"""
frames = []
for i in range(frame_count):
t = (i + 1) / (frame_count + 1)
frame = interpolate_single(model, img0, img1, t)
frames.append(frame)
return frames
def main():
parser = argparse.ArgumentParser(description='BiM-VFI frame interpolation worker')
parser.add_argument('--input0', required=True, help='Path to first input image')
parser.add_argument('--input1', required=True, help='Path to second input image')
parser.add_argument('--output', help='Path to output image (single frame mode)')
parser.add_argument('--output-dir', help='Output directory (batch mode)')
parser.add_argument('--output-pattern', default='frame_{:04d}.png',
help='Output filename pattern for batch mode')
parser.add_argument('--timestep', type=float, default=0.5,
help='Interpolation timestep 0-1 (single frame mode)')
parser.add_argument('--frame-count', type=int,
help='Number of frames to generate (batch mode)')
parser.add_argument('--repo-dir', required=True, help='Path to BiM-VFI repo clone')
parser.add_argument('--model-dir', required=True, help='Model cache directory')
parser.add_argument('--device', default='cuda', choices=['cuda', 'cpu'], help='Device to use')
args = parser.parse_args()
# Validate arguments
batch_mode = args.output_dir is not None and args.frame_count is not None
single_mode = args.output is not None
if not batch_mode and not single_mode:
print("Error: Must specify either --output (single) or --output-dir + --frame-count (batch)",
file=sys.stderr)
return 1
try:
# Select device
if args.device == 'cuda' and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
# Load model
repo_dir = Path(args.repo_dir)
model_dir = Path(args.model_dir)
model = get_model(repo_dir, model_dir, device)
# Load images
img0 = load_image(Path(args.input0), device)
img1 = load_image(Path(args.input1), device)
if batch_mode:
# Batch mode - generate all frames at once
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Generating {args.frame_count} frames...", file=sys.stderr)
frames = interpolate_batch(model, img0, img1, args.frame_count)
for i, frame in enumerate(frames):
output_path = output_dir / args.output_pattern.format(i)
save_image(frame, output_path)
print(f"Saved {output_path.name}", file=sys.stderr)
print(f"Success: Generated {len(frames)} frames", file=sys.stderr)
else:
# Single frame mode
result = interpolate_single(model, img0, img1, args.timestep)
save_image(result, Path(args.output))
print("Success", file=sys.stderr)
return 0
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
return 1
if __name__ == '__main__':
sys.exit(main())
+459 -1
View File
@@ -477,6 +477,302 @@ class FilmEnv:
return False, str(e), []
class BimVfiEnv:
"""Manages BiM-VFI frame interpolation using shared venv with RIFE."""
VENV_DIR = PRACTICAL_RIFE_VENV_DIR # Share venv with RIFE
REPO_DIR = CACHE_DIR / 'BiM-VFI'
MODEL_CACHE_DIR = CACHE_DIR / 'bim-vfi'
CHECKPOINT_FILENAME = 'bim_vfi.pth'
REPO_URL = 'https://github.com/KAIST-VICLab/BiM-VFI.git'
# Google Drive file ID for the checkpoint
GDRIVE_FILE_ID = '18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC'
# Extra pip packages needed beyond the base torch venv
EXTRA_PACKAGES = [
'basicsr-fixed', 'imageio', 'pyyaml', 'opencv-python',
'lpips', 'ptflops',
]
@classmethod
def get_venv_python(cls) -> Optional[Path]:
"""Get path to venv Python executable."""
if cls.VENV_DIR.exists():
if sys.platform == 'win32':
return cls.VENV_DIR / 'Scripts' / 'python.exe'
return cls.VENV_DIR / 'bin' / 'python'
return None
@classmethod
def get_checkpoint_path(cls) -> Path:
"""Get path to the BiM-VFI checkpoint."""
return cls.MODEL_CACHE_DIR / cls.CHECKPOINT_FILENAME
@classmethod
def is_setup(cls) -> bool:
"""Check if venv exists, repo is cloned, and checkpoint is present."""
python = cls.get_venv_python()
if not python or not python.exists():
return False
if not cls.REPO_DIR.exists():
return False
return cls.get_checkpoint_path().exists()
@classmethod
def setup_bim_vfi(cls, progress_callback=None, cancelled_check=None) -> bool:
"""Clone repo, install deps, and download checkpoint.
Args:
progress_callback: Optional callback(message, percent) for progress.
cancelled_check: Optional callable that returns True if cancelled.
Returns:
True if setup was successful.
"""
python = cls.get_venv_python()
if not python or not python.exists():
return False
try:
# Step 1: Clone repo if needed
if not cls.REPO_DIR.exists():
if progress_callback:
progress_callback("Cloning BiM-VFI repository...", 10)
if cancelled_check and cancelled_check():
return False
result = subprocess.run(
['git', 'clone', '--depth', '1', cls.REPO_URL, str(cls.REPO_DIR)],
capture_output=True, text=True, timeout=300
)
if result.returncode != 0:
print(f"[BiM-VFI] git clone failed: {result.stderr}", file=sys.stderr)
return False
# Step 2: Install extra packages
if progress_callback:
progress_callback("Installing BiM-VFI dependencies...", 30)
if cancelled_check and cancelled_check():
return False
pip = cls.VENV_DIR / ('Scripts' if sys.platform == 'win32' else 'bin') / 'pip'
result = subprocess.run(
[str(pip), 'install', '--quiet'] + cls.EXTRA_PACKAGES,
capture_output=True, text=True, timeout=600
)
if result.returncode != 0:
print(f"[BiM-VFI] pip install failed: {result.stderr}", file=sys.stderr)
return False
# Step 2b: Install cupy (needs CUDA-specific wheel)
if progress_callback:
progress_callback("Installing cupy (CUDA support)...", 50)
if cancelled_check and cancelled_check():
return False
# Try cupy-cuda12x first (CUDA 12), fall back to cupy-cuda11x
cupy_installed = False
for cupy_pkg in ['cupy-cuda12x', 'cupy-cuda11x']:
result = subprocess.run(
[str(pip), 'install', '--quiet', cupy_pkg],
capture_output=True, text=True, timeout=600
)
if result.returncode == 0:
cupy_installed = True
break
if not cupy_installed:
print("[BiM-VFI] Warning: cupy install failed, trying generic cupy", file=sys.stderr)
subprocess.run(
[str(pip), 'install', '--quiet', 'cupy'],
capture_output=True, text=True, timeout=600
)
# Step 3: Download checkpoint
checkpoint_path = cls.get_checkpoint_path()
if not checkpoint_path.exists():
if progress_callback:
progress_callback("Downloading BiM-VFI checkpoint (~300MB)...", 60)
if cancelled_check and cancelled_check():
return False
cls.MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
# Use gdown to download from Google Drive
result = subprocess.run(
[str(pip), 'install', '--quiet', 'gdown'],
capture_output=True, text=True, timeout=120
)
gdown_bin = cls.VENV_DIR / ('Scripts' if sys.platform == 'win32' else 'bin') / 'gdown'
tmp_path = checkpoint_path.with_suffix('.tmp')
result = subprocess.run(
[str(gdown_bin), '--id', cls.GDRIVE_FILE_ID,
'--output', str(tmp_path)],
capture_output=True, text=True, timeout=600
)
if result.returncode == 0 and tmp_path.exists():
tmp_path.rename(checkpoint_path)
else:
tmp_path.unlink(missing_ok=True)
error = result.stderr.strip() if result.stderr else "unknown error"
print(f"[BiM-VFI] Download failed: {error}", file=sys.stderr)
print("[BiM-VFI] Please download manually from Google Drive and place at "
f"{checkpoint_path}", file=sys.stderr)
return False
if progress_callback:
progress_callback("BiM-VFI setup complete!", 100)
return cls.is_setup()
except Exception as e:
print(f"[BiM-VFI] Setup error: {e}", file=sys.stderr)
return False
@classmethod
def get_worker_script(cls) -> Path:
"""Get path to the BiM-VFI worker script."""
return Path(__file__).parent / 'bim_vfi_worker.py'
@classmethod
def run_interpolation(
cls,
img_a_path: Path,
img_b_path: Path,
output_path: Path,
t: float
) -> tuple[bool, str]:
"""Run BiM-VFI interpolation via subprocess in venv.
Args:
img_a_path: Path to first input image.
img_b_path: Path to second input image.
output_path: Path to output image.
t: Timestep for interpolation (0.0 to 1.0).
Returns:
Tuple of (success, error_message).
"""
python = cls.get_venv_python()
if not python or not python.exists():
return False, "venv python not found"
script = cls.get_worker_script()
if not script.exists():
return False, f"worker script not found: {script}"
cmd = [
str(python), str(script),
'--input0', str(img_a_path),
'--input1', str(img_b_path),
'--output', str(output_path),
'--timestep', str(t),
'--repo-dir', str(cls.REPO_DIR),
'--model-dir', str(cls.MODEL_CACHE_DIR)
]
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300 # 5 minute timeout per frame (BiM-VFI can be slow)
)
if result.returncode == 0 and output_path.exists():
return True, ""
else:
error = result.stderr.strip() if result.stderr else f"returncode={result.returncode}"
return False, error
except subprocess.TimeoutExpired:
return False, "timeout (300s)"
except Exception as e:
return False, str(e)
@classmethod
def run_batch_interpolation(
cls,
img_a_path: Path,
img_b_path: Path,
output_dir: Path,
frame_count: int,
output_pattern: str = 'frame_{:04d}.png'
) -> tuple[bool, str, list[Path]]:
"""Run BiM-VFI batch interpolation via subprocess in venv.
Args:
img_a_path: Path to first input image.
img_b_path: Path to second input image.
output_dir: Directory to save output frames.
frame_count: Number of frames to generate.
output_pattern: Filename pattern for output frames.
Returns:
Tuple of (success, error_message, list_of_output_paths).
"""
python = cls.get_venv_python()
if not python or not python.exists():
return False, "venv python not found", []
script = cls.get_worker_script()
if not script.exists():
return False, f"worker script not found: {script}", []
output_dir.mkdir(parents=True, exist_ok=True)
cmd = [
str(python), str(script),
'--input0', str(img_a_path),
'--input1', str(img_b_path),
'--output-dir', str(output_dir),
'--frame-count', str(frame_count),
'--output-pattern', output_pattern,
'--repo-dir', str(cls.REPO_DIR),
'--model-dir', str(cls.MODEL_CACHE_DIR)
]
try:
timeout = max(300, frame_count * 45) # At least 5 min, +45s per frame
print(f"[BiM-VFI] Running batch interpolation: {frame_count} frames", file=sys.stderr)
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout
)
output_paths = [
output_dir / output_pattern.format(i)
for i in range(frame_count)
]
existing_paths = [p for p in output_paths if p.exists()]
if result.returncode == 0 and len(existing_paths) == frame_count:
print(f"[BiM-VFI] Success: generated {len(existing_paths)} frames", file=sys.stderr)
return True, "", output_paths
else:
error_parts = []
if result.returncode != 0:
error_parts.append(f"returncode={result.returncode}")
if result.stderr and result.stderr.strip():
error_parts.append(f"stderr: {result.stderr.strip()}")
if len(existing_paths) != frame_count:
error_parts.append(f"expected {frame_count} frames, got {len(existing_paths)}")
error = "; ".join(error_parts) if error_parts else "unknown error"
print(f"[BiM-VFI] Failed: {error}", file=sys.stderr)
return False, error, existing_paths
except subprocess.TimeoutExpired:
print(f"[BiM-VFI] Timeout after {timeout}s", file=sys.stderr)
return False, f"timeout ({timeout}s)", []
except Exception as e:
print(f"[BiM-VFI] Exception: {e}", file=sys.stderr)
return False, str(e), []
class RifeDownloader:
"""Handles automatic download and caching of rife-ncnn-vulkan binary."""
@@ -1100,6 +1396,53 @@ class ImageBlender:
# Fall back to Practical-RIFE
return ImageBlender.practical_rife_blend(img_a, img_b, t)
@staticmethod
def bim_vfi_blend(
img_a: Image.Image,
img_b: Image.Image,
t: float
) -> Image.Image:
"""Blend using BiM-VFI for high-quality interpolation.
BiM-VFI (Bidirectional Motion Field-Guided Frame Interpolation)
handles non-uniform motions well (CVPR 2025).
Args:
img_a: First PIL Image (source frame).
img_b: Second PIL Image (target frame).
t: Interpolation factor 0.0 (100% A) to 1.0 (100% B).
Returns:
AI-interpolated blended PIL Image.
"""
if not BimVfiEnv.is_setup():
print("[BiM-VFI] Not set up, falling back to FILM", file=sys.stderr)
return ImageBlender.film_blend(img_a, img_b, t)
try:
with tempfile.TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
input_a = tmp / 'a.png'
input_b = tmp / 'b.png'
output_file = tmp / 'out.png'
img_a.convert('RGB').save(input_a)
img_b.convert('RGB').save(input_b)
success, error_msg = BimVfiEnv.run_interpolation(
input_a, input_b, output_file, t
)
if success and output_file.exists():
return Image.open(output_file).copy()
else:
print(f"[BiM-VFI] Interpolation failed: {error_msg}, falling back to FILM", file=sys.stderr)
except Exception as e:
print(f"[BiM-VFI] Exception: {e}, falling back to FILM", file=sys.stderr)
return ImageBlender.film_blend(img_a, img_b, t)
@staticmethod
def blend_images(
img_a_path: Path,
@@ -1652,7 +1995,13 @@ class TransitionGenerator:
img_a_path, img_b_path, frame_count, dest, base_seq_num
)
# For RIFE (or FILM fallback), generate frames one at a time
# For BiM-VFI, use batch mode
if method == DirectInterpolationMethod.BIM_VFI and BimVfiEnv.is_setup():
return self._generate_bim_vfi_frames_batch(
img_a_path, img_b_path, frame_count, dest, base_seq_num
)
# For RIFE (or fallback), generate frames one at a time
# Load source images
img_a = Image.open(img_a_path)
img_b = Image.open(img_b_path)
@@ -1674,6 +2023,8 @@ class TransitionGenerator:
# Generate interpolated frame
if method == DirectInterpolationMethod.FILM:
blended = ImageBlender.film_blend(img_a, img_b, t)
elif method == DirectInterpolationMethod.BIM_VFI:
blended = ImageBlender.bim_vfi_blend(img_a, img_b, t)
else: # RIFE
blended = ImageBlender.practical_rife_blend(
img_a, img_b, t,
@@ -1844,3 +2195,110 @@ class TransitionGenerator:
))
return results
def _generate_bim_vfi_frames_batch(
self,
img_a_path: Path,
img_b_path: Path,
frame_count: int,
dest: Path,
base_seq_num: int
) -> list[BlendResult]:
"""Generate BiM-VFI frames using batch mode.
Args:
img_a_path: Path to last frame of first sequence.
img_b_path: Path to first frame of second sequence.
frame_count: Number of interpolated frames to generate.
dest: Destination directory for generated frames.
base_seq_num: Starting sequence number for continuous naming.
Returns:
List of BlendResult objects.
"""
results = []
temp_pattern = 'bimvfi_temp_{:04d}.png'
success, error, temp_paths = BimVfiEnv.run_batch_interpolation(
img_a_path,
img_b_path,
dest,
frame_count,
temp_pattern
)
if not success:
for i in range(frame_count):
t = (i + 1) / (frame_count + 1)
ext = f".{self.settings.output_format.lower()}"
seq_num = base_seq_num + i
output_name = f"seq_{seq_num:05d}{ext}"
output_path = dest / output_name
results.append(BlendResult(
output_path=output_path,
source_a=img_a_path,
source_b=img_b_path,
blend_factor=t,
success=False,
error=error
))
return results
for i, temp_path in enumerate(temp_paths):
t = (i + 1) / (frame_count + 1)
ext = f".{self.settings.output_format.lower()}"
seq_num = base_seq_num + i
output_name = f"seq_{seq_num:05d}{ext}"
output_path = dest / output_name
try:
if temp_path.exists():
frame = Image.open(temp_path)
if self.settings.output_format.lower() in ('jpg', 'jpeg'):
frame = frame.convert('RGB')
save_kwargs = {}
if self.settings.output_format.lower() in ('jpg', 'jpeg'):
save_kwargs['quality'] = self.settings.output_quality
elif self.settings.output_format.lower() == 'webp':
save_kwargs['lossless'] = True
save_kwargs['method'] = self.settings.webp_method
elif self.settings.output_format.lower() == 'png':
save_kwargs['compress_level'] = 6
frame.save(output_path, **save_kwargs)
frame.close()
if temp_path != output_path:
temp_path.unlink(missing_ok=True)
results.append(BlendResult(
output_path=output_path,
source_a=img_a_path,
source_b=img_b_path,
blend_factor=t,
success=True
))
else:
results.append(BlendResult(
output_path=output_path,
source_a=img_a_path,
source_b=img_b_path,
blend_factor=t,
success=False,
error=f"Temp file not found: {temp_path}"
))
except Exception as e:
results.append(BlendResult(
output_path=output_path,
source_a=img_a_path,
source_b=img_b_path,
blend_factor=t,
success=False,
error=str(e)
))
return results
+1
View File
@@ -36,6 +36,7 @@ class DirectInterpolationMethod(Enum):
"""Method for direct frame interpolation between sequences."""
RIFE = 'rife'
FILM = 'film'
BIM_VFI = 'bim_vfi'
# --- Data Classes ---
+59 -5
View File
@@ -67,6 +67,7 @@ from core import (
find_ffmpeg,
PracticalRifeEnv,
FilmEnv,
BimVfiEnv,
SymlinkManager,
OPTICAL_FLOW_PRESETS,
)
@@ -236,8 +237,12 @@ class DirectTransitionDialog(QDialog):
self.method_combo = QComboBox()
self.method_combo.addItem("RIFE (Fast, small motion)", DirectInterpolationMethod.RIFE)
self.method_combo.addItem("FILM (Slow, large motion)", DirectInterpolationMethod.FILM)
if method == DirectInterpolationMethod.FILM:
self.method_combo.setCurrentIndex(1)
self.method_combo.addItem("BiM-VFI (Best quality, slowest)", DirectInterpolationMethod.BIM_VFI)
# Set current method
for i in range(self.method_combo.count()):
if self.method_combo.itemData(i) == method:
self.method_combo.setCurrentIndex(i)
break
form_layout.addRow("Method:", self.method_combo)
# Frame count
@@ -268,7 +273,8 @@ class DirectTransitionDialog(QDialog):
# Explanation
explain = QLabel(
"RIFE: Fast AI interpolation, best for small motion and color shifts.\n"
"FILM: Google Research model, better for large motion and scene gaps.\n\n"
"FILM: Google Research model, better for large motion and scene gaps.\n"
"BiM-VFI: CVPR 2025, best quality for non-uniform/complex motions.\n\n"
"Generated frames bridge the gap between the last frame of this\n"
"sequence and the first frame of the next MAIN sequence."
)
@@ -305,6 +311,7 @@ class DirectTransitionDialog(QDialog):
rife_ready = PracticalRifeEnv.is_setup()
film_ready = FilmEnv.is_setup() if rife_ready else False
bim_ready = BimVfiEnv.is_setup() if rife_ready else False
if method == DirectInterpolationMethod.RIFE:
if rife_ready:
@@ -316,13 +323,13 @@ class DirectTransitionDialog(QDialog):
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
self.setup_btn.setVisible(True)
self.setup_btn.setText("Setup PyTorch Environment")
else: # FILM
elif method == DirectInterpolationMethod.FILM:
if film_ready:
self.status_label.setText("FILM: Ready")
self.status_label.setStyleSheet("color: green; font-size: 10px;")
self.setup_btn.setVisible(False)
elif rife_ready:
self.status_label.setText("FILM: Package not installed")
self.status_label.setText("FILM: Model not downloaded")
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
self.setup_btn.setVisible(True)
self.setup_btn.setText("Install FILM Package")
@@ -331,6 +338,21 @@ class DirectTransitionDialog(QDialog):
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
self.setup_btn.setVisible(True)
self.setup_btn.setText("Setup PyTorch Environment")
else: # BIM_VFI
if bim_ready:
self.status_label.setText("BiM-VFI: Ready")
self.status_label.setStyleSheet("color: green; font-size: 10px;")
self.setup_btn.setVisible(False)
elif rife_ready:
self.status_label.setText("BiM-VFI: Not installed (repo + model needed)")
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
self.setup_btn.setVisible(True)
self.setup_btn.setText("Install BiM-VFI")
else:
self.status_label.setText("BiM-VFI: Not installed (PyTorch required first)")
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
self.setup_btn.setVisible(True)
self.setup_btn.setText("Setup PyTorch Environment")
def _on_setup(self) -> None:
"""Handle setup button click."""
@@ -397,6 +419,38 @@ class DirectTransitionDialog(QDialog):
)
return
# If BiM-VFI selected and we need to install it
if method == DirectInterpolationMethod.BIM_VFI and not BimVfiEnv.is_setup():
progress = QProgressDialog(
"Setting up BiM-VFI...", "Cancel", 0, 100, self
)
progress.setWindowTitle("Setup")
progress.setWindowModality(Qt.WindowModality.WindowModal)
progress.setMinimumDuration(0)
progress.setValue(0)
def progress_cb(msg, pct):
progress.setLabelText(msg)
progress.setValue(pct)
def cancelled_check():
QApplication.processEvents()
return progress.wasCanceled()
success = BimVfiEnv.setup_bim_vfi(progress_cb, cancelled_check)
progress.close()
if not success:
if not cancelled_check():
QMessageBox.warning(
self, "Setup Failed",
"Failed to set up BiM-VFI.\n\n"
"If the checkpoint download failed, you can download it "
"manually from Google Drive and place it at:\n"
f"{BimVfiEnv.get_checkpoint_path()}"
)
return
self._update_status()
def _on_remove(self) -> None: