feat: add inject_mode (suffix/prefix) to TI pipeline

Observation: n4_baseline loss barely moved (1.025→0.965 over 3000 steps),
token_norm grew linearly without plateau — generator likely ignores last-K
CLIP positions (EOS/padding zone) where suffix injects.

Fix: add inject_mode parameter throughout the pipeline:
- "suffix": replace last K positions (original behavior, model may ignore)
- "prefix": replace positions 1:1+K right after BOS — highest attention
  weight in CLIP, much stronger gradient signal expected

Changes:
- selva_textual_inversion_trainer.py: _inject_tokens() helper centralises
  the torch.cat construction for both modes; used in training loop and eval;
  inject_mode stored in checkpoint files
- selva_textual_inversion_loader.py: reads inject_mode from checkpoint,
  includes in TEXTUAL_INVERSION bundle
- selva_sampler.py: uses _inject_tokens() via bundle's inject_mode field
- selva_ti_scheduler.py: inject_mode in _PARAM_DEFAULTS, config, and
  _train_inner call
- ti_sweep_1.json: updated with prefix_inject group (n4, n8, n4+warm);
  n4_baseline marked completed; suffix experiments retained for comparison

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-08 23:31:52 +02:00
parent f96265da23
commit e1a2f0ed7d
5 changed files with 105 additions and 59 deletions
+54 -25
View File
@@ -42,7 +42,29 @@ from .selva_lora_trainer import (
# Eval helper with token injection
# ---------------------------------------------------------------------------
def _eval_sample_ti(generator, learned_tokens, n_tokens,
def _inject_tokens(text_clip: torch.Tensor, tokens: torch.Tensor,
n_tokens: int, inject_mode: str) -> torch.Tensor:
"""Build a text_clip tensor with learned tokens injected.
inject_mode:
"suffix" — replace last n_tokens positions (EOS/padding zone)
"prefix" — replace positions 1:1+n_tokens (after BOS, before content)
Always uses torch.cat so gradient flows to `tokens` when tokens.requires_grad.
Works for both training (tokens is a Parameter) and eval (tokens is detached).
"""
if inject_mode == "prefix":
bos = text_clip[:, :1, :].detach() # [B, 1, D]
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
rest = text_clip[:, 1 + n_tokens:, :].detach() # [B, 75-K, D]
return torch.cat([bos, toks, rest], dim=1) # [B, 77, D]
else: # suffix (default)
front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D]
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
return torch.cat([front, toks], dim=1) # [B, 77, D]
def _eval_sample_ti(generator, learned_tokens, n_tokens, inject_mode,
feature_utils_orig, dataset, seq_cfg,
device, dtype, num_steps=25, seed=42, clip_idx=0):
"""Inference pass with learned tokens injected into text conditioning."""
@@ -53,7 +75,8 @@ def _eval_sample_ti(generator, learned_tokens, n_tokens,
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype).clone()
text_clip[:, -n_tokens:, :] = learned_tokens.detach().unsqueeze(0).to(device, dtype)
emb = learned_tokens.detach().to(device, dtype)
text_input = _inject_tokens(text_clip, emb, n_tokens, inject_mode)
rng = torch.Generator(device=device).manual_seed(seed)
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
@@ -62,7 +85,7 @@ def _eval_sample_ti(generator, learned_tokens, n_tokens,
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
def velocity_fn(t, x):
return generator.forward(x, clip_f, sync_f, text_clip,
return generator.forward(x, clip_f, sync_f, text_input,
t.reshape(1).to(device, dtype))
with torch.no_grad():
@@ -165,6 +188,14 @@ class SelvaTextualInversionTrainer:
}),
},
"optional": {
"inject_mode": (["suffix", "prefix"], {
"default": "suffix",
"tooltip": (
"Where to inject the learned tokens in the 77-token CLIP sequence. "
"'suffix' replaces the last K positions (EOS/padding — may be ignored by the model). "
"'prefix' replaces positions 1:1+K right after BOS — higher attention weight, stronger style signal."
),
}),
"init_text": ("STRING", {
"default": "",
"tooltip": "Optional text phrase to warm-start token values via CLIP. Leave empty for random init (N(0, 0.02)). Example: 'industrial sound design'.",
@@ -178,7 +209,7 @@ class SelvaTextualInversionTrainer:
def train(self, model, data_dir, output_path, n_tokens, steps, lr,
batch_size, seed, save_every,
init_text="", warmup_steps=100):
inject_mode="suffix", init_text="", warmup_steps=100):
device = get_device()
dtype = model["dtype"]
@@ -212,7 +243,7 @@ class SelvaTextualInversionTrainer:
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup_steps, seed, save_every, init_text,
warmup_steps, seed, save_every, init_text, inject_mode,
)
smoothed = _smooth_losses(r["loss_history"]) if r["loss_history"] else []
curve_img = _draw_loss_curve(r["loss_history"], log_interval=50, smoothed=smoothed)
@@ -223,7 +254,7 @@ class SelvaTextualInversionTrainer:
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup_steps, seed, save_every, init_text,
warmup_steps, seed, save_every, init_text, inject_mode="suffix",
):
torch.manual_seed(seed)
@@ -289,12 +320,8 @@ class SelvaTextualInversionTrainer:
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype).clone()
# Inject learned tokens into last n_tokens positions.
# Must use torch.cat (not in-place assignment) so the computation graph
# links text_input → learned_tokens and gradients flow correctly.
text_front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D], no grad
tokens_expanded = learned_tokens.unsqueeze(0).expand(batch_size, -1, -1) # [B, K, D]
text_input = torch.cat([text_front, tokens_expanded], dim=1) # [B, 77, D] with grad
# Inject learned tokens — gradient flows via torch.cat (not in-place assignment).
text_input = _inject_tokens(text_clip, learned_tokens, n_tokens, inject_mode)
x1 = generator.normalize(x1)
t = torch.rand(batch_size, device=device, dtype=dtype)
@@ -326,12 +353,13 @@ class SelvaTextualInversionTrainer:
if step % save_every == 0 or step == steps:
# Save checkpoint
ckpt = {
"embeddings": learned_tokens.detach().cpu(),
"n_tokens": n_tokens,
"step": step,
"init_text": init_text,
"lr": lr,
"steps": steps,
"embeddings": learned_tokens.detach().cpu(),
"n_tokens": n_tokens,
"inject_mode": inject_mode,
"step": step,
"init_text": init_text,
"lr": lr,
"steps": steps,
"loss_history": loss_history,
}
ckpt_path = ckpt_dir / f"step_{step:05d}.pt"
@@ -339,7 +367,7 @@ class SelvaTextualInversionTrainer:
# Eval sample
wav, sr = _eval_sample_ti(
generator, learned_tokens, n_tokens,
generator, learned_tokens, n_tokens, inject_mode,
feature_utils_orig, dataset, seq_cfg,
device, dtype, seed=seed,
)
@@ -365,12 +393,13 @@ class SelvaTextualInversionTrainer:
# --- Final save ---
final = {
"embeddings": learned_tokens.detach().cpu(),
"n_tokens": n_tokens,
"step": steps,
"init_text": init_text,
"lr": lr,
"steps": steps,
"embeddings": learned_tokens.detach().cpu(),
"n_tokens": n_tokens,
"inject_mode": inject_mode,
"step": steps,
"init_text": init_text,
"lr": lr,
"steps": steps,
"loss_history": loss_history,
}
torch.save(final, str(out_path))