feat: extract prismaudio_core inference with callback-enabled sampling

Add inference subpackage with:
- sampling.py: sample_discrete_euler modified from upstream to add callback
  parameter for ComfyUI progress bars (uses enumerate for step index)
- utils.py: set_audio_channels and prepare_audio for audio preprocessing

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 17:59:37 +01:00
parent 30e85f0f99
commit 87bea21d49
3 changed files with 96 additions and 0 deletions
+4
View File
@@ -0,0 +1,4 @@
from .sampling import sample_discrete_euler
from .utils import set_audio_channels, prepare_audio
__all__ = ["sample_discrete_euler", "set_audio_channels", "prepare_audio"]
+30
View File
@@ -0,0 +1,30 @@
import torch
from tqdm import trange
@torch.no_grad()
def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args):
"""Discrete Euler sampler for rectified flow, with optional callback.
Modified from PrismAudio to add callback parameter for ComfyUI progress reporting.
Original uses tqdm internally.
Args:
model: The diffusion model (DiTWrapper)
x: Initial noise tensor [B, C, T]
steps: Number of sampling steps
sigma_max: Maximum sigma (default 1.0 for rectified flow)
callback: Optional callable({"i": step, "x": current_x}) for progress
**extra_args: Passed to model() — includes cross_attn_cond, add_cond,
sync_cond, cfg_scale, batch_cfg, etc.
"""
t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype)
for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])):
dt = t_next - t_curr
t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device)
x = x + dt * model(x, t_curr_tensor, **extra_args)
if callback is not None:
callback({"i": i, "x": x})
return x
+62
View File
@@ -0,0 +1,62 @@
import torch
import torch.nn.functional as F
from torchaudio import transforms as T
def set_audio_channels(audio, target_channels):
"""Convert audio tensor to target number of channels.
Args:
audio: Audio tensor of shape [B, C, T]
target_channels: Desired number of channels (1 for mono, 2 for stereo)
Returns:
Audio tensor with the target number of channels.
"""
if target_channels == 1:
# Convert to mono
audio = audio.mean(1, keepdim=True)
elif target_channels == 2:
# Convert to stereo
if audio.shape[1] == 1:
audio = audio.repeat(1, 2, 1)
elif audio.shape[1] > 2:
audio = audio[:, :2, :]
return audio
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
"""Resample, pad/trim, and convert channels of an audio tensor.
Args:
audio: Audio tensor (1D, 2D [C, T], or 3D [B, C, T])
in_sr: Input sample rate
target_sr: Target sample rate
target_length: Target length in samples (padded or cropped)
target_channels: Target number of channels
device: Torch device to place the audio on
Returns:
Audio tensor of shape [B, target_channels, target_length] on device.
"""
audio = audio.to(device)
if in_sr != target_sr:
resample_tf = T.Resample(in_sr, target_sr).to(device)
audio = resample_tf(audio)
# Add batch dimension
if audio.dim() == 1:
audio = audio.unsqueeze(0).unsqueeze(0)
elif audio.dim() == 2:
audio = audio.unsqueeze(0)
# Pad or crop to target_length
if audio.shape[-1] < target_length:
audio = F.pad(audio, (0, target_length - audio.shape[-1]))
elif audio.shape[-1] > target_length:
audio = audio[:, :, :target_length]
audio = set_audio_channels(audio, target_channels)
return audio