diff --git a/nodes/selva_textual_inversion_trainer.py b/nodes/selva_textual_inversion_trainer.py index 41c0bdf..59767f4 100644 --- a/nodes/selva_textual_inversion_trainer.py +++ b/nodes/selva_textual_inversion_trainer.py @@ -282,15 +282,19 @@ class SelvaTextualInversionTrainer: sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype) text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype).clone() - # Inject learned tokens into last n_tokens positions - text_clip[:, -n_tokens:, :] = learned_tokens.unsqueeze(0).expand(batch_size, -1, -1) + # Inject learned tokens into last n_tokens positions. + # Must use torch.cat (not in-place assignment) so the computation graph + # links text_input → learned_tokens and gradients flow correctly. + text_front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D], no grad + tokens_expanded = learned_tokens.unsqueeze(0).expand(batch_size, -1, -1) # [B, K, D] + text_input = torch.cat([text_front, tokens_expanded], dim=1) # [B, 77, D] with grad x1 = generator.normalize(x1) t = torch.rand(batch_size, device=device, dtype=dtype) x0 = torch.randn_like(x1) xt = fm.get_conditional_flow(x0, x1, t) - v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t) + v_pred = generator.forward(xt, clip_f, sync_f, text_input, t) loss = fm.loss(v_pred, x0, x1).mean() loss.backward() @@ -340,9 +344,9 @@ class SelvaTextualInversionTrainer: import soundfile as sf sf.write(str(wav_path), wav.squeeze(0).numpy(), sr) - metrics = _spectral_metrics(wav.unsqueeze(0), sr) + metrics = _spectral_metrics(wav, sr) if metrics: - img = _save_spectrogram(wav.squeeze(0), sr, ckpt_dir / f"step_{step:05d}.png") + img = _save_spectrogram(wav, sr, ckpt_dir / f"step_{step:05d}.png") print(f"[TI Trainer] step {step} " f"centroid={metrics['spectral_centroid_hz']:.0f}Hz " f"flatness={metrics['spectral_flatness']:.4f} "