fix: two bugs in SelVA nodes
- selva_feature_extractor: cache hash now includes resolved duration; same video + different duration override no longer returns stale features - selva_sampler: MPS-safe noise generation (torch.Generator on CPU then move to device, same pattern as PrismAudioSampler) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -34,11 +34,12 @@ def _resize_frames(frames, size):
|
|||||||
return x.clamp(0.0, 1.0) # [N, C, H, W]
|
return x.clamp(0.0, 1.0) # [N, C, H, W]
|
||||||
|
|
||||||
|
|
||||||
def _hash_inputs(video_tensor, prompt, fps, variant):
|
def _hash_inputs(video_tensor, prompt, fps, duration, variant):
|
||||||
h = hashlib.sha256()
|
h = hashlib.sha256()
|
||||||
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024])
|
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024])
|
||||||
h.update(prompt.encode())
|
h.update(prompt.encode())
|
||||||
h.update(str(fps).encode())
|
h.update(str(fps).encode())
|
||||||
|
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
|
||||||
h.update(variant.encode())
|
h.update(variant.encode())
|
||||||
return h.hexdigest()[:16]
|
return h.hexdigest()[:16]
|
||||||
|
|
||||||
@@ -86,7 +87,7 @@ class SelvaFeatureExtractor:
|
|||||||
if not cache_dir:
|
if not cache_dir:
|
||||||
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
cache_key = _hash_inputs(video, prompt, fps, model["variant"])
|
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"])
|
||||||
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
||||||
|
|
||||||
if os.path.exists(cached_path):
|
if os.path.exists(cached_path):
|
||||||
|
|||||||
@@ -93,12 +93,13 @@ class SelvaSampler:
|
|||||||
bs=1, negative_text_features=neg_text_clip
|
bs=1, negative_text_features=neg_text_clip
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initial noise
|
# Initial noise (MPS doesn't support torch.Generator on device)
|
||||||
rng = torch.Generator(device=device).manual_seed(seed)
|
gen_device = "cpu" if device.type == "mps" else device
|
||||||
|
rng = torch.Generator(device=gen_device).manual_seed(seed)
|
||||||
x0 = torch.randn(
|
x0 = torch.randn(
|
||||||
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
||||||
device=device, dtype=dtype, generator=rng,
|
device=gen_device, dtype=dtype, generator=rng,
|
||||||
)
|
).to(device)
|
||||||
|
|
||||||
# Flow matching ODE (Euler)
|
# Flow matching ODE (Euler)
|
||||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
||||||
|
|||||||
Reference in New Issue
Block a user