Fix BiM-VFI integration bugs from code review
- Fix critical import: use `from modules.components import make_components` instead of `from modules.components.components` so the @register decorator fires and the model can be instantiated - Use separate venv (venv-bimvfi) to avoid dependency conflicts with RIFE/FILM — basicsr-fixed/cupy/lpips can break the shared venv - BimVfiEnv now creates its own venv with PyTorch, independent of RIFE - Remove dead code: dis0/dis1/scale_factor/ratio/nr_lvl_skipped kwargs are silently ignored by model's **kwargs, and get_scale_factor/ get_pyr_level were computing unused values - Use model's trained pyr_level=3 (set at construction) instead of overriding to 5-7 at inference which was untested - Use `python -m pip` and `python -m gdown` instead of direct binaries for cross-platform reliability - Add checkpoint size validation (>1MB) on download - Verify torch importable in is_setup() instead of just checking files - Update UI: BiM-VFI setup flow is independent, doesn't need RIFE first Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+6
-51
@@ -55,7 +55,6 @@ def save_image(tensor: torch.Tensor, path: Path) -> None:
|
||||
Image.fromarray(arr).save(path)
|
||||
|
||||
|
||||
# Global model cache
|
||||
_model_cache: dict = {}
|
||||
|
||||
|
||||
@@ -87,10 +86,10 @@ def get_model(repo_dir: Path, model_dir: Path, device: torch.device):
|
||||
|
||||
print(f"Loading BiM-VFI model from {checkpoint_path}...", file=sys.stderr)
|
||||
|
||||
# Import BiM-VFI's component registry
|
||||
from modules.components.components import make_components
|
||||
# Import from the package __init__ so the @register decorator fires
|
||||
from modules.components import make_components
|
||||
|
||||
# Create model with default config
|
||||
# Create model with the trained config (pyr_level=3, feat_channels=32)
|
||||
cfg = {'name': 'bim_vfi', 'args': {'pyr_level': 3, 'feat_channels': 32}}
|
||||
model = make_components(cfg)
|
||||
|
||||
@@ -106,46 +105,15 @@ def get_model(repo_dir: Path, model_dir: Path, device: torch.device):
|
||||
return _model_cache[cache_key]
|
||||
|
||||
|
||||
def get_pyr_level(height: int) -> int:
|
||||
"""Get appropriate pyramid level based on image height.
|
||||
|
||||
Args:
|
||||
height: Image height in pixels.
|
||||
|
||||
Returns:
|
||||
Recommended pyramid level.
|
||||
"""
|
||||
if height >= 2160:
|
||||
return 7
|
||||
elif height >= 1080:
|
||||
return 6
|
||||
else:
|
||||
return 5
|
||||
|
||||
|
||||
def get_scale_factor(height: int) -> float:
|
||||
"""Get appropriate scale factor based on image height.
|
||||
|
||||
Args:
|
||||
height: Image height in pixels.
|
||||
|
||||
Returns:
|
||||
Recommended scale factor.
|
||||
"""
|
||||
if height >= 2160:
|
||||
return 0.25
|
||||
elif height >= 1080:
|
||||
return 0.5
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_single(
|
||||
model, img0: torch.Tensor, img1: torch.Tensor, t: float
|
||||
) -> torch.Tensor:
|
||||
"""Perform single frame interpolation using BiM-VFI.
|
||||
|
||||
Uses the model's trained pyr_level (set at construction time).
|
||||
The model handles input padding internally via InputPadder.
|
||||
|
||||
Args:
|
||||
model: BiM-VFI model instance.
|
||||
img0: First frame tensor (1, 3, H, W) normalized to [0, 1].
|
||||
@@ -155,24 +123,11 @@ def interpolate_single(
|
||||
Returns:
|
||||
Interpolated frame tensor.
|
||||
"""
|
||||
h = img0.shape[2]
|
||||
pyr_level = get_pyr_level(h)
|
||||
scale_factor = get_scale_factor(h)
|
||||
|
||||
time_step = torch.tensor([t]).view(1, 1, 1, 1).to(img0.device)
|
||||
|
||||
# Distance weights for better quality
|
||||
dis0 = torch.ones((1, 1, h, img0.shape[3]), device=img0.device) * t
|
||||
dis1 = 1 - dis0
|
||||
|
||||
results_dict = model(
|
||||
img0=img0, img1=img1,
|
||||
time_step=time_step,
|
||||
dis0=dis0, dis1=dis1,
|
||||
scale_factor=scale_factor,
|
||||
ratio=(1.0 / scale_factor),
|
||||
pyr_level=pyr_level,
|
||||
nr_lvl_skipped=0
|
||||
)
|
||||
|
||||
return results_dict['imgt_pred'].clamp(0, 1)
|
||||
|
||||
+72
-30
@@ -478,9 +478,13 @@ class FilmEnv:
|
||||
|
||||
|
||||
class BimVfiEnv:
|
||||
"""Manages BiM-VFI frame interpolation using shared venv with RIFE."""
|
||||
"""Manages BiM-VFI frame interpolation in its own isolated venv.
|
||||
|
||||
VENV_DIR = PRACTICAL_RIFE_VENV_DIR # Share venv with RIFE
|
||||
Uses a separate venv from RIFE/FILM to avoid dependency conflicts
|
||||
(basicsr-fixed, cupy, etc. can break the RIFE/FILM torch install).
|
||||
"""
|
||||
|
||||
VENV_DIR = CACHE_DIR / 'venv-bimvfi'
|
||||
REPO_DIR = CACHE_DIR / 'BiM-VFI'
|
||||
MODEL_CACHE_DIR = CACHE_DIR / 'bim-vfi'
|
||||
CHECKPOINT_FILENAME = 'bim_vfi.pth'
|
||||
@@ -488,10 +492,10 @@ class BimVfiEnv:
|
||||
# Google Drive file ID for the checkpoint
|
||||
GDRIVE_FILE_ID = '18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC'
|
||||
|
||||
# Extra pip packages needed beyond the base torch venv
|
||||
# All pip packages needed (torch/torchvision installed separately)
|
||||
EXTRA_PACKAGES = [
|
||||
'basicsr-fixed', 'imageio', 'pyyaml', 'opencv-python',
|
||||
'lpips', 'ptflops',
|
||||
'lpips', 'ptflops', 'gdown',
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@@ -510,17 +514,62 @@ class BimVfiEnv:
|
||||
|
||||
@classmethod
|
||||
def is_setup(cls) -> bool:
|
||||
"""Check if venv exists, repo is cloned, and checkpoint is present."""
|
||||
"""Check if venv exists with deps, repo is cloned, and checkpoint is present."""
|
||||
python = cls.get_venv_python()
|
||||
if not python or not python.exists():
|
||||
return False
|
||||
if not cls.REPO_DIR.exists():
|
||||
return False
|
||||
return cls.get_checkpoint_path().exists()
|
||||
if not cls.get_checkpoint_path().exists():
|
||||
return False
|
||||
# Verify torch is importable in this venv
|
||||
result = subprocess.run(
|
||||
[str(python), '-c', 'import torch; print(torch.__version__)'],
|
||||
capture_output=True
|
||||
)
|
||||
return result.returncode == 0
|
||||
|
||||
@classmethod
|
||||
def _create_venv(cls, progress_callback=None, cancelled_check=None) -> bool:
|
||||
"""Create the BiM-VFI venv and install PyTorch + deps.
|
||||
|
||||
Returns:
|
||||
True if venv creation was successful.
|
||||
"""
|
||||
if progress_callback:
|
||||
progress_callback("Creating BiM-VFI Python environment...", 5)
|
||||
if cancelled_check and cancelled_check():
|
||||
return False
|
||||
|
||||
import venv
|
||||
cls.VENV_DIR.mkdir(parents=True, exist_ok=True)
|
||||
venv.create(str(cls.VENV_DIR), with_pip=True, clear=True)
|
||||
|
||||
python = cls.get_venv_python()
|
||||
if not python or not python.exists():
|
||||
print("[BiM-VFI] Failed to create venv", file=sys.stderr)
|
||||
return False
|
||||
|
||||
# Install PyTorch (same approach as PracticalRifeEnv)
|
||||
if progress_callback:
|
||||
progress_callback("Installing PyTorch (this may take a while)...", 10)
|
||||
if cancelled_check and cancelled_check():
|
||||
return False
|
||||
|
||||
result = subprocess.run(
|
||||
[str(python), '-m', 'pip', 'install', '--quiet',
|
||||
'torch', 'torchvision', 'numpy'],
|
||||
capture_output=True, text=True, timeout=600
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(f"[BiM-VFI] PyTorch install failed: {result.stderr}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def setup_bim_vfi(cls, progress_callback=None, cancelled_check=None) -> bool:
|
||||
"""Clone repo, install deps, and download checkpoint.
|
||||
"""Create venv, clone repo, install deps, and download checkpoint.
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback(message, percent) for progress.
|
||||
@@ -529,15 +578,18 @@ class BimVfiEnv:
|
||||
Returns:
|
||||
True if setup was successful.
|
||||
"""
|
||||
python = cls.get_venv_python()
|
||||
if not python or not python.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
# Step 0: Create venv if needed
|
||||
python = cls.get_venv_python()
|
||||
if not python or not python.exists():
|
||||
if not cls._create_venv(progress_callback, cancelled_check):
|
||||
return False
|
||||
python = cls.get_venv_python()
|
||||
|
||||
# Step 1: Clone repo if needed
|
||||
if not cls.REPO_DIR.exists():
|
||||
if progress_callback:
|
||||
progress_callback("Cloning BiM-VFI repository...", 10)
|
||||
progress_callback("Cloning BiM-VFI repository...", 20)
|
||||
if cancelled_check and cancelled_check():
|
||||
return False
|
||||
|
||||
@@ -551,13 +603,12 @@ class BimVfiEnv:
|
||||
|
||||
# Step 2: Install extra packages
|
||||
if progress_callback:
|
||||
progress_callback("Installing BiM-VFI dependencies...", 30)
|
||||
progress_callback("Installing BiM-VFI dependencies...", 35)
|
||||
if cancelled_check and cancelled_check():
|
||||
return False
|
||||
|
||||
pip = cls.VENV_DIR / ('Scripts' if sys.platform == 'win32' else 'bin') / 'pip'
|
||||
result = subprocess.run(
|
||||
[str(pip), 'install', '--quiet'] + cls.EXTRA_PACKAGES,
|
||||
[str(python), '-m', 'pip', 'install', '--quiet'] + cls.EXTRA_PACKAGES,
|
||||
capture_output=True, text=True, timeout=600
|
||||
)
|
||||
if result.returncode != 0:
|
||||
@@ -574,7 +625,7 @@ class BimVfiEnv:
|
||||
cupy_installed = False
|
||||
for cupy_pkg in ['cupy-cuda12x', 'cupy-cuda11x']:
|
||||
result = subprocess.run(
|
||||
[str(pip), 'install', '--quiet', cupy_pkg],
|
||||
[str(python), '-m', 'pip', 'install', '--quiet', cupy_pkg],
|
||||
capture_output=True, text=True, timeout=600
|
||||
)
|
||||
if result.returncode == 0:
|
||||
@@ -582,11 +633,8 @@ class BimVfiEnv:
|
||||
break
|
||||
|
||||
if not cupy_installed:
|
||||
print("[BiM-VFI] Warning: cupy install failed, trying generic cupy", file=sys.stderr)
|
||||
subprocess.run(
|
||||
[str(pip), 'install', '--quiet', 'cupy'],
|
||||
capture_output=True, text=True, timeout=600
|
||||
)
|
||||
print("[BiM-VFI] Warning: cupy install failed, model may not work on GPU",
|
||||
file=sys.stderr)
|
||||
|
||||
# Step 3: Download checkpoint
|
||||
checkpoint_path = cls.get_checkpoint_path()
|
||||
@@ -598,20 +646,14 @@ class BimVfiEnv:
|
||||
|
||||
cls.MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use gdown to download from Google Drive
|
||||
result = subprocess.run(
|
||||
[str(pip), 'install', '--quiet', 'gdown'],
|
||||
capture_output=True, text=True, timeout=120
|
||||
)
|
||||
|
||||
gdown_bin = cls.VENV_DIR / ('Scripts' if sys.platform == 'win32' else 'bin') / 'gdown'
|
||||
# Use gdown via python -m to download from Google Drive
|
||||
tmp_path = checkpoint_path.with_suffix('.tmp')
|
||||
result = subprocess.run(
|
||||
[str(gdown_bin), '--id', cls.GDRIVE_FILE_ID,
|
||||
[str(python), '-m', 'gdown', '--id', cls.GDRIVE_FILE_ID,
|
||||
'--output', str(tmp_path)],
|
||||
capture_output=True, text=True, timeout=600
|
||||
)
|
||||
if result.returncode == 0 and tmp_path.exists():
|
||||
if result.returncode == 0 and tmp_path.exists() and tmp_path.stat().st_size > 1_000_000:
|
||||
tmp_path.rename(checkpoint_path)
|
||||
else:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
+43
-44
@@ -311,7 +311,7 @@ class DirectTransitionDialog(QDialog):
|
||||
|
||||
rife_ready = PracticalRifeEnv.is_setup()
|
||||
film_ready = FilmEnv.is_setup() if rife_ready else False
|
||||
bim_ready = BimVfiEnv.is_setup() if rife_ready else False
|
||||
bim_ready = BimVfiEnv.is_setup()
|
||||
|
||||
if method == DirectInterpolationMethod.RIFE:
|
||||
if rife_ready:
|
||||
@@ -338,29 +338,32 @@ class DirectTransitionDialog(QDialog):
|
||||
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
|
||||
self.setup_btn.setVisible(True)
|
||||
self.setup_btn.setText("Setup PyTorch Environment")
|
||||
else: # BIM_VFI
|
||||
else: # BIM_VFI (own venv, independent of RIFE)
|
||||
if bim_ready:
|
||||
self.status_label.setText("BiM-VFI: Ready")
|
||||
self.status_label.setStyleSheet("color: green; font-size: 10px;")
|
||||
self.setup_btn.setVisible(False)
|
||||
elif rife_ready:
|
||||
self.status_label.setText("BiM-VFI: Not installed (repo + model needed)")
|
||||
else:
|
||||
self.status_label.setText("BiM-VFI: Not installed")
|
||||
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
|
||||
self.setup_btn.setVisible(True)
|
||||
self.setup_btn.setText("Install BiM-VFI")
|
||||
else:
|
||||
self.status_label.setText("BiM-VFI: Not installed (PyTorch required first)")
|
||||
self.status_label.setStyleSheet("color: orange; font-size: 10px;")
|
||||
self.setup_btn.setVisible(True)
|
||||
self.setup_btn.setText("Setup PyTorch Environment")
|
||||
|
||||
def _on_setup(self) -> None:
|
||||
"""Handle setup button click."""
|
||||
method = self.method_combo.currentData()
|
||||
|
||||
# BiM-VFI has its own venv — handle independently
|
||||
if method == DirectInterpolationMethod.BIM_VFI:
|
||||
if not BimVfiEnv.is_setup():
|
||||
self._setup_bim_vfi()
|
||||
self._update_status()
|
||||
return
|
||||
|
||||
# RIFE and FILM share a venv — set up base PyTorch if needed
|
||||
rife_ready = PracticalRifeEnv.is_setup()
|
||||
|
||||
if not rife_ready:
|
||||
# Need to set up PyTorch venv first
|
||||
progress = QProgressDialog(
|
||||
"Setting up PyTorch environment...", "Cancel", 0, 100, self
|
||||
)
|
||||
@@ -369,8 +372,6 @@ class DirectTransitionDialog(QDialog):
|
||||
progress.setMinimumDuration(0)
|
||||
progress.setValue(0)
|
||||
|
||||
cancelled = [False]
|
||||
|
||||
def progress_cb(msg, pct):
|
||||
progress.setLabelText(msg)
|
||||
progress.setValue(pct)
|
||||
@@ -419,40 +420,38 @@ class DirectTransitionDialog(QDialog):
|
||||
)
|
||||
return
|
||||
|
||||
# If BiM-VFI selected and we need to install it
|
||||
if method == DirectInterpolationMethod.BIM_VFI and not BimVfiEnv.is_setup():
|
||||
progress = QProgressDialog(
|
||||
"Setting up BiM-VFI...", "Cancel", 0, 100, self
|
||||
)
|
||||
progress.setWindowTitle("Setup")
|
||||
progress.setWindowModality(Qt.WindowModality.WindowModal)
|
||||
progress.setMinimumDuration(0)
|
||||
progress.setValue(0)
|
||||
|
||||
def progress_cb(msg, pct):
|
||||
progress.setLabelText(msg)
|
||||
progress.setValue(pct)
|
||||
|
||||
def cancelled_check():
|
||||
QApplication.processEvents()
|
||||
return progress.wasCanceled()
|
||||
|
||||
success = BimVfiEnv.setup_bim_vfi(progress_cb, cancelled_check)
|
||||
progress.close()
|
||||
|
||||
if not success:
|
||||
if not cancelled_check():
|
||||
QMessageBox.warning(
|
||||
self, "Setup Failed",
|
||||
"Failed to set up BiM-VFI.\n\n"
|
||||
"If the checkpoint download failed, you can download it "
|
||||
"manually from Google Drive and place it at:\n"
|
||||
f"{BimVfiEnv.get_checkpoint_path()}"
|
||||
)
|
||||
return
|
||||
|
||||
self._update_status()
|
||||
|
||||
def _setup_bim_vfi(self) -> None:
|
||||
"""Run BiM-VFI setup with progress dialog."""
|
||||
progress = QProgressDialog(
|
||||
"Setting up BiM-VFI...", "Cancel", 0, 100, self
|
||||
)
|
||||
progress.setWindowTitle("Setup")
|
||||
progress.setWindowModality(Qt.WindowModality.WindowModal)
|
||||
progress.setMinimumDuration(0)
|
||||
progress.setValue(0)
|
||||
|
||||
def progress_cb(msg, pct):
|
||||
progress.setLabelText(msg)
|
||||
progress.setValue(pct)
|
||||
|
||||
def cancelled_check():
|
||||
QApplication.processEvents()
|
||||
return progress.wasCanceled()
|
||||
|
||||
success = BimVfiEnv.setup_bim_vfi(progress_cb, cancelled_check)
|
||||
progress.close()
|
||||
|
||||
if not success and not cancelled_check():
|
||||
QMessageBox.warning(
|
||||
self, "Setup Failed",
|
||||
"Failed to set up BiM-VFI.\n\n"
|
||||
"If the checkpoint download failed, you can download it "
|
||||
"manually from Google Drive and place it at:\n"
|
||||
f"{BimVfiEnv.get_checkpoint_path()}"
|
||||
)
|
||||
|
||||
def _on_remove(self) -> None:
|
||||
"""Handle remove button click."""
|
||||
self._removed = True
|
||||
|
||||
Reference in New Issue
Block a user