From 211494a91c79a6c903c582d7b7e2cf28e2997267 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 12:10:02 +0200 Subject: [PATCH] fix: DITTO gradient never reached x0, remove unused imports and dead code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DITTO critical bug: x was reassigned on every ODE step, so by the time loss.backward() ran, x pointed to the final output tensor (grad_fn, not a leaf) and x.grad was always None. The manual gradient transfer never fired — x0 was never updated. The optimization was a no-op. Fix: use a straight-through estimator after the no-grad prefix: x = x + (x0 - x0.detach()) This adds zero value but creates a grad_fn back to x0, so backward() propagates ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad. Equivalent to truncated BPTT with ∂x_prefix/∂x0 ≈ I. Also remove unused imports (SelvaSampler, _inject_tokens, random) that caused cascade ImportError risk, and remove dead trainable_count variable in BigVGAN trainer. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_bigvgan_trainer.py | 5 ----- nodes/selva_ditto_optimizer.py | 40 +++++++++------------------------- 2 files changed, 10 insertions(+), 35 deletions(-) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 0bc2a11..3aa2d67 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -404,11 +404,6 @@ class SelvaBigvganTrainer: f"[BigVGAN] No usable clips found (need audio >= {segment_seconds}s)" ) - trainable_count = sum( - 1 for n, _ in vocoder.named_parameters() if "alpha" in n - ) if train_mode == "snake_alpha_only" else sum( - 1 for _ in vocoder.parameters() - ) print(f"[BigVGAN] {len(clips)} clips ready mode={train_mode} " f"segment={segment_seconds}s steps={steps} lr={lr} " f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True) diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index cb0fdfd..00fb174 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -17,7 +17,6 @@ step's activations on demand. """ import dataclasses -import random import threading from pathlib import Path @@ -29,8 +28,6 @@ import comfy.model_management import folder_paths from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache -from .selva_sampler import SelvaSampler -from .selva_textual_inversion_trainer import _inject_tokens def _load_wav(path): @@ -329,26 +326,22 @@ def _do_optimize(net_generator, feature_utils, mel_converter, comfy.model_management.throw_exception_if_processing_interrupted() # ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ── - # This is cheaper than checkpointing all steps, at the cost of an - # approximate (truncated) gradient. The gradient still flows through - # n_grad_steps steps, which is sufficient for meaningful x_0 updates. + # Detach from x0 so Phase 1 does not build a computation graph. with torch.no_grad(): - x = x0 + x = x0.detach() for i in range(n_free_steps): t = ts[i] dt = ts[i + 1] - t flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength) x = x + dt * flow - # Detach and re-leaf so backward only goes n_grad_steps deep. - # We treat x_k as a new leaf but seed it from x_0's value — so at - # opt step 0 the gradient is a true n_grad_steps truncated BPTT, - # and x_0 gets updated via x_k's dependence on x_0 through the - # no-grad prefix (approximation: gradient doesn't flow through prefix). - # - # Richer alternative: full checkpointing through all steps (uncomment - # the checkpoint block below and remove the no-grad prefix). - x = x.detach().requires_grad_(True) + # Straight-through estimator: reconnect x to x0's gradient path by + # adding the zero tensor (x0 - x0.detach()). This adds zero value but + # creates a grad_fn pointing back to x0, so loss.backward() will + # propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad. + # The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is + # treated as identity for gradient purposes (truncated BPTT). + x = x + (x0 - x0.detach()) # ── Phase 2: run last n_grad_steps with gradient + checkpointing ── for i in range(n_free_steps, n_ode_steps): @@ -373,20 +366,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter, loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram) optimizer.zero_grad() - loss.backward() - - # Propagate gradient from x (grad_fn leaf) back to x_0. - # x was detached from x_0, so we manually transfer the gradient: - # the no-grad prefix is an approximation — skip this if doing full - # checkpointing (x would have grad_fn pointing back to x_0). - # Here x.grad is the gradient w.r.t. x at step n_free_steps; - # we directly add it to x_0.grad as an approximation. - if x.grad is not None: - if x0.grad is None: - x0.grad = x.grad.clone() - else: - x0.grad.add_(x.grad) - + loss.backward() # gradient flows through Phase 2 + STE back to x0.grad torch.nn.utils.clip_grad_norm_([x0], 1.0) optimizer.step()