feat: extract prismaudio_core inference with callback-enabled sampling
Add inference subpackage with: - sampling.py: sample_discrete_euler modified from upstream to add callback parameter for ComfyUI progress bars (uses enumerate for step index) - utils.py: set_audio_channels and prepare_audio for audio preprocessing Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,4 @@
|
|||||||
|
from .sampling import sample_discrete_euler
|
||||||
|
from .utils import set_audio_channels, prepare_audio
|
||||||
|
|
||||||
|
__all__ = ["sample_discrete_euler", "set_audio_channels", "prepare_audio"]
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args):
|
||||||
|
"""Discrete Euler sampler for rectified flow, with optional callback.
|
||||||
|
|
||||||
|
Modified from PrismAudio to add callback parameter for ComfyUI progress reporting.
|
||||||
|
Original uses tqdm internally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The diffusion model (DiTWrapper)
|
||||||
|
x: Initial noise tensor [B, C, T]
|
||||||
|
steps: Number of sampling steps
|
||||||
|
sigma_max: Maximum sigma (default 1.0 for rectified flow)
|
||||||
|
callback: Optional callable({"i": step, "x": current_x}) for progress
|
||||||
|
**extra_args: Passed to model() — includes cross_attn_cond, add_cond,
|
||||||
|
sync_cond, cfg_scale, batch_cfg, etc.
|
||||||
|
"""
|
||||||
|
t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])):
|
||||||
|
dt = t_next - t_curr
|
||||||
|
t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device)
|
||||||
|
x = x + dt * model(x, t_curr_tensor, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({"i": i, "x": x})
|
||||||
|
|
||||||
|
return x
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchaudio import transforms as T
|
||||||
|
|
||||||
|
|
||||||
|
def set_audio_channels(audio, target_channels):
|
||||||
|
"""Convert audio tensor to target number of channels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio tensor of shape [B, C, T]
|
||||||
|
target_channels: Desired number of channels (1 for mono, 2 for stereo)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Audio tensor with the target number of channels.
|
||||||
|
"""
|
||||||
|
if target_channels == 1:
|
||||||
|
# Convert to mono
|
||||||
|
audio = audio.mean(1, keepdim=True)
|
||||||
|
elif target_channels == 2:
|
||||||
|
# Convert to stereo
|
||||||
|
if audio.shape[1] == 1:
|
||||||
|
audio = audio.repeat(1, 2, 1)
|
||||||
|
elif audio.shape[1] > 2:
|
||||||
|
audio = audio[:, :2, :]
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
||||||
|
"""Resample, pad/trim, and convert channels of an audio tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio tensor (1D, 2D [C, T], or 3D [B, C, T])
|
||||||
|
in_sr: Input sample rate
|
||||||
|
target_sr: Target sample rate
|
||||||
|
target_length: Target length in samples (padded or cropped)
|
||||||
|
target_channels: Target number of channels
|
||||||
|
device: Torch device to place the audio on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Audio tensor of shape [B, target_channels, target_length] on device.
|
||||||
|
"""
|
||||||
|
audio = audio.to(device)
|
||||||
|
|
||||||
|
if in_sr != target_sr:
|
||||||
|
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
||||||
|
audio = resample_tf(audio)
|
||||||
|
|
||||||
|
# Add batch dimension
|
||||||
|
if audio.dim() == 1:
|
||||||
|
audio = audio.unsqueeze(0).unsqueeze(0)
|
||||||
|
elif audio.dim() == 2:
|
||||||
|
audio = audio.unsqueeze(0)
|
||||||
|
|
||||||
|
# Pad or crop to target_length
|
||||||
|
if audio.shape[-1] < target_length:
|
||||||
|
audio = F.pad(audio, (0, target_length - audio.shape[-1]))
|
||||||
|
elif audio.shape[-1] > target_length:
|
||||||
|
audio = audio[:, :, :target_length]
|
||||||
|
|
||||||
|
audio = set_audio_channels(audio, target_channels)
|
||||||
|
|
||||||
|
return audio
|
||||||
Reference in New Issue
Block a user