diff --git a/nodes/selva_feature_extractor.py b/nodes/selva_feature_extractor.py index 5d68bdc..c4a9f4a 100644 --- a/nodes/selva_feature_extractor.py +++ b/nodes/selva_feature_extractor.py @@ -158,50 +158,52 @@ class SelvaFeatureExtractor: print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True) pbar = comfy.utils.ProgressBar(3) - with torch.no_grad(): - # --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] --- - clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C] - clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384] - if mask is not None: - clip_frames = _apply_mask(clip_frames, mask) - clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384] - print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {'(masked)' if mask is not None else ''}", flush=True) + try: + with torch.no_grad(): + # --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] --- + clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C] + clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384] + if mask is not None: + clip_frames = _apply_mask(clip_frames, mask) + clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384] + print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {'(masked)' if mask is not None else ''}", flush=True) - clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024] - pbar.update(1) + clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024] + pbar.update(1) - # --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] --- - sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C] - sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224] - if mask is not None: - sync_frames = _apply_mask(sync_frames, mask) - # Pad to minimum 16 frames (TextSynchformer segment size) - if sync_frames.shape[0] < 16: - pad = 16 - sync_frames.shape[0] - sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0) - # Normalize [0,1] → [-1,1] - mean = _SYNC_MEAN.to(sync_frames.device) - std = _SYNC_STD.to(sync_frames.device) - sync_frames = (sync_frames - mean) / std - sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224] - print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {'(masked)' if mask is not None else ''}", flush=True) + # --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] --- + sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C] + sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224] + if mask is not None: + sync_frames = _apply_mask(sync_frames, mask) + # Pad to minimum 16 frames (TextSynchformer segment size) + if sync_frames.shape[0] < 16: + pad = 16 - sync_frames.shape[0] + sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0) + # Normalize [0,1] → [-1,1] + mean = _SYNC_MEAN.to(sync_frames.device) + std = _SYNC_STD.to(sync_frames.device) + sync_frames = (sync_frames - mean) / std + sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224] + print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {'(masked)' if mask is not None else ''}", flush=True) - # Encode T5 text + prepend supplementary tokens → text-conditioned sync features - text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L] - pbar.update(1) - text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask) - sync_features = net_video_enc.encode_video_with_sync( - sync_input, text_f=text_f, text_mask=text_mask - ) # [1, T_sync, 768] - pbar.update(1) + # Encode T5 text + prepend supplementary tokens → text-conditioned sync features + text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L] + pbar.update(1) + text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask) + sync_features = net_video_enc.encode_video_with_sync( + sync_input, text_f=text_f, text_mask=text_mask + ) # [1, T_sync, 768] + pbar.update(1) - print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True) - print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True) + print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True) + print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True) - if strategy == "offload_to_cpu": - feature_utils.to(get_offload_device()) - net_video_enc.to(get_offload_device()) - soft_empty_cache() + finally: + if strategy == "offload_to_cpu": + feature_utils.to(get_offload_device()) + net_video_enc.to(get_offload_device()) + soft_empty_cache() np.savez( cached_path, diff --git a/nodes/selva_sampler.py b/nodes/selva_sampler.py index 73101a0..249ccc4 100644 --- a/nodes/selva_sampler.py +++ b/nodes/selva_sampler.py @@ -54,7 +54,6 @@ class SelvaSampler: strategy = model["strategy"] net_generator = model["generator"] feature_utils = model["feature_utils"] - mode = model["mode"] # Validate that features were extracted with the same model variant feat_variant = features.get("variant") @@ -88,76 +87,78 @@ class SelvaSampler: feature_utils.to(device) soft_empty_cache() - clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024] - sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768] + try: + clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024] + sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768] - print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True) + print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True) - # Update model rotary position embeddings for actual feature shapes and duration. - # Use actual feature dimensions (not seq_cfg) to avoid rounding assertion mismatches. - net_generator.update_seq_lengths( - latent_seq_len=seq_cfg.latent_seq_len, - clip_seq_len=clip_f.shape[1], - sync_seq_len=sync_f.shape[1], - ) - print(f"[SelVA] seq: latent={seq_cfg.latent_seq_len} clip={clip_f.shape[1]} sync={sync_f.shape[1]}", flush=True) - - with torch.no_grad(): - # Encode text conditioning - text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D] - - # Encode negative prompt (or use empty conditions) - neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \ - if negative_prompt.strip() else None - - conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip) - empty_conditions = net_generator.get_empty_conditions( - bs=1, negative_text_features=neg_text_clip + # Update model rotary position embeddings for actual feature shapes and duration. + # Use actual feature dimensions (not seq_cfg) to avoid rounding assertion mismatches. + net_generator.update_seq_lengths( + latent_seq_len=seq_cfg.latent_seq_len, + clip_seq_len=clip_f.shape[1], + sync_seq_len=sync_f.shape[1], ) + print(f"[SelVA] seq: latent={seq_cfg.latent_seq_len} clip={clip_f.shape[1]} sync={sync_f.shape[1]}", flush=True) - # Initial noise (MPS doesn't support torch.Generator on device) - gen_device = "cpu" if device.type == "mps" else device - rng = torch.Generator(device=gen_device).manual_seed(seed) - x0 = torch.randn( - 1, seq_cfg.latent_seq_len, net_generator.latent_dim, - device=gen_device, dtype=dtype, generator=rng, - ).to(device) + with torch.no_grad(): + # Encode text conditioning + text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D] - # Flow matching ODE (Euler) - fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps) - pbar = comfy.utils.ProgressBar(steps) + # Encode negative prompt (or use empty conditions) + neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \ + if negative_prompt.strip() else None - def ode_wrapper_tracked(t, x): - comfy.model_management.throw_exception_if_processing_interrupted() - pbar.update(1) - return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength) + conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip) + empty_conditions = net_generator.get_empty_conditions( + bs=1, negative_text_features=neg_text_clip + ) + # Initial noise (MPS doesn't support torch.Generator on device) + gen_device = "cpu" if device.type == "mps" else device + rng = torch.Generator(device=gen_device).manual_seed(seed) + x0 = torch.randn( + 1, seq_cfg.latent_seq_len, net_generator.latent_dim, + device=gen_device, dtype=dtype, generator=rng, + ).to(device) + + # Flow matching ODE (Euler) + fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps) + pbar = comfy.utils.ProgressBar(steps) + + def ode_wrapper_tracked(t, x): + comfy.model_management.throw_exception_if_processing_interrupted() + pbar.update(1) + return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength) + + try: + x1 = fm.to_data(ode_wrapper_tracked, x0) + except torch.cuda.OutOfMemoryError: + raise RuntimeError( + "[SelVA] CUDA out of memory during generation. Try switching offload_strategy " + "to 'offload_to_cpu', using a smaller variant, or reducing duration." + ) + + print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True) + + # Decode: latent → mel → audio try: - x1 = fm.to_data(ode_wrapper_tracked, x0) + with torch.no_grad(): + x1_unnorm = net_generator.unnormalize(x1) + spec = feature_utils.decode(x1_unnorm) # latent → mel spectrogram + audio = feature_utils.vocode(spec) # mel → waveform except torch.cuda.OutOfMemoryError: raise RuntimeError( - "[SelVA] CUDA out of memory during generation. Try switching offload_strategy " + "[SelVA] CUDA out of memory during decode/vocode. Try switching offload_strategy " "to 'offload_to_cpu', using a smaller variant, or reducing duration." ) - print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True) - - # Decode: latent → mel → audio - try: - with torch.no_grad(): - x1_unnorm = net_generator.unnormalize(x1) - spec = feature_utils.decode(x1_unnorm) # latent → mel spectrogram - audio = feature_utils.vocode(spec) # mel → waveform - except torch.cuda.OutOfMemoryError: - raise RuntimeError( - "[SelVA] CUDA out of memory during decode/vocode. Try switching offload_strategy " - "to 'offload_to_cpu', using a smaller variant, or reducing duration." - ) - - if strategy == "offload_to_cpu": - net_generator.to(get_offload_device()) - feature_utils.to(get_offload_device()) - soft_empty_cache() + finally: + if strategy == "offload_to_cpu": + net_generator.to(get_offload_device()) + feature_utils.to(get_offload_device()) + soft_empty_cache() # Ensure [1, 1, samples] and normalize to [-1,1] audio = audio.float()