fix: three bugs in OmniVoiceMixVoices

- _resample: squeeze batch dim before torchaudio.Resample (expected 2D)
- weight scaling: each clip now trims to natural_length*weight samples,
  dropping the broken target_per_unit double-multiplication
- empty trimmed guard: raise clear error when all weights are 0

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 19:04:54 +02:00
parent c7c7123068
commit 219c74d7ed
+7 -12
View File
@@ -14,10 +14,10 @@ def _resample(waveform, src_sr, dst_sr):
return waveform return waveform
try: try:
import torchaudio import torchaudio
# Resample expects (channels, samples), not (batch, channels, samples)
resampler = torchaudio.transforms.Resample(orig_freq=src_sr, new_freq=dst_sr) resampler = torchaudio.transforms.Resample(orig_freq=src_sr, new_freq=dst_sr)
return resampler(waveform) return resampler(waveform.squeeze(0)).unsqueeze(0)
except Exception: except Exception:
# fallback: nearest-neighbour via interpolate
ratio = dst_sr / src_sr ratio = dst_sr / src_sr
new_len = int(waveform.shape[-1] * ratio) new_len = int(waveform.shape[-1] * ratio)
return torch.nn.functional.interpolate( return torch.nn.functional.interpolate(
@@ -94,30 +94,25 @@ class OmniVoiceMixVoices:
for audio, weight in zip(audios, weights): for audio, weight in zip(audios, weights):
w = _to_mono(audio["waveform"]) # (1, 1, samples) w = _to_mono(audio["waveform"]) # (1, 1, samples)
w = _resample(w, audio["sample_rate"], target_sr) w = _resample(w, audio["sample_rate"], target_sr)
# trim/repeat to match requested weight in seconds (normalise later)
clips.append((w, weight)) clips.append((w, weight))
# Determine target samples per unit weight # Each clip contributes (natural_length * weight) samples.
# Scale each clip so that weight=1.0 keeps its full length,
# and trim/tile accordingly relative to the largest weighted clip.
max_samples = max(c.shape[-1] * wt for c, wt in clips)
target_per_unit = max_samples # samples for weight=1.0
trimmed = [] trimmed = []
for clip, weight in clips: for clip, weight in clips:
n_samples = int(target_per_unit * weight) n_samples = int(clip.shape[-1] * weight)
if n_samples <= 0: if n_samples <= 0:
continue continue
src_len = clip.shape[-1] src_len = clip.shape[-1]
if src_len >= n_samples: if src_len >= n_samples:
trimmed.append(clip[..., :n_samples]) trimmed.append(clip[..., :n_samples])
else: else:
# tile then trim
reps = (n_samples // src_len) + 1 reps = (n_samples // src_len) + 1
tiled = clip.repeat(1, 1, reps) tiled = clip.repeat(1, 1, reps)
trimmed.append(tiled[..., :n_samples]) trimmed.append(tiled[..., :n_samples])
if not trimmed:
raise ValueError("OmniVoice Mix Voices: all weights are 0 — nothing to mix.")
mixed = torch.cat(trimmed, dim=-1) # (1, 1, total_samples) mixed = torch.cat(trimmed, dim=-1) # (1, 1, total_samples)
merged_text = " ".join(t.strip() for t in texts if t.strip()) merged_text = " ".join(t.strip() for t in texts if t.strip())