diff --git a/prismaudio_core/inference/__init__.py b/prismaudio_core/inference/__init__.py new file mode 100644 index 0000000..9160888 --- /dev/null +++ b/prismaudio_core/inference/__init__.py @@ -0,0 +1,4 @@ +from .sampling import sample_discrete_euler +from .utils import set_audio_channels, prepare_audio + +__all__ = ["sample_discrete_euler", "set_audio_channels", "prepare_audio"] diff --git a/prismaudio_core/inference/sampling.py b/prismaudio_core/inference/sampling.py new file mode 100644 index 0000000..18f66c0 --- /dev/null +++ b/prismaudio_core/inference/sampling.py @@ -0,0 +1,30 @@ +import torch +from tqdm import trange + + +@torch.no_grad() +def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args): + """Discrete Euler sampler for rectified flow, with optional callback. + + Modified from PrismAudio to add callback parameter for ComfyUI progress reporting. + Original uses tqdm internally. + + Args: + model: The diffusion model (DiTWrapper) + x: Initial noise tensor [B, C, T] + steps: Number of sampling steps + sigma_max: Maximum sigma (default 1.0 for rectified flow) + callback: Optional callable({"i": step, "x": current_x}) for progress + **extra_args: Passed to model() — includes cross_attn_cond, add_cond, + sync_cond, cfg_scale, batch_cfg, etc. + """ + t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype) + + for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])): + dt = t_next - t_curr + t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device) + x = x + dt * model(x, t_curr_tensor, **extra_args) + if callback is not None: + callback({"i": i, "x": x}) + + return x diff --git a/prismaudio_core/inference/utils.py b/prismaudio_core/inference/utils.py new file mode 100644 index 0000000..c47c97b --- /dev/null +++ b/prismaudio_core/inference/utils.py @@ -0,0 +1,62 @@ +import torch +import torch.nn.functional as F +from torchaudio import transforms as T + + +def set_audio_channels(audio, target_channels): + """Convert audio tensor to target number of channels. + + Args: + audio: Audio tensor of shape [B, C, T] + target_channels: Desired number of channels (1 for mono, 2 for stereo) + + Returns: + Audio tensor with the target number of channels. + """ + if target_channels == 1: + # Convert to mono + audio = audio.mean(1, keepdim=True) + elif target_channels == 2: + # Convert to stereo + if audio.shape[1] == 1: + audio = audio.repeat(1, 2, 1) + elif audio.shape[1] > 2: + audio = audio[:, :2, :] + return audio + + +def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): + """Resample, pad/trim, and convert channels of an audio tensor. + + Args: + audio: Audio tensor (1D, 2D [C, T], or 3D [B, C, T]) + in_sr: Input sample rate + target_sr: Target sample rate + target_length: Target length in samples (padded or cropped) + target_channels: Target number of channels + device: Torch device to place the audio on + + Returns: + Audio tensor of shape [B, target_channels, target_length] on device. + """ + audio = audio.to(device) + + if in_sr != target_sr: + resample_tf = T.Resample(in_sr, target_sr).to(device) + audio = resample_tf(audio) + + # Add batch dimension + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + elif audio.dim() == 2: + audio = audio.unsqueeze(0) + + # Pad or crop to target_length + if audio.shape[-1] < target_length: + audio = F.pad(audio, (0, target_length - audio.shape[-1])) + elif audio.shape[-1] > target_length: + audio = audio[:, :, :target_length] + + audio = set_audio_channels(audio, target_channels) + + return audio