6bc3fd6443
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes. Imports rewritten from selva.* to selva_core.*. mel_converter.py: replaced librosa.filters.mel with pure-numpy implementation to avoid librosa→numba→NumPy version incompatibility in some ComfyUI environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
191 lines
6.5 KiB
Python
191 lines
6.5 KiB
Python
import logging
|
|
from dataclasses import dataclass
|
|
from fractions import Fraction
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import av
|
|
import numpy as np
|
|
import torch
|
|
from av import AudioFrame
|
|
|
|
log = logging.getLogger()
|
|
|
|
@dataclass
|
|
class VideoInfo:
|
|
duration_sec: float
|
|
fps: Fraction
|
|
clip_frames: torch.Tensor
|
|
sync_frames: torch.Tensor
|
|
all_frames: Optional[list[np.ndarray]]
|
|
|
|
@property
|
|
def height(self):
|
|
return self.all_frames[0].shape[0]
|
|
|
|
@property
|
|
def width(self):
|
|
return self.all_frames[0].shape[1]
|
|
|
|
@classmethod
|
|
def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
|
|
fps: Fraction) -> 'VideoInfo':
|
|
num_frames = int(duration_sec * fps)
|
|
all_frames = [image_info.original_frame] * num_frames
|
|
return cls(duration_sec=duration_sec,
|
|
fps=fps,
|
|
clip_frames=image_info.clip_frames,
|
|
sync_frames=image_info.sync_frames,
|
|
all_frames=all_frames)
|
|
|
|
|
|
@dataclass
|
|
class ImageInfo:
|
|
clip_frames: torch.Tensor
|
|
sync_frames: torch.Tensor
|
|
original_frame: Optional[np.ndarray]
|
|
|
|
@property
|
|
def height(self):
|
|
return self.original_frame.shape[0]
|
|
|
|
@property
|
|
def width(self):
|
|
return self.original_frame.shape[1]
|
|
|
|
|
|
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
|
|
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
|
|
output_frames = [[] for _ in list_of_fps]
|
|
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
|
|
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
|
|
all_frames = []
|
|
|
|
# container = av.open(video_path)
|
|
with av.open(video_path) as container:
|
|
stream = container.streams.video[0]
|
|
fps = stream.guessed_rate
|
|
stream.thread_type = 'AUTO'
|
|
for packet in container.demux(stream):
|
|
for frame in packet.decode():
|
|
frame_time = frame.time
|
|
if frame_time < start_sec:
|
|
continue
|
|
if frame_time > end_sec:
|
|
break
|
|
|
|
frame_np = None
|
|
if need_all_frames:
|
|
frame_np = frame.to_ndarray(format='rgb24')
|
|
all_frames.append(frame_np)
|
|
|
|
for i, _ in enumerate(list_of_fps):
|
|
this_time = frame_time
|
|
while this_time >= next_frame_time_for_each_fps[i]:
|
|
if frame_np is None:
|
|
frame_np = frame.to_ndarray(format='rgb24')
|
|
|
|
output_frames[i].append(frame_np)
|
|
next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
|
|
|
|
output_frames = [np.stack(frames) for frames in output_frames]
|
|
return output_frames, all_frames, fps
|
|
|
|
|
|
def normalize_video_chunk(video_chunk: torch.Tensor,
|
|
expected_length: int,
|
|
*,
|
|
n_tolerance_frame: int = 1,
|
|
desc: str = "") \
|
|
-> torch.Tensor:
|
|
# video_chunk: [T, H, W, C]
|
|
if video_chunk.shape[0] < expected_length:
|
|
if expected_length - video_chunk.shape[0] <= n_tolerance_frame:
|
|
# copy the last frame to make it the right length
|
|
log.warning(f'Video too short {desc}, padding {expected_length - video_chunk.shape[0]} frames with the last frame')
|
|
video_chunk = torch.cat([video_chunk, video_chunk[-1:].repeat(expected_length - video_chunk.shape[0], 1, 1, 1)])
|
|
assert video_chunk.shape[0] == expected_length
|
|
else:
|
|
raise RuntimeError(
|
|
f'Video too short {desc}, expected {expected_length}, got {video_chunk.shape[0]}'
|
|
)
|
|
video_chunk = video_chunk[:expected_length]
|
|
if video_chunk.shape[0] != expected_length:
|
|
raise RuntimeError(f'Video wrong length {desc}, '
|
|
f'expected {expected_length}, '
|
|
f'got {video_chunk.shape[0]}')
|
|
|
|
return video_chunk
|
|
|
|
|
|
def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
|
|
sampling_rate: int):
|
|
container = av.open(output_path, 'w')
|
|
output_video_stream = container.add_stream('h264', video_info.fps)
|
|
output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
|
|
output_video_stream.width = video_info.width
|
|
output_video_stream.height = video_info.height
|
|
output_video_stream.pix_fmt = 'yuv420p'
|
|
|
|
output_audio_stream = container.add_stream('aac', sampling_rate)
|
|
|
|
# encode video
|
|
for image in video_info.all_frames:
|
|
image = av.VideoFrame.from_ndarray(image)
|
|
packet = output_video_stream.encode(image)
|
|
container.mux(packet)
|
|
|
|
for packet in output_video_stream.encode():
|
|
container.mux(packet)
|
|
|
|
# convert float tensor audio to numpy array
|
|
audio_np = audio.numpy().astype(np.float32)
|
|
audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
|
audio_frame.sample_rate = sampling_rate
|
|
|
|
for packet in output_audio_stream.encode(audio_frame):
|
|
container.mux(packet)
|
|
|
|
for packet in output_audio_stream.encode():
|
|
container.mux(packet)
|
|
|
|
container.close()
|
|
|
|
|
|
def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
|
|
"""
|
|
NOTE: I don't think we can get the exact video duration right without re-encoding
|
|
so we are not using this but keeping it here for reference
|
|
"""
|
|
video = av.open(video_path)
|
|
output = av.open(output_path, 'w')
|
|
input_video_stream = video.streams.video[0]
|
|
output_video_stream = output.add_stream(template=input_video_stream)
|
|
output_audio_stream = output.add_stream('aac', sampling_rate)
|
|
|
|
duration_sec = audio.shape[-1] / sampling_rate
|
|
|
|
for packet in video.demux(input_video_stream):
|
|
# We need to skip the "flushing" packets that `demux` generates.
|
|
if packet.dts is None:
|
|
continue
|
|
# We need to assign the packet to the new stream.
|
|
packet.stream = output_video_stream
|
|
output.mux(packet)
|
|
|
|
# convert float tensor audio to numpy array
|
|
audio_np = audio.numpy().astype(np.float32)
|
|
audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
|
audio_frame.sample_rate = sampling_rate
|
|
|
|
for packet in output_audio_stream.encode(audio_frame):
|
|
output.mux(packet)
|
|
|
|
for packet in output_audio_stream.encode():
|
|
output.mux(packet)
|
|
|
|
video.close()
|
|
output.close()
|
|
|
|
output.close()
|