fix(ti-trainer): fix gradient flow and spectral metric shapes
- Replace in-place text_clip assignment with torch.cat so the computation graph correctly links text_input → learned_tokens; in-place assignment into a requires_grad=False leaf severs the graph and learned_tokens receives no gradients - _spectral_metrics(wav, sr): was passing wav.unsqueeze(0) [1,1,L] instead of wav [1,L]; stft mean(dim=1) would return wrong shape [1,T] not [n_freqs] - _save_spectrogram(wav, sr, ...): was passing wav.squeeze(0) [L] (1D) instead of wav [1,L] as the function expects Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -282,15 +282,19 @@ class SelvaTextualInversionTrainer:
|
||||
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
||||
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype).clone()
|
||||
|
||||
# Inject learned tokens into last n_tokens positions
|
||||
text_clip[:, -n_tokens:, :] = learned_tokens.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
# Inject learned tokens into last n_tokens positions.
|
||||
# Must use torch.cat (not in-place assignment) so the computation graph
|
||||
# links text_input → learned_tokens and gradients flow correctly.
|
||||
text_front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D], no grad
|
||||
tokens_expanded = learned_tokens.unsqueeze(0).expand(batch_size, -1, -1) # [B, K, D]
|
||||
text_input = torch.cat([text_front, tokens_expanded], dim=1) # [B, 77, D] with grad
|
||||
|
||||
x1 = generator.normalize(x1)
|
||||
t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||
x0 = torch.randn_like(x1)
|
||||
xt = fm.get_conditional_flow(x0, x1, t)
|
||||
|
||||
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
|
||||
v_pred = generator.forward(xt, clip_f, sync_f, text_input, t)
|
||||
loss = fm.loss(v_pred, x0, x1).mean()
|
||||
loss.backward()
|
||||
|
||||
@@ -340,9 +344,9 @@ class SelvaTextualInversionTrainer:
|
||||
import soundfile as sf
|
||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||||
|
||||
metrics = _spectral_metrics(wav.unsqueeze(0), sr)
|
||||
metrics = _spectral_metrics(wav, sr)
|
||||
if metrics:
|
||||
img = _save_spectrogram(wav.squeeze(0), sr, ckpt_dir / f"step_{step:05d}.png")
|
||||
img = _save_spectrogram(wav, sr, ckpt_dir / f"step_{step:05d}.png")
|
||||
print(f"[TI Trainer] step {step} "
|
||||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
||||
f"flatness={metrics['spectral_flatness']:.4f} "
|
||||
|
||||
Reference in New Issue
Block a user