8b634923dd
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
30 lines
1.1 KiB
Python
30 lines
1.1 KiB
Python
import torch
|
|
|
|
|
|
@torch.no_grad()
|
|
def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args):
|
|
"""Discrete Euler sampler for rectified flow, with optional callback.
|
|
|
|
Modified from PrismAudio to add callback parameter for ComfyUI progress reporting.
|
|
Original uses tqdm internally.
|
|
|
|
Args:
|
|
model: The diffusion model (DiTWrapper)
|
|
x: Initial noise tensor [B, C, T]
|
|
steps: Number of sampling steps
|
|
sigma_max: Maximum sigma (default 1.0 for rectified flow)
|
|
callback: Optional callable({"i": step, "x": current_x}) for progress
|
|
**extra_args: Passed to model() — includes cross_attn_cond, add_cond,
|
|
sync_cond, cfg_scale, batch_cfg, etc.
|
|
"""
|
|
t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype)
|
|
|
|
for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])):
|
|
dt = t_next - t_curr
|
|
t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device)
|
|
x = x + dt * model(x, t_curr_tensor, **extra_args)
|
|
if callback is not None:
|
|
callback({"i": i, "x": x})
|
|
|
|
return x
|