diff --git a/sharp_node.py b/sharp_node.py index a50cbb1..8eba4bb 100644 --- a/sharp_node.py +++ b/sharp_node.py @@ -28,10 +28,9 @@ class SharpnessAnalyzer: score = cv2.Laplacian(gray, cv2.CV_64F).var() scores.append(score) - # We pass the list of scores to the next node return (scores,) -# --- NODE 2: SELECTOR (Uses scores to filter high-res images) --- +# --- NODE 2: SELECTOR (Filters High-Res images) --- class SharpFrameSelector: @classmethod def INPUT_TYPES(s): @@ -42,6 +41,8 @@ class SharpFrameSelector: "selection_method": (["batched", "best_n"],), "batch_size": ("INT", {"default": 24, "min": 1, "max": 10000, "step": 1}), "num_frames": ("INT", {"default": 10, "min": 1, "max": 10000, "step": 1}), + # NEW SETTING + "min_sharpness": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10000.0, "step": 0.1}), } } @@ -50,18 +51,17 @@ class SharpFrameSelector: FUNCTION = "select_frames" CATEGORY = "SharpFrames" - def select_frames(self, images, scores, selection_method, batch_size, num_frames): + def select_frames(self, images, scores, selection_method, batch_size, num_frames, min_sharpness): # Validation if len(images) != len(scores): print(f"[SharpSelector] WARNING: Frame count mismatch! Images: {len(images)}, Scores: {len(scores)}") - # If mismatch (e.g. latent optimization), we truncate to the shorter length min_len = min(len(images), len(scores)) images = images[:min_len] scores = scores[:min_len] selected_indices = [] - # --- SELECTION LOGIC (Same as before, but using pre-calculated scores) --- + # --- SELECTION LOGIC --- if selection_method == "batched": total_frames = len(scores) for i in range(0, total_frames, batch_size): @@ -70,14 +70,46 @@ class SharpFrameSelector: # Find best in batch best_in_chunk_idx = np.argmax(chunk_scores) - selected_indices.append(i + best_in_chunk_idx) + best_score = chunk_scores[best_in_chunk_idx] + + # Only keep if it passes the threshold + if best_score >= min_sharpness: + selected_indices.append(i + best_in_chunk_idx) elif selection_method == "best_n": - target_count = min(num_frames, len(scores)) - top_indices = np.argsort(scores)[-target_count:] - selected_indices = sorted(top_indices) + # 1. Filter out everything below threshold + valid_indices = [i for i, s in enumerate(scores) if s >= min_sharpness] + + # 2. Sort valid candidates by score (Low -> High) + # We use numpy array for easy indexing + valid_scores = np.array([scores[i] for i in valid_indices]) + + if len(valid_scores) > 0: + # How many can we take? + target_count = min(num_frames, len(valid_scores)) + + # Get indices of top N scores within the VALID list + top_local_indices = np.argsort(valid_scores)[-target_count:] + + # Map back to global indices + top_global_indices = [valid_indices[i] for i in top_local_indices] + + # Sort by time + selected_indices = sorted(top_global_indices) + else: + selected_indices = [] print(f"[SharpSelector] Selected {len(selected_indices)} frames.") + + # --- EMPTY RESULT SAFETY NET --- + if len(selected_indices) == 0: + print("[SharpSelector] Warning: No frames met criteria. Returning 1 black frame to prevent crash.") + # Create 1 black pixel frame with same dimensions as input + # This keeps the workflow alive + h, w = images[0].shape[0], images[0].shape[1] + empty_frame = torch.zeros((1, h, w, 3), dtype=images.dtype, device=images.device) + return (empty_frame, 0) + result_images = images[selected_indices] return (result_images, len(selected_indices)) \ No newline at end of file