Update parallel_loader.py

This commit is contained in:
2026-01-19 12:00:54 +01:00
parent 689c67cd7f
commit 610c52c0b8

View File

@@ -19,14 +19,13 @@ class ParallelSharpnessLoader:
"frame_scan_step": ("INT", {"default": 5, "min": 1, "step": 1, "label": "Analyze Every Nth Frame"}), "frame_scan_step": ("INT", {"default": 5, "min": 1, "step": 1, "label": "Analyze Every Nth Frame"}),
"return_count": ("INT", {"default": 4, "min": 1, "max": 1024, "step": 1, "label": "Best Frames to Return"}), "return_count": ("INT", {"default": 4, "min": 1, "max": 1024, "step": 1, "label": "Best Frames to Return"}),
"min_distance": ("INT", {"default": 24, "min": 0, "max": 10000, "step": 1, "label": "Min Distance (Frames)"}), "min_distance": ("INT", {"default": 24, "min": 0, "max": 10000, "step": 1, "label": "Min Distance (Frames)"}),
# MANUAL OFFSET (Optional: e.g. skip the first 2000 frames always)
"manual_skip_start": ("INT", {"default": 0, "min": 0, "max": 10000000, "step": 1, "label": "Global Start Offset"}), "manual_skip_start": ("INT", {"default": 0, "min": 0, "max": 10000000, "step": 1, "label": "Global Start Offset"}),
}, },
} }
RETURN_TYPES = ("IMAGE", "STRING", "INT") # Added a 4th output: STRING (The status sentence)
RETURN_NAMES = ("images", "scores_info", "current_batch_index") RETURN_TYPES = ("IMAGE", "STRING", "INT", "STRING")
RETURN_NAMES = ("images", "scores_info", "batch_int", "batch_status")
FUNCTION = "load_video" FUNCTION = "load_video"
CATEGORY = "BetaHelper/Video" CATEGORY = "BetaHelper/Video"
@@ -42,19 +41,21 @@ class ParallelSharpnessLoader:
if not os.path.exists(video_path): if not os.path.exists(video_path):
raise FileNotFoundError(f"Video not found: {video_path}") raise FileNotFoundError(f"Video not found: {video_path}")
# 2. Calculate Actual Start Frame # 2. Calculate Offsets
# Formula: (Batch Number * Frames Per Batch) + Global Offset
current_skip = (batch_index * scan_limit) + manual_skip_start current_skip = (batch_index * scan_limit) + manual_skip_start
range_end = current_skip + scan_limit
print(f"xx- Parallel Loader | Batch: {batch_index} | Start Frame: {current_skip} | Range: {current_skip} -> {current_skip + scan_limit}") # Create the Status String
status_msg = f"Batch {batch_index}: Skipped {current_skip} frames. Scanning range {current_skip} -> {range_end}."
print(f"xx- Parallel Loader | {status_msg}")
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if current_skip >= total_frames: if current_skip >= total_frames:
print("xx- End of video reached.") print("xx- End of video reached.")
# Return a black frame to prevent crashing, or handle as you wish empty_img = torch.zeros((1, 64, 64, 3))
return (torch.zeros((1, 64, 64, 3)), "End of Video", batch_index) return (empty_img, "End of Video", batch_index, "End of Video Reached")
# 3. Scanning (Pass 1) # 3. Scanning (Pass 1)
if current_skip > 0: if current_skip > 0:
@@ -68,14 +69,12 @@ class ParallelSharpnessLoader:
futures = [] futures = []
while True: while True:
# Stop if we hit the batch limit
if scanned_count >= scan_limit: if scanned_count >= scan_limit:
break break
ret, frame = cap.read() ret, frame = cap.read()
if not ret: break # End of file if not ret: break
# Submit to thread
future = executor.submit(self.calculate_sharpness, frame) future = executor.submit(self.calculate_sharpness, frame)
futures.append((current_frame, future)) futures.append((current_frame, future))
scanned_count += 1 scanned_count += 1
@@ -95,7 +94,7 @@ class ParallelSharpnessLoader:
# 4. Selection # 4. Selection
if not frame_scores: if not frame_scores:
return (torch.zeros((1, 64, 64, 3)), "No frames in batch", batch_index) return (torch.zeros((1, 64, 64, 3)), "No frames in batch", batch_index, status_msg + " (No frames found)")
frame_scores.sort(key=lambda x: x[1], reverse=True) frame_scores.sort(key=lambda x: x[1], reverse=True)
selected = [] selected = []
@@ -106,7 +105,6 @@ class ParallelSharpnessLoader:
selected.append((idx, score)) selected.append((idx, score))
selected.sort(key=lambda x: x[0]) selected.sort(key=lambda x: x[0])
print(f"xx- Selected Frames: {[f[0] for f in selected]}")
# 5. Extraction (Pass 2) # 5. Extraction (Pass 2)
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
@@ -125,6 +123,7 @@ class ParallelSharpnessLoader:
cap.release() cap.release()
if not output_tensors: if not output_tensors:
return (torch.zeros((1, 64, 64, 3)), "Extraction Failed", batch_index) return (torch.zeros((1, 64, 64, 3)), "Extraction Failed", batch_index, status_msg)
return (torch.stack(output_tensors), ", ".join(info_log), batch_index) # Return all 4 outputs
return (torch.stack(output_tensors), ", ".join(info_log), batch_index, status_msg)