diff --git a/nodes/selva_feature_extractor.py b/nodes/selva_feature_extractor.py index c802026..995adf9 100644 --- a/nodes/selva_feature_extractor.py +++ b/nodes/selva_feature_extractor.py @@ -65,8 +65,8 @@ class SelvaFeatureExtractor: }, } - RETURN_TYPES = ("SELVA_FEATURES", "FLOAT") - RETURN_NAMES = ("features", "fps") + RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING") + RETURN_NAMES = ("features", "fps", "prompt") FUNCTION = "extract_features" CATEGORY = PRISMAUDIO_CATEGORY @@ -92,7 +92,8 @@ class SelvaFeatureExtractor: if os.path.exists(cached_path): print(f"[SelVA] Using cached features: {cached_path}", flush=True) - return (_load_cached(cached_path), float(fps)) + cached = _load_cached(cached_path) + return (cached, float(fps), cached.get("prompt", prompt)) device = get_device() dtype = model["dtype"] @@ -159,7 +160,7 @@ class SelvaFeatureExtractor: "sync_features": sync_features.cpu(), "duration": float(duration), "prompt": prompt, - }, float(fps)) + }, float(fps), prompt) def _load_cached(path):