feat: add torch.compile option to Model Loader
Compiles the model graph on first generation (~30-60s warmup) then speeds up all subsequent generations in the session. Recommended for audiobook pipelines. Default off. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+15
-1
@@ -33,6 +33,18 @@ class OmniVoiceModelLoader:
|
||||
["float16", "bfloat16", "float32"],
|
||||
{"default": "float16"},
|
||||
),
|
||||
"compile": (
|
||||
"BOOLEAN",
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": (
|
||||
"Run torch.compile() on the model after loading. "
|
||||
"First generation will be slow (~30-60s warmup) while the graph is compiled, "
|
||||
"then every subsequent generation in the session will be faster. "
|
||||
"Recommended for audiobook pipelines. Requires PyTorch 2.0+."
|
||||
),
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -41,7 +53,7 @@ class OmniVoiceModelLoader:
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = "OmniVoice"
|
||||
|
||||
def load_model(self, device, dtype):
|
||||
def load_model(self, device, dtype, compile=False):
|
||||
if OmniVoice is None:
|
||||
raise ImportError(
|
||||
"omnivoice is not installed. Run: pip install omnivoice --no-deps"
|
||||
@@ -53,4 +65,6 @@ class OmniVoiceModelLoader:
|
||||
dtype=DTYPE_MAP[dtype],
|
||||
cache_dir=CACHE_DIR,
|
||||
)
|
||||
if compile:
|
||||
model = torch.compile(model)
|
||||
return (model,)
|
||||
|
||||
Reference in New Issue
Block a user