chore: vendor selva_core from jnwnlee/selva@d7d40a9
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,190 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from av import AudioFrame
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
@dataclass
|
||||
class VideoInfo:
|
||||
duration_sec: float
|
||||
fps: Fraction
|
||||
clip_frames: torch.Tensor
|
||||
sync_frames: torch.Tensor
|
||||
all_frames: Optional[list[np.ndarray]]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.all_frames[0].shape[0]
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.all_frames[0].shape[1]
|
||||
|
||||
@classmethod
|
||||
def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
|
||||
fps: Fraction) -> 'VideoInfo':
|
||||
num_frames = int(duration_sec * fps)
|
||||
all_frames = [image_info.original_frame] * num_frames
|
||||
return cls(duration_sec=duration_sec,
|
||||
fps=fps,
|
||||
clip_frames=image_info.clip_frames,
|
||||
sync_frames=image_info.sync_frames,
|
||||
all_frames=all_frames)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageInfo:
|
||||
clip_frames: torch.Tensor
|
||||
sync_frames: torch.Tensor
|
||||
original_frame: Optional[np.ndarray]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.original_frame.shape[0]
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.original_frame.shape[1]
|
||||
|
||||
|
||||
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
|
||||
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
|
||||
output_frames = [[] for _ in list_of_fps]
|
||||
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
|
||||
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
|
||||
all_frames = []
|
||||
|
||||
# container = av.open(video_path)
|
||||
with av.open(video_path) as container:
|
||||
stream = container.streams.video[0]
|
||||
fps = stream.guessed_rate
|
||||
stream.thread_type = 'AUTO'
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
frame_time = frame.time
|
||||
if frame_time < start_sec:
|
||||
continue
|
||||
if frame_time > end_sec:
|
||||
break
|
||||
|
||||
frame_np = None
|
||||
if need_all_frames:
|
||||
frame_np = frame.to_ndarray(format='rgb24')
|
||||
all_frames.append(frame_np)
|
||||
|
||||
for i, _ in enumerate(list_of_fps):
|
||||
this_time = frame_time
|
||||
while this_time >= next_frame_time_for_each_fps[i]:
|
||||
if frame_np is None:
|
||||
frame_np = frame.to_ndarray(format='rgb24')
|
||||
|
||||
output_frames[i].append(frame_np)
|
||||
next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
|
||||
|
||||
output_frames = [np.stack(frames) for frames in output_frames]
|
||||
return output_frames, all_frames, fps
|
||||
|
||||
|
||||
def normalize_video_chunk(video_chunk: torch.Tensor,
|
||||
expected_length: int,
|
||||
*,
|
||||
n_tolerance_frame: int = 1,
|
||||
desc: str = "") \
|
||||
-> torch.Tensor:
|
||||
# video_chunk: [T, H, W, C]
|
||||
if video_chunk.shape[0] < expected_length:
|
||||
if expected_length - video_chunk.shape[0] <= n_tolerance_frame:
|
||||
# copy the last frame to make it the right length
|
||||
log.warning(f'Video too short {desc}, padding {expected_length - video_chunk.shape[0]} frames with the last frame')
|
||||
video_chunk = torch.cat([video_chunk, video_chunk[-1:].repeat(expected_length - video_chunk.shape[0], 1, 1, 1)])
|
||||
assert video_chunk.shape[0] == expected_length
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Video too short {desc}, expected {expected_length}, got {video_chunk.shape[0]}'
|
||||
)
|
||||
video_chunk = video_chunk[:expected_length]
|
||||
if video_chunk.shape[0] != expected_length:
|
||||
raise RuntimeError(f'Video wrong length {desc}, '
|
||||
f'expected {expected_length}, '
|
||||
f'got {video_chunk.shape[0]}')
|
||||
|
||||
return video_chunk
|
||||
|
||||
|
||||
def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
|
||||
sampling_rate: int):
|
||||
container = av.open(output_path, 'w')
|
||||
output_video_stream = container.add_stream('h264', video_info.fps)
|
||||
output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
|
||||
output_video_stream.width = video_info.width
|
||||
output_video_stream.height = video_info.height
|
||||
output_video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
output_audio_stream = container.add_stream('aac', sampling_rate)
|
||||
|
||||
# encode video
|
||||
for image in video_info.all_frames:
|
||||
image = av.VideoFrame.from_ndarray(image)
|
||||
packet = output_video_stream.encode(image)
|
||||
container.mux(packet)
|
||||
|
||||
for packet in output_video_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
# convert float tensor audio to numpy array
|
||||
audio_np = audio.numpy().astype(np.float32)
|
||||
audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
||||
audio_frame.sample_rate = sampling_rate
|
||||
|
||||
for packet in output_audio_stream.encode(audio_frame):
|
||||
container.mux(packet)
|
||||
|
||||
for packet in output_audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
container.close()
|
||||
|
||||
|
||||
def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
|
||||
"""
|
||||
NOTE: I don't think we can get the exact video duration right without re-encoding
|
||||
so we are not using this but keeping it here for reference
|
||||
"""
|
||||
video = av.open(video_path)
|
||||
output = av.open(output_path, 'w')
|
||||
input_video_stream = video.streams.video[0]
|
||||
output_video_stream = output.add_stream(template=input_video_stream)
|
||||
output_audio_stream = output.add_stream('aac', sampling_rate)
|
||||
|
||||
duration_sec = audio.shape[-1] / sampling_rate
|
||||
|
||||
for packet in video.demux(input_video_stream):
|
||||
# We need to skip the "flushing" packets that `demux` generates.
|
||||
if packet.dts is None:
|
||||
continue
|
||||
# We need to assign the packet to the new stream.
|
||||
packet.stream = output_video_stream
|
||||
output.mux(packet)
|
||||
|
||||
# convert float tensor audio to numpy array
|
||||
audio_np = audio.numpy().astype(np.float32)
|
||||
audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
||||
audio_frame.sample_rate = sampling_rate
|
||||
|
||||
for packet in output_audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
for packet in output_audio_stream.encode():
|
||||
output.mux(packet)
|
||||
|
||||
video.close()
|
||||
output.close()
|
||||
|
||||
output.close()
|
||||
Reference in New Issue
Block a user