Add auto_pyr_level toggle to select pyramid level by resolution
When enabled (default), automatically picks the optimal pyr_level based on input height: <540p=3, 540p=5, 1080p=6, 4K=7. When disabled, uses the manual pyr_level value. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
18
inference.py
18
inference.py
@@ -5,8 +5,9 @@ from .bim_vfi_arch import BiMVFI
|
|||||||
class BiMVFIModel:
|
class BiMVFIModel:
|
||||||
"""Clean inference wrapper around BiMVFI for ComfyUI integration."""
|
"""Clean inference wrapper around BiMVFI for ComfyUI integration."""
|
||||||
|
|
||||||
def __init__(self, checkpoint_path, pyr_level=3, device="cpu"):
|
def __init__(self, checkpoint_path, pyr_level=3, auto_pyr_level=True, device="cpu"):
|
||||||
self.pyr_level = pyr_level
|
self.pyr_level = pyr_level
|
||||||
|
self.auto_pyr_level = auto_pyr_level
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
self.model = BiMVFI(pyr_level=pyr_level, feat_channels=32)
|
self.model = BiMVFI(pyr_level=pyr_level, feat_channels=32)
|
||||||
@@ -58,12 +59,25 @@ class BiMVFIModel:
|
|||||||
img0 = frame0.to(device)
|
img0 = frame0.to(device)
|
||||||
img1 = frame1.to(device)
|
img1 = frame1.to(device)
|
||||||
|
|
||||||
|
if self.auto_pyr_level:
|
||||||
|
_, _, h, _ = img0.shape
|
||||||
|
if h >= 2160:
|
||||||
|
pyr_level = 7
|
||||||
|
elif h >= 1080:
|
||||||
|
pyr_level = 6
|
||||||
|
elif h >= 540:
|
||||||
|
pyr_level = 5
|
||||||
|
else:
|
||||||
|
pyr_level = 3
|
||||||
|
else:
|
||||||
|
pyr_level = self.pyr_level
|
||||||
|
|
||||||
time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1)
|
time_step_tensor = torch.tensor([time_step], device=device).view(1, 1, 1, 1)
|
||||||
|
|
||||||
result_dict = self.model(
|
result_dict = self.model(
|
||||||
img0=img0, img1=img1,
|
img0=img0, img1=img1,
|
||||||
time_step=time_step_tensor,
|
time_step=time_step_tensor,
|
||||||
pyr_level=self.pyr_level,
|
pyr_level=pyr_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
interp = result_dict["imgt_pred"]
|
interp = result_dict["imgt_pred"]
|
||||||
|
|||||||
12
nodes.py
12
nodes.py
@@ -57,9 +57,13 @@ class LoadBIMVFIModel:
|
|||||||
"default": MODEL_FILENAME,
|
"default": MODEL_FILENAME,
|
||||||
"tooltip": "Checkpoint file from models/bim-vfi/. Auto-downloads on first use if missing.",
|
"tooltip": "Checkpoint file from models/bim-vfi/. Auto-downloads on first use if missing.",
|
||||||
}),
|
}),
|
||||||
|
"auto_pyr_level": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Automatically select pyramid level based on input resolution: <540p=3, 540p=5, 1080p=6, 4K=7. Disable to use manual pyr_level.",
|
||||||
|
}),
|
||||||
"pyr_level": ("INT", {
|
"pyr_level": ("INT", {
|
||||||
"default": 3, "min": 3, "max": 7, "step": 1,
|
"default": 3, "min": 3, "max": 7, "step": 1,
|
||||||
"tooltip": "Pyramid levels for coarse-to-fine processing. More levels = captures larger motion but slower. Recommended: 3-5 for <540p, 5-6 for 1080p, 6-7 for 4K.",
|
"tooltip": "Manual pyramid levels for coarse-to-fine processing. Only used when auto_pyr_level is disabled. More levels = captures larger motion but slower.",
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -69,7 +73,7 @@ class LoadBIMVFIModel:
|
|||||||
FUNCTION = "load_model"
|
FUNCTION = "load_model"
|
||||||
CATEGORY = "video/BIM-VFI"
|
CATEGORY = "video/BIM-VFI"
|
||||||
|
|
||||||
def load_model(self, model_path, pyr_level):
|
def load_model(self, model_path, auto_pyr_level, pyr_level):
|
||||||
full_path = os.path.join(MODEL_DIR, model_path)
|
full_path = os.path.join(MODEL_DIR, model_path)
|
||||||
|
|
||||||
if not os.path.exists(full_path):
|
if not os.path.exists(full_path):
|
||||||
@@ -79,10 +83,12 @@ class LoadBIMVFIModel:
|
|||||||
wrapper = BiMVFIModel(
|
wrapper = BiMVFIModel(
|
||||||
checkpoint_path=full_path,
|
checkpoint_path=full_path,
|
||||||
pyr_level=pyr_level,
|
pyr_level=pyr_level,
|
||||||
|
auto_pyr_level=auto_pyr_level,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"BIM-VFI model loaded (pyr_level={pyr_level})")
|
mode = "auto" if auto_pyr_level else f"manual ({pyr_level})"
|
||||||
|
logger.info(f"BIM-VFI model loaded (pyr_level={mode})")
|
||||||
return (wrapper,)
|
return (wrapper,)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user