Files
Ethanfel 87bea21d49 feat: extract prismaudio_core inference with callback-enabled sampling
Add inference subpackage with:
- sampling.py: sample_discrete_euler modified from upstream to add callback
  parameter for ComfyUI progress bars (uses enumerate for step index)
- utils.py: set_audio_channels and prepare_audio for audio preprocessing

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-27 17:59:37 +01:00

63 lines
1.9 KiB
Python

import torch
import torch.nn.functional as F
from torchaudio import transforms as T
def set_audio_channels(audio, target_channels):
"""Convert audio tensor to target number of channels.
Args:
audio: Audio tensor of shape [B, C, T]
target_channels: Desired number of channels (1 for mono, 2 for stereo)
Returns:
Audio tensor with the target number of channels.
"""
if target_channels == 1:
# Convert to mono
audio = audio.mean(1, keepdim=True)
elif target_channels == 2:
# Convert to stereo
if audio.shape[1] == 1:
audio = audio.repeat(1, 2, 1)
elif audio.shape[1] > 2:
audio = audio[:, :2, :]
return audio
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
"""Resample, pad/trim, and convert channels of an audio tensor.
Args:
audio: Audio tensor (1D, 2D [C, T], or 3D [B, C, T])
in_sr: Input sample rate
target_sr: Target sample rate
target_length: Target length in samples (padded or cropped)
target_channels: Target number of channels
device: Torch device to place the audio on
Returns:
Audio tensor of shape [B, target_channels, target_length] on device.
"""
audio = audio.to(device)
if in_sr != target_sr:
resample_tf = T.Resample(in_sr, target_sr).to(device)
audio = resample_tf(audio)
# Add batch dimension
if audio.dim() == 1:
audio = audio.unsqueeze(0).unsqueeze(0)
elif audio.dim() == 2:
audio = audio.unsqueeze(0)
# Pad or crop to target_length
if audio.shape[-1] < target_length:
audio = F.pad(audio, (0, target_length - audio.shape[-1]))
elif audio.shape[-1] > target_length:
audio = audio[:, :, :target_length]
audio = set_audio_channels(audio, target_channels)
return audio