fix: add input validation and cpu() guard in OmniVoiceGenerate
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+6
-1
@@ -33,12 +33,17 @@ class OmniVoiceGenerate:
|
||||
def generate(self, model, text, mode, ref_audio=None, ref_text="", instruct="", speed=1.0, num_step=32):
|
||||
kwargs = {"text": text, "speed": speed, "num_step": num_step}
|
||||
|
||||
if mode == "voice_cloning" and ref_audio is None:
|
||||
raise ValueError("voice_cloning mode requires ref_audio to be connected")
|
||||
if mode == "voice_design" and not instruct:
|
||||
raise ValueError("voice_design mode requires an instruct string (e.g. 'female, low pitch')")
|
||||
|
||||
if mode == "voice_cloning" and ref_audio is not None:
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
try:
|
||||
waveform = ref_audio["waveform"].squeeze(0) # (channels, samples)
|
||||
waveform = ref_audio["waveform"].squeeze(0).cpu() # (channels, samples)
|
||||
torchaudio.save(tmp_path, waveform, ref_audio["sample_rate"])
|
||||
kwargs["ref_audio"] = tmp_path
|
||||
if ref_text:
|
||||
|
||||
Reference in New Issue
Block a user