fix: DITTO gradient never reached x0, remove unused imports and dead code

DITTO critical bug: x was reassigned on every ODE step, so by the time
loss.backward() ran, x pointed to the final output tensor (grad_fn, not
a leaf) and x.grad was always None. The manual gradient transfer never
fired — x0 was never updated. The optimization was a no-op.

Fix: use a straight-through estimator after the no-grad prefix:
  x = x + (x0 - x0.detach())
This adds zero value but creates a grad_fn back to x0, so backward()
propagates ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
Equivalent to truncated BPTT with ∂x_prefix/∂x0 ≈ I.

Also remove unused imports (SelvaSampler, _inject_tokens, random) that
caused cascade ImportError risk, and remove dead trainable_count variable
in BigVGAN trainer.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 12:10:02 +02:00
parent 1e9551152e
commit 211494a91c
2 changed files with 10 additions and 35 deletions
-5
View File
@@ -404,11 +404,6 @@ class SelvaBigvganTrainer:
f"[BigVGAN] No usable clips found (need audio >= {segment_seconds}s)"
)
trainable_count = sum(
1 for n, _ in vocoder.named_parameters() if "alpha" in n
) if train_mode == "snake_alpha_only" else sum(
1 for _ in vocoder.parameters()
)
print(f"[BigVGAN] {len(clips)} clips ready mode={train_mode} "
f"segment={segment_seconds}s steps={steps} lr={lr} "
f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True)
+10 -30
View File
@@ -17,7 +17,6 @@ step's activations on demand.
"""
import dataclasses
import random
import threading
from pathlib import Path
@@ -29,8 +28,6 @@ import comfy.model_management
import folder_paths
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
from .selva_sampler import SelvaSampler
from .selva_textual_inversion_trainer import _inject_tokens
def _load_wav(path):
@@ -329,26 +326,22 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
comfy.model_management.throw_exception_if_processing_interrupted()
# ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ──
# This is cheaper than checkpointing all steps, at the cost of an
# approximate (truncated) gradient. The gradient still flows through
# n_grad_steps steps, which is sufficient for meaningful x_0 updates.
# Detach from x0 so Phase 1 does not build a computation graph.
with torch.no_grad():
x = x0
x = x0.detach()
for i in range(n_free_steps):
t = ts[i]
dt = ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
# Detach and re-leaf so backward only goes n_grad_steps deep.
# We treat x_k as a new leaf but seed it from x_0's value — so at
# opt step 0 the gradient is a true n_grad_steps truncated BPTT,
# and x_0 gets updated via x_k's dependence on x_0 through the
# no-grad prefix (approximation: gradient doesn't flow through prefix).
#
# Richer alternative: full checkpointing through all steps (uncomment
# the checkpoint block below and remove the no-grad prefix).
x = x.detach().requires_grad_(True)
# Straight-through estimator: reconnect x to x0's gradient path by
# adding the zero tensor (x0 - x0.detach()). This adds zero value but
# creates a grad_fn pointing back to x0, so loss.backward() will
# propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
# The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is
# treated as identity for gradient purposes (truncated BPTT).
x = x + (x0 - x0.detach())
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
for i in range(n_free_steps, n_ode_steps):
@@ -373,20 +366,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)
optimizer.zero_grad()
loss.backward()
# Propagate gradient from x (grad_fn leaf) back to x_0.
# x was detached from x_0, so we manually transfer the gradient:
# the no-grad prefix is an approximation — skip this if doing full
# checkpointing (x would have grad_fn pointing back to x_0).
# Here x.grad is the gradient w.r.t. x at step n_free_steps;
# we directly add it to x_0.grad as an approximation.
if x.grad is not None:
if x0.grad is None:
x0.grad = x.grad.clone()
else:
x0.grad.add_(x.grad)
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
torch.nn.utils.clip_grad_norm_([x0], 1.0)
optimizer.step()