From 2db0cbb76aacb97051e13c36cd5ae0fcc1b270a0 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 12 Mar 2026 21:57:27 +0100 Subject: [PATCH] Fix BiM-VFI integration bugs from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix critical import: use `from modules.components import make_components` instead of `from modules.components.components` so the @register decorator fires and the model can be instantiated - Use separate venv (venv-bimvfi) to avoid dependency conflicts with RIFE/FILM — basicsr-fixed/cupy/lpips can break the shared venv - BimVfiEnv now creates its own venv with PyTorch, independent of RIFE - Remove dead code: dis0/dis1/scale_factor/ratio/nr_lvl_skipped kwargs are silently ignored by model's **kwargs, and get_scale_factor/ get_pyr_level were computing unused values - Use model's trained pyr_level=3 (set at construction) instead of overriding to 5-7 at inference which was untested - Use `python -m pip` and `python -m gdown` instead of direct binaries for cross-platform reliability - Add checkpoint size validation (>1MB) on download - Verify torch importable in is_setup() instead of just checking files - Update UI: BiM-VFI setup flow is independent, doesn't need RIFE first Co-Authored-By: Claude Opus 4.6 --- core/bim_vfi_worker.py | 57 +++-------------------- core/blender.py | 102 +++++++++++++++++++++++++++++------------ ui/main_window.py | 87 +++++++++++++++++------------------ 3 files changed, 121 insertions(+), 125 deletions(-) diff --git a/core/bim_vfi_worker.py b/core/bim_vfi_worker.py index bcdc998..f7d2034 100644 --- a/core/bim_vfi_worker.py +++ b/core/bim_vfi_worker.py @@ -55,7 +55,6 @@ def save_image(tensor: torch.Tensor, path: Path) -> None: Image.fromarray(arr).save(path) -# Global model cache _model_cache: dict = {} @@ -87,10 +86,10 @@ def get_model(repo_dir: Path, model_dir: Path, device: torch.device): print(f"Loading BiM-VFI model from {checkpoint_path}...", file=sys.stderr) - # Import BiM-VFI's component registry - from modules.components.components import make_components + # Import from the package __init__ so the @register decorator fires + from modules.components import make_components - # Create model with default config + # Create model with the trained config (pyr_level=3, feat_channels=32) cfg = {'name': 'bim_vfi', 'args': {'pyr_level': 3, 'feat_channels': 32}} model = make_components(cfg) @@ -106,46 +105,15 @@ def get_model(repo_dir: Path, model_dir: Path, device: torch.device): return _model_cache[cache_key] -def get_pyr_level(height: int) -> int: - """Get appropriate pyramid level based on image height. - - Args: - height: Image height in pixels. - - Returns: - Recommended pyramid level. - """ - if height >= 2160: - return 7 - elif height >= 1080: - return 6 - else: - return 5 - - -def get_scale_factor(height: int) -> float: - """Get appropriate scale factor based on image height. - - Args: - height: Image height in pixels. - - Returns: - Recommended scale factor. - """ - if height >= 2160: - return 0.25 - elif height >= 1080: - return 0.5 - else: - return 1.0 - - @torch.no_grad() def interpolate_single( model, img0: torch.Tensor, img1: torch.Tensor, t: float ) -> torch.Tensor: """Perform single frame interpolation using BiM-VFI. + Uses the model's trained pyr_level (set at construction time). + The model handles input padding internally via InputPadder. + Args: model: BiM-VFI model instance. img0: First frame tensor (1, 3, H, W) normalized to [0, 1]. @@ -155,24 +123,11 @@ def interpolate_single( Returns: Interpolated frame tensor. """ - h = img0.shape[2] - pyr_level = get_pyr_level(h) - scale_factor = get_scale_factor(h) - time_step = torch.tensor([t]).view(1, 1, 1, 1).to(img0.device) - # Distance weights for better quality - dis0 = torch.ones((1, 1, h, img0.shape[3]), device=img0.device) * t - dis1 = 1 - dis0 - results_dict = model( img0=img0, img1=img1, time_step=time_step, - dis0=dis0, dis1=dis1, - scale_factor=scale_factor, - ratio=(1.0 / scale_factor), - pyr_level=pyr_level, - nr_lvl_skipped=0 ) return results_dict['imgt_pred'].clamp(0, 1) diff --git a/core/blender.py b/core/blender.py index df91ec3..8cc2af8 100644 --- a/core/blender.py +++ b/core/blender.py @@ -478,9 +478,13 @@ class FilmEnv: class BimVfiEnv: - """Manages BiM-VFI frame interpolation using shared venv with RIFE.""" + """Manages BiM-VFI frame interpolation in its own isolated venv. - VENV_DIR = PRACTICAL_RIFE_VENV_DIR # Share venv with RIFE + Uses a separate venv from RIFE/FILM to avoid dependency conflicts + (basicsr-fixed, cupy, etc. can break the RIFE/FILM torch install). + """ + + VENV_DIR = CACHE_DIR / 'venv-bimvfi' REPO_DIR = CACHE_DIR / 'BiM-VFI' MODEL_CACHE_DIR = CACHE_DIR / 'bim-vfi' CHECKPOINT_FILENAME = 'bim_vfi.pth' @@ -488,10 +492,10 @@ class BimVfiEnv: # Google Drive file ID for the checkpoint GDRIVE_FILE_ID = '18Wre7XyRtu_wtFRzcsit6oNfHiFRt9vC' - # Extra pip packages needed beyond the base torch venv + # All pip packages needed (torch/torchvision installed separately) EXTRA_PACKAGES = [ 'basicsr-fixed', 'imageio', 'pyyaml', 'opencv-python', - 'lpips', 'ptflops', + 'lpips', 'ptflops', 'gdown', ] @classmethod @@ -510,17 +514,62 @@ class BimVfiEnv: @classmethod def is_setup(cls) -> bool: - """Check if venv exists, repo is cloned, and checkpoint is present.""" + """Check if venv exists with deps, repo is cloned, and checkpoint is present.""" python = cls.get_venv_python() if not python or not python.exists(): return False if not cls.REPO_DIR.exists(): return False - return cls.get_checkpoint_path().exists() + if not cls.get_checkpoint_path().exists(): + return False + # Verify torch is importable in this venv + result = subprocess.run( + [str(python), '-c', 'import torch; print(torch.__version__)'], + capture_output=True + ) + return result.returncode == 0 + + @classmethod + def _create_venv(cls, progress_callback=None, cancelled_check=None) -> bool: + """Create the BiM-VFI venv and install PyTorch + deps. + + Returns: + True if venv creation was successful. + """ + if progress_callback: + progress_callback("Creating BiM-VFI Python environment...", 5) + if cancelled_check and cancelled_check(): + return False + + import venv + cls.VENV_DIR.mkdir(parents=True, exist_ok=True) + venv.create(str(cls.VENV_DIR), with_pip=True, clear=True) + + python = cls.get_venv_python() + if not python or not python.exists(): + print("[BiM-VFI] Failed to create venv", file=sys.stderr) + return False + + # Install PyTorch (same approach as PracticalRifeEnv) + if progress_callback: + progress_callback("Installing PyTorch (this may take a while)...", 10) + if cancelled_check and cancelled_check(): + return False + + result = subprocess.run( + [str(python), '-m', 'pip', 'install', '--quiet', + 'torch', 'torchvision', 'numpy'], + capture_output=True, text=True, timeout=600 + ) + if result.returncode != 0: + print(f"[BiM-VFI] PyTorch install failed: {result.stderr}", file=sys.stderr) + return False + + return True @classmethod def setup_bim_vfi(cls, progress_callback=None, cancelled_check=None) -> bool: - """Clone repo, install deps, and download checkpoint. + """Create venv, clone repo, install deps, and download checkpoint. Args: progress_callback: Optional callback(message, percent) for progress. @@ -529,15 +578,18 @@ class BimVfiEnv: Returns: True if setup was successful. """ - python = cls.get_venv_python() - if not python or not python.exists(): - return False - try: + # Step 0: Create venv if needed + python = cls.get_venv_python() + if not python or not python.exists(): + if not cls._create_venv(progress_callback, cancelled_check): + return False + python = cls.get_venv_python() + # Step 1: Clone repo if needed if not cls.REPO_DIR.exists(): if progress_callback: - progress_callback("Cloning BiM-VFI repository...", 10) + progress_callback("Cloning BiM-VFI repository...", 20) if cancelled_check and cancelled_check(): return False @@ -551,13 +603,12 @@ class BimVfiEnv: # Step 2: Install extra packages if progress_callback: - progress_callback("Installing BiM-VFI dependencies...", 30) + progress_callback("Installing BiM-VFI dependencies...", 35) if cancelled_check and cancelled_check(): return False - pip = cls.VENV_DIR / ('Scripts' if sys.platform == 'win32' else 'bin') / 'pip' result = subprocess.run( - [str(pip), 'install', '--quiet'] + cls.EXTRA_PACKAGES, + [str(python), '-m', 'pip', 'install', '--quiet'] + cls.EXTRA_PACKAGES, capture_output=True, text=True, timeout=600 ) if result.returncode != 0: @@ -574,7 +625,7 @@ class BimVfiEnv: cupy_installed = False for cupy_pkg in ['cupy-cuda12x', 'cupy-cuda11x']: result = subprocess.run( - [str(pip), 'install', '--quiet', cupy_pkg], + [str(python), '-m', 'pip', 'install', '--quiet', cupy_pkg], capture_output=True, text=True, timeout=600 ) if result.returncode == 0: @@ -582,11 +633,8 @@ class BimVfiEnv: break if not cupy_installed: - print("[BiM-VFI] Warning: cupy install failed, trying generic cupy", file=sys.stderr) - subprocess.run( - [str(pip), 'install', '--quiet', 'cupy'], - capture_output=True, text=True, timeout=600 - ) + print("[BiM-VFI] Warning: cupy install failed, model may not work on GPU", + file=sys.stderr) # Step 3: Download checkpoint checkpoint_path = cls.get_checkpoint_path() @@ -598,20 +646,14 @@ class BimVfiEnv: cls.MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True) - # Use gdown to download from Google Drive - result = subprocess.run( - [str(pip), 'install', '--quiet', 'gdown'], - capture_output=True, text=True, timeout=120 - ) - - gdown_bin = cls.VENV_DIR / ('Scripts' if sys.platform == 'win32' else 'bin') / 'gdown' + # Use gdown via python -m to download from Google Drive tmp_path = checkpoint_path.with_suffix('.tmp') result = subprocess.run( - [str(gdown_bin), '--id', cls.GDRIVE_FILE_ID, + [str(python), '-m', 'gdown', '--id', cls.GDRIVE_FILE_ID, '--output', str(tmp_path)], capture_output=True, text=True, timeout=600 ) - if result.returncode == 0 and tmp_path.exists(): + if result.returncode == 0 and tmp_path.exists() and tmp_path.stat().st_size > 1_000_000: tmp_path.rename(checkpoint_path) else: tmp_path.unlink(missing_ok=True) diff --git a/ui/main_window.py b/ui/main_window.py index e85c1ea..1f7ebc1 100644 --- a/ui/main_window.py +++ b/ui/main_window.py @@ -311,7 +311,7 @@ class DirectTransitionDialog(QDialog): rife_ready = PracticalRifeEnv.is_setup() film_ready = FilmEnv.is_setup() if rife_ready else False - bim_ready = BimVfiEnv.is_setup() if rife_ready else False + bim_ready = BimVfiEnv.is_setup() if method == DirectInterpolationMethod.RIFE: if rife_ready: @@ -338,29 +338,32 @@ class DirectTransitionDialog(QDialog): self.status_label.setStyleSheet("color: orange; font-size: 10px;") self.setup_btn.setVisible(True) self.setup_btn.setText("Setup PyTorch Environment") - else: # BIM_VFI + else: # BIM_VFI (own venv, independent of RIFE) if bim_ready: self.status_label.setText("BiM-VFI: Ready") self.status_label.setStyleSheet("color: green; font-size: 10px;") self.setup_btn.setVisible(False) - elif rife_ready: - self.status_label.setText("BiM-VFI: Not installed (repo + model needed)") + else: + self.status_label.setText("BiM-VFI: Not installed") self.status_label.setStyleSheet("color: orange; font-size: 10px;") self.setup_btn.setVisible(True) self.setup_btn.setText("Install BiM-VFI") - else: - self.status_label.setText("BiM-VFI: Not installed (PyTorch required first)") - self.status_label.setStyleSheet("color: orange; font-size: 10px;") - self.setup_btn.setVisible(True) - self.setup_btn.setText("Setup PyTorch Environment") def _on_setup(self) -> None: """Handle setup button click.""" method = self.method_combo.currentData() + + # BiM-VFI has its own venv — handle independently + if method == DirectInterpolationMethod.BIM_VFI: + if not BimVfiEnv.is_setup(): + self._setup_bim_vfi() + self._update_status() + return + + # RIFE and FILM share a venv — set up base PyTorch if needed rife_ready = PracticalRifeEnv.is_setup() if not rife_ready: - # Need to set up PyTorch venv first progress = QProgressDialog( "Setting up PyTorch environment...", "Cancel", 0, 100, self ) @@ -369,8 +372,6 @@ class DirectTransitionDialog(QDialog): progress.setMinimumDuration(0) progress.setValue(0) - cancelled = [False] - def progress_cb(msg, pct): progress.setLabelText(msg) progress.setValue(pct) @@ -419,40 +420,38 @@ class DirectTransitionDialog(QDialog): ) return - # If BiM-VFI selected and we need to install it - if method == DirectInterpolationMethod.BIM_VFI and not BimVfiEnv.is_setup(): - progress = QProgressDialog( - "Setting up BiM-VFI...", "Cancel", 0, 100, self - ) - progress.setWindowTitle("Setup") - progress.setWindowModality(Qt.WindowModality.WindowModal) - progress.setMinimumDuration(0) - progress.setValue(0) - - def progress_cb(msg, pct): - progress.setLabelText(msg) - progress.setValue(pct) - - def cancelled_check(): - QApplication.processEvents() - return progress.wasCanceled() - - success = BimVfiEnv.setup_bim_vfi(progress_cb, cancelled_check) - progress.close() - - if not success: - if not cancelled_check(): - QMessageBox.warning( - self, "Setup Failed", - "Failed to set up BiM-VFI.\n\n" - "If the checkpoint download failed, you can download it " - "manually from Google Drive and place it at:\n" - f"{BimVfiEnv.get_checkpoint_path()}" - ) - return - self._update_status() + def _setup_bim_vfi(self) -> None: + """Run BiM-VFI setup with progress dialog.""" + progress = QProgressDialog( + "Setting up BiM-VFI...", "Cancel", 0, 100, self + ) + progress.setWindowTitle("Setup") + progress.setWindowModality(Qt.WindowModality.WindowModal) + progress.setMinimumDuration(0) + progress.setValue(0) + + def progress_cb(msg, pct): + progress.setLabelText(msg) + progress.setValue(pct) + + def cancelled_check(): + QApplication.processEvents() + return progress.wasCanceled() + + success = BimVfiEnv.setup_bim_vfi(progress_cb, cancelled_check) + progress.close() + + if not success and not cancelled_check(): + QMessageBox.warning( + self, "Setup Failed", + "Failed to set up BiM-VFI.\n\n" + "If the checkpoint download failed, you can download it " + "manually from Google Drive and place it at:\n" + f"{BimVfiEnv.get_checkpoint_path()}" + ) + def _on_remove(self) -> None: """Handle remove button click.""" self._removed = True