fix: DITTO gradient never reached x0, remove unused imports and dead code
DITTO critical bug: x was reassigned on every ODE step, so by the time loss.backward() ran, x pointed to the final output tensor (grad_fn, not a leaf) and x.grad was always None. The manual gradient transfer never fired — x0 was never updated. The optimization was a no-op. Fix: use a straight-through estimator after the no-grad prefix: x = x + (x0 - x0.detach()) This adds zero value but creates a grad_fn back to x0, so backward() propagates ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad. Equivalent to truncated BPTT with ∂x_prefix/∂x0 ≈ I. Also remove unused imports (SelvaSampler, _inject_tokens, random) that caused cascade ImportError risk, and remove dead trainable_count variable in BigVGAN trainer. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -404,11 +404,6 @@ class SelvaBigvganTrainer:
|
|||||||
f"[BigVGAN] No usable clips found (need audio >= {segment_seconds}s)"
|
f"[BigVGAN] No usable clips found (need audio >= {segment_seconds}s)"
|
||||||
)
|
)
|
||||||
|
|
||||||
trainable_count = sum(
|
|
||||||
1 for n, _ in vocoder.named_parameters() if "alpha" in n
|
|
||||||
) if train_mode == "snake_alpha_only" else sum(
|
|
||||||
1 for _ in vocoder.parameters()
|
|
||||||
)
|
|
||||||
print(f"[BigVGAN] {len(clips)} clips ready mode={train_mode} "
|
print(f"[BigVGAN] {len(clips)} clips ready mode={train_mode} "
|
||||||
f"segment={segment_seconds}s steps={steps} lr={lr} "
|
f"segment={segment_seconds}s steps={steps} lr={lr} "
|
||||||
f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True)
|
f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True)
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ step's activations on demand.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import random
|
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -29,8 +28,6 @@ import comfy.model_management
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||||
from .selva_sampler import SelvaSampler
|
|
||||||
from .selva_textual_inversion_trainer import _inject_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def _load_wav(path):
|
def _load_wav(path):
|
||||||
@@ -329,26 +326,22 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
|||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
# ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ──
|
# ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ──
|
||||||
# This is cheaper than checkpointing all steps, at the cost of an
|
# Detach from x0 so Phase 1 does not build a computation graph.
|
||||||
# approximate (truncated) gradient. The gradient still flows through
|
|
||||||
# n_grad_steps steps, which is sufficient for meaningful x_0 updates.
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
x = x0
|
x = x0.detach()
|
||||||
for i in range(n_free_steps):
|
for i in range(n_free_steps):
|
||||||
t = ts[i]
|
t = ts[i]
|
||||||
dt = ts[i + 1] - t
|
dt = ts[i + 1] - t
|
||||||
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
||||||
x = x + dt * flow
|
x = x + dt * flow
|
||||||
|
|
||||||
# Detach and re-leaf so backward only goes n_grad_steps deep.
|
# Straight-through estimator: reconnect x to x0's gradient path by
|
||||||
# We treat x_k as a new leaf but seed it from x_0's value — so at
|
# adding the zero tensor (x0 - x0.detach()). This adds zero value but
|
||||||
# opt step 0 the gradient is a true n_grad_steps truncated BPTT,
|
# creates a grad_fn pointing back to x0, so loss.backward() will
|
||||||
# and x_0 gets updated via x_k's dependence on x_0 through the
|
# propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
|
||||||
# no-grad prefix (approximation: gradient doesn't flow through prefix).
|
# The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is
|
||||||
#
|
# treated as identity for gradient purposes (truncated BPTT).
|
||||||
# Richer alternative: full checkpointing through all steps (uncomment
|
x = x + (x0 - x0.detach())
|
||||||
# the checkpoint block below and remove the no-grad prefix).
|
|
||||||
x = x.detach().requires_grad_(True)
|
|
||||||
|
|
||||||
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
|
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
|
||||||
for i in range(n_free_steps, n_ode_steps):
|
for i in range(n_free_steps, n_ode_steps):
|
||||||
@@ -373,20 +366,7 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
|||||||
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)
|
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
||||||
|
|
||||||
# Propagate gradient from x (grad_fn leaf) back to x_0.
|
|
||||||
# x was detached from x_0, so we manually transfer the gradient:
|
|
||||||
# the no-grad prefix is an approximation — skip this if doing full
|
|
||||||
# checkpointing (x would have grad_fn pointing back to x_0).
|
|
||||||
# Here x.grad is the gradient w.r.t. x at step n_free_steps;
|
|
||||||
# we directly add it to x_0.grad as an approximation.
|
|
||||||
if x.grad is not None:
|
|
||||||
if x0.grad is None:
|
|
||||||
x0.grad = x.grad.clone()
|
|
||||||
else:
|
|
||||||
x0.grad.add_(x.grad)
|
|
||||||
|
|
||||||
torch.nn.utils.clip_grad_norm_([x0], 1.0)
|
torch.nn.utils.clip_grad_norm_([x0], 1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user