87bea21d49
Add inference subpackage with: - sampling.py: sample_discrete_euler modified from upstream to add callback parameter for ComfyUI progress bars (uses enumerate for step index) - utils.py: set_audio_channels and prepare_audio for audio preprocessing Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
63 lines
1.9 KiB
Python
63 lines
1.9 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torchaudio import transforms as T
|
|
|
|
|
|
def set_audio_channels(audio, target_channels):
|
|
"""Convert audio tensor to target number of channels.
|
|
|
|
Args:
|
|
audio: Audio tensor of shape [B, C, T]
|
|
target_channels: Desired number of channels (1 for mono, 2 for stereo)
|
|
|
|
Returns:
|
|
Audio tensor with the target number of channels.
|
|
"""
|
|
if target_channels == 1:
|
|
# Convert to mono
|
|
audio = audio.mean(1, keepdim=True)
|
|
elif target_channels == 2:
|
|
# Convert to stereo
|
|
if audio.shape[1] == 1:
|
|
audio = audio.repeat(1, 2, 1)
|
|
elif audio.shape[1] > 2:
|
|
audio = audio[:, :2, :]
|
|
return audio
|
|
|
|
|
|
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
|
"""Resample, pad/trim, and convert channels of an audio tensor.
|
|
|
|
Args:
|
|
audio: Audio tensor (1D, 2D [C, T], or 3D [B, C, T])
|
|
in_sr: Input sample rate
|
|
target_sr: Target sample rate
|
|
target_length: Target length in samples (padded or cropped)
|
|
target_channels: Target number of channels
|
|
device: Torch device to place the audio on
|
|
|
|
Returns:
|
|
Audio tensor of shape [B, target_channels, target_length] on device.
|
|
"""
|
|
audio = audio.to(device)
|
|
|
|
if in_sr != target_sr:
|
|
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
|
audio = resample_tf(audio)
|
|
|
|
# Add batch dimension
|
|
if audio.dim() == 1:
|
|
audio = audio.unsqueeze(0).unsqueeze(0)
|
|
elif audio.dim() == 2:
|
|
audio = audio.unsqueeze(0)
|
|
|
|
# Pad or crop to target_length
|
|
if audio.shape[-1] < target_length:
|
|
audio = F.pad(audio, (0, target_length - audio.shape[-1]))
|
|
elif audio.shape[-1] > target_length:
|
|
audio = audio[:, :, :target_length]
|
|
|
|
audio = set_audio_channels(audio, target_channels)
|
|
|
|
return audio
|