Update sharp_node.py
This commit is contained in:
@@ -28,10 +28,9 @@ class SharpnessAnalyzer:
|
||||
score = cv2.Laplacian(gray, cv2.CV_64F).var()
|
||||
scores.append(score)
|
||||
|
||||
# We pass the list of scores to the next node
|
||||
return (scores,)
|
||||
|
||||
# --- NODE 2: SELECTOR (Uses scores to filter high-res images) ---
|
||||
# --- NODE 2: SELECTOR (Filters High-Res images) ---
|
||||
class SharpFrameSelector:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -42,6 +41,8 @@ class SharpFrameSelector:
|
||||
"selection_method": (["batched", "best_n"],),
|
||||
"batch_size": ("INT", {"default": 24, "min": 1, "max": 10000, "step": 1}),
|
||||
"num_frames": ("INT", {"default": 10, "min": 1, "max": 10000, "step": 1}),
|
||||
# NEW SETTING
|
||||
"min_sharpness": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10000.0, "step": 0.1}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,18 +51,17 @@ class SharpFrameSelector:
|
||||
FUNCTION = "select_frames"
|
||||
CATEGORY = "SharpFrames"
|
||||
|
||||
def select_frames(self, images, scores, selection_method, batch_size, num_frames):
|
||||
def select_frames(self, images, scores, selection_method, batch_size, num_frames, min_sharpness):
|
||||
# Validation
|
||||
if len(images) != len(scores):
|
||||
print(f"[SharpSelector] WARNING: Frame count mismatch! Images: {len(images)}, Scores: {len(scores)}")
|
||||
# If mismatch (e.g. latent optimization), we truncate to the shorter length
|
||||
min_len = min(len(images), len(scores))
|
||||
images = images[:min_len]
|
||||
scores = scores[:min_len]
|
||||
|
||||
selected_indices = []
|
||||
|
||||
# --- SELECTION LOGIC (Same as before, but using pre-calculated scores) ---
|
||||
# --- SELECTION LOGIC ---
|
||||
if selection_method == "batched":
|
||||
total_frames = len(scores)
|
||||
for i in range(0, total_frames, batch_size):
|
||||
@@ -70,14 +70,46 @@ class SharpFrameSelector:
|
||||
|
||||
# Find best in batch
|
||||
best_in_chunk_idx = np.argmax(chunk_scores)
|
||||
best_score = chunk_scores[best_in_chunk_idx]
|
||||
|
||||
# Only keep if it passes the threshold
|
||||
if best_score >= min_sharpness:
|
||||
selected_indices.append(i + best_in_chunk_idx)
|
||||
|
||||
elif selection_method == "best_n":
|
||||
target_count = min(num_frames, len(scores))
|
||||
top_indices = np.argsort(scores)[-target_count:]
|
||||
selected_indices = sorted(top_indices)
|
||||
# 1. Filter out everything below threshold
|
||||
valid_indices = [i for i, s in enumerate(scores) if s >= min_sharpness]
|
||||
|
||||
# 2. Sort valid candidates by score (Low -> High)
|
||||
# We use numpy array for easy indexing
|
||||
valid_scores = np.array([scores[i] for i in valid_indices])
|
||||
|
||||
if len(valid_scores) > 0:
|
||||
# How many can we take?
|
||||
target_count = min(num_frames, len(valid_scores))
|
||||
|
||||
# Get indices of top N scores within the VALID list
|
||||
top_local_indices = np.argsort(valid_scores)[-target_count:]
|
||||
|
||||
# Map back to global indices
|
||||
top_global_indices = [valid_indices[i] for i in top_local_indices]
|
||||
|
||||
# Sort by time
|
||||
selected_indices = sorted(top_global_indices)
|
||||
else:
|
||||
selected_indices = []
|
||||
|
||||
print(f"[SharpSelector] Selected {len(selected_indices)} frames.")
|
||||
|
||||
# --- EMPTY RESULT SAFETY NET ---
|
||||
if len(selected_indices) == 0:
|
||||
print("[SharpSelector] Warning: No frames met criteria. Returning 1 black frame to prevent crash.")
|
||||
# Create 1 black pixel frame with same dimensions as input
|
||||
# This keeps the workflow alive
|
||||
h, w = images[0].shape[0], images[0].shape[1]
|
||||
empty_frame = torch.zeros((1, h, w, 3), dtype=images.dtype, device=images.device)
|
||||
return (empty_frame, 0)
|
||||
|
||||
result_images = images[selected_indices]
|
||||
|
||||
return (result_images, len(selected_indices))
|
||||
Reference in New Issue
Block a user