diff --git a/nodes/selva_feature_extractor.py b/nodes/selva_feature_extractor.py index 0ff28d6..2db1526 100644 --- a/nodes/selva_feature_extractor.py +++ b/nodes/selva_feature_extractor.py @@ -34,11 +34,12 @@ def _resize_frames(frames, size): return x.clamp(0.0, 1.0) # [N, C, H, W] -def _hash_inputs(video_tensor, prompt, fps, variant): +def _hash_inputs(video_tensor, prompt, fps, duration, variant): h = hashlib.sha256() h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) h.update(prompt.encode()) h.update(str(fps).encode()) + h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count h.update(variant.encode()) return h.hexdigest()[:16] @@ -86,7 +87,7 @@ class SelvaFeatureExtractor: if not cache_dir: cache_dir = os.path.join(tempfile.gettempdir(), "selva_features") os.makedirs(cache_dir, exist_ok=True) - cache_key = _hash_inputs(video, prompt, fps, model["variant"]) + cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"]) cached_path = os.path.join(cache_dir, f"{cache_key}.npz") if os.path.exists(cached_path): diff --git a/nodes/selva_sampler.py b/nodes/selva_sampler.py index 10dc5c0..3681a1d 100644 --- a/nodes/selva_sampler.py +++ b/nodes/selva_sampler.py @@ -93,12 +93,13 @@ class SelvaSampler: bs=1, negative_text_features=neg_text_clip ) - # Initial noise - rng = torch.Generator(device=device).manual_seed(seed) + # Initial noise (MPS doesn't support torch.Generator on device) + gen_device = "cpu" if device.type == "mps" else device + rng = torch.Generator(device=gen_device).manual_seed(seed) x0 = torch.randn( 1, seq_cfg.latent_seq_len, net_generator.latent_dim, - device=device, dtype=dtype, generator=rng, - ) + device=gen_device, dtype=dtype, generator=rng, + ).to(device) # Flow matching ODE (Euler) fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)