diff --git a/nodes/loader.py b/nodes/loader.py index 80fefe5..43b1d7f 100644 --- a/nodes/loader.py +++ b/nodes/loader.py @@ -33,6 +33,18 @@ class OmniVoiceModelLoader: ["float16", "bfloat16", "float32"], {"default": "float16"}, ), + "compile": ( + "BOOLEAN", + { + "default": False, + "tooltip": ( + "Run torch.compile() on the model after loading. " + "First generation will be slow (~30-60s warmup) while the graph is compiled, " + "then every subsequent generation in the session will be faster. " + "Recommended for audiobook pipelines. Requires PyTorch 2.0+." + ), + }, + ), }, } @@ -41,7 +53,7 @@ class OmniVoiceModelLoader: FUNCTION = "load_model" CATEGORY = "OmniVoice" - def load_model(self, device, dtype): + def load_model(self, device, dtype, compile=False): if OmniVoice is None: raise ImportError( "omnivoice is not installed. Run: pip install omnivoice --no-deps" @@ -53,4 +65,6 @@ class OmniVoiceModelLoader: dtype=DTYPE_MAP[dtype], cache_dir=CACHE_DIR, ) + if compile: + model = torch.compile(model) return (model,) diff --git a/workflows/omnivoice_voice_cloning.json b/workflows/omnivoice_voice_cloning.json index e2be3c7..35b56cc 100644 --- a/workflows/omnivoice_voice_cloning.json +++ b/workflows/omnivoice_voice_cloning.json @@ -6,7 +6,7 @@ "id": 1, "type": "OmniVoiceModelLoader", "pos": [40, 80], - "size": {"0": 300, "1": 100}, + "size": {"0": 300, "1": 130}, "flags": {}, "order": 0, "mode": 0, @@ -14,7 +14,7 @@ {"name": "model", "type": "OMNIVOICE_MODEL", "links": [1], "shape": 3, "slot_index": 0} ], "properties": {"Node name for S&R": "OmniVoiceModelLoader"}, - "widgets_values": ["cuda:0", "float16"] + "widgets_values": ["cuda:0", "float16", false] }, { "id": 2,