Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c9550ce693 | |||
| f3cabcad90 |
@@ -35,6 +35,66 @@ def _resize_frames(frames, size):
|
|||||||
return x.clamp(0.0, 1.0) # [N, C, H, W]
|
return x.clamp(0.0, 1.0) # [N, C, H, W]
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_mask_bbox(mask, frame_h, frame_w, margin=0.1, square=True):
|
||||||
|
"""
|
||||||
|
Compute a bounding box around the union of all mask frames.
|
||||||
|
|
||||||
|
mask: [M, H', W'] float [0,1]
|
||||||
|
square: if True, expand bbox to a square and shift into frame bounds;
|
||||||
|
if False, apply margin independently on each axis (rect crop).
|
||||||
|
Returns (y0, x0, y1, x1) in pixel coords clamped to (frame_h, frame_w).
|
||||||
|
"""
|
||||||
|
if mask.shape[1] != frame_h or mask.shape[2] != frame_w:
|
||||||
|
m = F.interpolate(
|
||||||
|
mask.float().unsqueeze(1), size=(frame_h, frame_w), mode="nearest-exact"
|
||||||
|
).squeeze(1)
|
||||||
|
else:
|
||||||
|
m = mask.float()
|
||||||
|
|
||||||
|
union = (m > 0.5).max(dim=0).values # [H, W] bool
|
||||||
|
|
||||||
|
if not union.any():
|
||||||
|
if square:
|
||||||
|
# Empty mask — center square crop
|
||||||
|
side = min(frame_h, frame_w)
|
||||||
|
cy, cx = frame_h // 2, frame_w // 2
|
||||||
|
y0 = max(0, cy - side // 2)
|
||||||
|
x0 = max(0, cx - side // 2)
|
||||||
|
return y0, x0, min(frame_h, y0 + side), min(frame_w, x0 + side)
|
||||||
|
else:
|
||||||
|
# Empty mask — return full frame (no meaningful rect to crop to)
|
||||||
|
return 0, 0, frame_h, frame_w
|
||||||
|
|
||||||
|
ys = union.any(dim=1).nonzero(as_tuple=True)[0]
|
||||||
|
xs = union.any(dim=0).nonzero(as_tuple=True)[0]
|
||||||
|
y0, y1 = int(ys[0]), int(ys[-1]) + 1
|
||||||
|
x0, x1 = int(xs[0]), int(xs[-1]) + 1
|
||||||
|
|
||||||
|
if square:
|
||||||
|
side = max(y1 - y0, x1 - x0)
|
||||||
|
pad = int(side * margin)
|
||||||
|
side += 2 * pad
|
||||||
|
|
||||||
|
cy = (y0 + y1) // 2
|
||||||
|
cx = (x0 + x1) // 2
|
||||||
|
y0n = cy - side // 2
|
||||||
|
x0n = cx - side // 2
|
||||||
|
y1n = y0n + side
|
||||||
|
x1n = x0n + side
|
||||||
|
|
||||||
|
# Shift into frame bounds to preserve square shape
|
||||||
|
if y0n < 0: y1n -= y0n; y0n = 0
|
||||||
|
if y1n > frame_h: y0n -= y1n - frame_h; y1n = frame_h
|
||||||
|
if x0n < 0: x1n -= x0n; x0n = 0
|
||||||
|
if x1n > frame_w: x0n -= x1n - frame_w; x1n = frame_w
|
||||||
|
|
||||||
|
return max(0, int(y0n)), max(0, int(x0n)), min(frame_h, int(y1n)), min(frame_w, int(x1n))
|
||||||
|
else:
|
||||||
|
pad_y = int(max(1, y1 - y0) * margin)
|
||||||
|
pad_x = int(max(1, x1 - x0) * margin)
|
||||||
|
return max(0, y0 - pad_y), max(0, x0 - pad_x), min(frame_h, y1 + pad_y), min(frame_w, x1 + pad_x)
|
||||||
|
|
||||||
|
|
||||||
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
|
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
|
||||||
"""
|
"""
|
||||||
Apply a ComfyUI MASK to resized frames.
|
Apply a ComfyUI MASK to resized frames.
|
||||||
@@ -69,7 +129,8 @@ def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
|
|||||||
|
|
||||||
|
|
||||||
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
|
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
|
||||||
mask_strength=1.0, mask_clip=True, mask_sync=True):
|
mask_strength=1.0, mask_clip=True, mask_sync=True,
|
||||||
|
crop_to_mask=False, crop_rect=False, crop_margin=0.1):
|
||||||
h = hashlib.sha256()
|
h = hashlib.sha256()
|
||||||
raw = video_tensor.cpu().numpy().tobytes()
|
raw = video_tensor.cpu().numpy().tobytes()
|
||||||
n = len(raw)
|
n = len(raw)
|
||||||
@@ -87,6 +148,10 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
|
|||||||
h.update(str(round(mask_strength, 4)).encode())
|
h.update(str(round(mask_strength, 4)).encode())
|
||||||
h.update(str(mask_clip).encode())
|
h.update(str(mask_clip).encode())
|
||||||
h.update(str(mask_sync).encode())
|
h.update(str(mask_sync).encode())
|
||||||
|
h.update(str(crop_to_mask).encode())
|
||||||
|
h.update(str(crop_rect).encode())
|
||||||
|
if crop_to_mask or crop_rect:
|
||||||
|
h.update(str(round(crop_margin, 4)).encode())
|
||||||
h.update(prompt.encode())
|
h.update(prompt.encode())
|
||||||
h.update(str(fps).encode())
|
h.update(str(fps).encode())
|
||||||
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
|
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
|
||||||
@@ -131,6 +196,18 @@ class SelvaFeatureExtractor:
|
|||||||
"default": True,
|
"default": True,
|
||||||
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
|
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
|
||||||
}),
|
}),
|
||||||
|
"crop_to_mask": ("BOOLEAN", {
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Experimental. Crops frames to a square region around the mask bounding box before resizing. The model sees an undistorted view of the subject. Requires mask. Takes priority over crop_rect.",
|
||||||
|
}),
|
||||||
|
"crop_rect": ("BOOLEAN", {
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Experimental. Crops frames to a rectangle around the mask bounding box (with margin) before resizing. The model still stretches the crop to a square, but only sees the region around the target element. Simpler than crop_to_mask. Requires mask.",
|
||||||
|
}),
|
||||||
|
"crop_margin": ("FLOAT", {
|
||||||
|
"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.05,
|
||||||
|
"tooltip": "Margin added around the bounding box as a fraction of the bbox size. Shared by crop_to_mask and crop_rect. 0.1 = 10% on each side.",
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,7 +224,8 @@ class SelvaFeatureExtractor:
|
|||||||
|
|
||||||
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
||||||
duration=0.0, cache_dir="", mask=None,
|
duration=0.0, cache_dir="", mask=None,
|
||||||
mask_strength=1.0, mask_clip=True, mask_sync=True):
|
mask_strength=1.0, mask_clip=True, mask_sync=True,
|
||||||
|
crop_to_mask=False, crop_rect=False, crop_margin=0.1):
|
||||||
if video_info is not None:
|
if video_info is not None:
|
||||||
fps = video_info["loaded_fps"]
|
fps = video_info["loaded_fps"]
|
||||||
|
|
||||||
@@ -164,7 +242,8 @@ class SelvaFeatureExtractor:
|
|||||||
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
|
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
|
||||||
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync)
|
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync,
|
||||||
|
crop_to_mask=crop_to_mask, crop_rect=crop_rect, crop_margin=crop_margin)
|
||||||
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
||||||
|
|
||||||
if os.path.exists(cached_path):
|
if os.path.exists(cached_path):
|
||||||
@@ -186,10 +265,24 @@ class SelvaFeatureExtractor:
|
|||||||
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
|
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
|
||||||
pbar = comfy.utils.ProgressBar(3)
|
pbar = comfy.utils.ProgressBar(3)
|
||||||
|
|
||||||
|
# Pre-compute crop bbox once from the original-resolution mask
|
||||||
|
crop_bbox = None
|
||||||
|
if mask is not None and (crop_to_mask or crop_rect):
|
||||||
|
H_vid, W_vid = video.shape[1], video.shape[2]
|
||||||
|
_square = crop_to_mask # crop_to_mask takes priority; crop_rect is rect-only
|
||||||
|
crop_bbox = _compute_mask_bbox(mask, H_vid, W_vid, crop_margin, square=_square)
|
||||||
|
cy0, cx0, cy1, cx1 = crop_bbox
|
||||||
|
_mode = "square" if _square else "rect"
|
||||||
|
print(f"[SelVA] Mask crop ({_mode}): y={cy0}:{cy1} x={cx0}:{cx1} "
|
||||||
|
f"({cy1-cy0}×{cx1-cx0}px from {H_vid}×{W_vid})", flush=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
|
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
|
||||||
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
||||||
|
if crop_bbox is not None:
|
||||||
|
cy0, cx0, cy1, cx1 = crop_bbox
|
||||||
|
clip_frames = clip_frames[:, cy0:cy1, cx0:cx1, :]
|
||||||
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
||||||
if mask is not None and mask_clip:
|
if mask is not None and mask_clip:
|
||||||
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
|
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
|
||||||
@@ -202,6 +295,9 @@ class SelvaFeatureExtractor:
|
|||||||
|
|
||||||
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
|
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
|
||||||
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
|
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
|
||||||
|
if crop_bbox is not None:
|
||||||
|
cy0, cx0, cy1, cx1 = crop_bbox
|
||||||
|
sync_frames = sync_frames[:, cy0:cy1, cx0:cx1, :]
|
||||||
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
||||||
if mask is not None and mask_sync:
|
if mask is not None and mask_sync:
|
||||||
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
|
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
|
||||||
|
|||||||
Reference in New Issue
Block a user