perf: replace MP4 encode/decode with lossless .npy frame transfer
Saves frames as uint8 .npy instead of H.264 MP4, eliminating the lossy codec roundtrip. extract_features.py loads .npy directly and skips decord when given a numpy file. Passes --source_fps for correct temporal sampling. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+11
-25
@@ -78,29 +78,14 @@ def _hash_inputs(video_tensor, cot_text):
|
|||||||
return h.hexdigest()[:16]
|
return h.hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
def _save_video_tensor_to_mp4(video_tensor, output_path, fps=30):
|
def _save_frames_to_npy(video_tensor, output_path):
|
||||||
"""Save ComfyUI IMAGE tensor [T,H,W,C] to MP4 by piping raw RGB to ffmpeg.
|
"""Save ComfyUI IMAGE tensor [T,H,W,C] float32 [0,1] to .npy as uint8.
|
||||||
|
|
||||||
Avoids intermediate PNG files — frames are streamed directly to ffmpeg stdin.
|
Lossless — avoids H.264 encode/decode roundtrip.
|
||||||
"""
|
"""
|
||||||
|
import numpy as np
|
||||||
frames_np = (video_tensor.cpu().numpy() * 255).astype("uint8")
|
frames_np = (video_tensor.cpu().numpy() * 255).astype("uint8")
|
||||||
T, H, W, C = frames_np.shape
|
np.save(output_path, frames_np)
|
||||||
|
|
||||||
result = subprocess.run(
|
|
||||||
[
|
|
||||||
"ffmpeg", "-y",
|
|
||||||
"-f", "rawvideo", "-vcodec", "rawvideo",
|
|
||||||
"-s", f"{W}x{H}", "-pix_fmt", "rgb24",
|
|
||||||
"-r", str(fps),
|
|
||||||
"-i", "pipe:0",
|
|
||||||
"-c:v", "libx264", "-pix_fmt", "yuv420p",
|
|
||||||
output_path,
|
|
||||||
],
|
|
||||||
input=frames_np.tobytes(),
|
|
||||||
capture_output=True,
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(f"[PrismAudio] ffmpeg failed:\n{result.stderr.decode()}")
|
|
||||||
|
|
||||||
|
|
||||||
class PrismAudioFeatureExtractor:
|
class PrismAudioFeatureExtractor:
|
||||||
@@ -143,15 +128,15 @@ class PrismAudioFeatureExtractor:
|
|||||||
loader = PrismAudioFeatureLoader()
|
loader = PrismAudioFeatureLoader()
|
||||||
return loader.load_features(cached_path)
|
return loader.load_features(cached_path)
|
||||||
|
|
||||||
# Save video to temp file
|
# Save frames to temp file (lossless .npy, no codec roundtrip)
|
||||||
import time
|
import time
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
frames = video.shape[0]
|
frames = video.shape[0]
|
||||||
print(f"[PrismAudio] Converting {frames} frames to MP4 (fps={fps})...", flush=True)
|
print(f"[PrismAudio] Saving {frames} frames to .npy (fps={fps})...", flush=True)
|
||||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp:
|
||||||
tmp_video = tmp.name
|
tmp_video = tmp.name
|
||||||
_save_video_tensor_to_mp4(video, tmp_video, fps=fps)
|
_save_frames_to_npy(video, tmp_video)
|
||||||
print(f"[PrismAudio] MP4 ready in {time.perf_counter() - t0:.1f}s ({tmp_video})", flush=True)
|
print(f"[PrismAudio] Frames saved in {time.perf_counter() - t0:.1f}s", flush=True)
|
||||||
|
|
||||||
# Build subprocess command
|
# Build subprocess command
|
||||||
script_path = os.path.join(
|
script_path = os.path.join(
|
||||||
@@ -165,6 +150,7 @@ class PrismAudioFeatureExtractor:
|
|||||||
"--video", tmp_video,
|
"--video", tmp_video,
|
||||||
"--cot_text", caption_cot,
|
"--cot_text", caption_cot,
|
||||||
"--output", cached_path,
|
"--output", cached_path,
|
||||||
|
"--source_fps", str(fps),
|
||||||
]
|
]
|
||||||
# Auto-resolve synchformer checkpoint from the prismaudio models dir
|
# Auto-resolve synchformer checkpoint from the prismaudio models dir
|
||||||
if not synchformer_ckpt:
|
if not synchformer_ckpt:
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ def main():
|
|||||||
parser.add_argument("--output", required=True, help="Output .npz path")
|
parser.add_argument("--output", required=True, help="Output .npz path")
|
||||||
parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint")
|
parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint")
|
||||||
parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON")
|
parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON")
|
||||||
|
parser.add_argument("--source_fps", type=float, default=30.0, help="Original video fps (used when --video is a .npy file)")
|
||||||
parser.add_argument("--clip_fps", type=float, default=4.0)
|
parser.add_argument("--clip_fps", type=float, default=4.0)
|
||||||
parser.add_argument("--clip_size", type=int, default=288)
|
parser.add_argument("--clip_size", type=int, default=288)
|
||||||
parser.add_argument("--sync_fps", type=float, default=25.0)
|
parser.add_argument("--sync_fps", type=float, default=25.0)
|
||||||
@@ -64,7 +65,6 @@ def main():
|
|||||||
t0 = _step(1, 6, "importing dependencies")
|
t0 = _step(1, 6, "importing dependencies")
|
||||||
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
|
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from decord import VideoReader, cpu
|
|
||||||
_done(t0)
|
_done(t0)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -78,6 +78,24 @@ def main():
|
|||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
t0 = _step(3, 6, "reading and preprocessing video")
|
t0 = _step(3, 6, "reading and preprocessing video")
|
||||||
|
if args.video.endswith(".npy"):
|
||||||
|
all_frames = np.load(args.video) # [T, H, W, C] uint8
|
||||||
|
fps = args.source_fps
|
||||||
|
total_frames = all_frames.shape[0]
|
||||||
|
duration = total_frames / fps
|
||||||
|
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||||
|
|
||||||
|
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
||||||
|
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||||
|
clip_frames = all_frames[clip_indices]
|
||||||
|
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||||
|
|
||||||
|
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
||||||
|
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||||
|
sync_frames = all_frames[sync_indices]
|
||||||
|
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||||
|
else:
|
||||||
|
from decord import VideoReader, cpu
|
||||||
vr = VideoReader(args.video, ctx=cpu(0))
|
vr = VideoReader(args.video, ctx=cpu(0))
|
||||||
fps = vr.get_avg_fps()
|
fps = vr.get_avg_fps()
|
||||||
total_frames = len(vr)
|
total_frames = len(vr)
|
||||||
@@ -89,6 +107,11 @@ def main():
|
|||||||
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
||||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||||
|
|
||||||
|
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
||||||
|
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||||
|
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
||||||
|
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||||
|
|
||||||
clip_transform = T.Compose([
|
clip_transform = T.Compose([
|
||||||
T.ToPILImage(),
|
T.ToPILImage(),
|
||||||
T.Resize(args.clip_size),
|
T.Resize(args.clip_size),
|
||||||
@@ -98,11 +121,6 @@ def main():
|
|||||||
])
|
])
|
||||||
clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device)
|
clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device)
|
||||||
|
|
||||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
|
||||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
|
||||||
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
|
||||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
|
||||||
|
|
||||||
sync_transform = T.Compose([
|
sync_transform = T.Compose([
|
||||||
T.ToPILImage(),
|
T.ToPILImage(),
|
||||||
T.Resize(args.sync_size),
|
T.Resize(args.sync_size),
|
||||||
|
|||||||
Reference in New Issue
Block a user