From e1a2f0ed7d4ced8d1bfc5f79b70abd41adc0d8e7 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Wed, 8 Apr 2026 23:31:52 +0200 Subject: [PATCH] feat: add inject_mode (suffix/prefix) to TI pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Observation: n4_baseline loss barely moved (1.025→0.965 over 3000 steps), token_norm grew linearly without plateau — generator likely ignores last-K CLIP positions (EOS/padding zone) where suffix injects. Fix: add inject_mode parameter throughout the pipeline: - "suffix": replace last K positions (original behavior, model may ignore) - "prefix": replace positions 1:1+K right after BOS — highest attention weight in CLIP, much stronger gradient signal expected Changes: - selva_textual_inversion_trainer.py: _inject_tokens() helper centralises the torch.cat construction for both modes; used in training loop and eval; inject_mode stored in checkpoint files - selva_textual_inversion_loader.py: reads inject_mode from checkpoint, includes in TEXTUAL_INVERSION bundle - selva_sampler.py: uses _inject_tokens() via bundle's inject_mode field - selva_ti_scheduler.py: inject_mode in _PARAM_DEFAULTS, config, and _train_inner call - ti_sweep_1.json: updated with prefix_inject group (n4, n8, n4+warm); n4_baseline marked completed; suffix experiments retained for comparison Co-Authored-By: Claude Sonnet 4.6 --- experiments/ti_sweep_1.json | 49 +++++++++------ nodes/selva_sampler.py | 16 ++--- nodes/selva_textual_inversion_loader.py | 12 ++-- nodes/selva_textual_inversion_trainer.py | 79 ++++++++++++++++-------- nodes/selva_ti_scheduler.py | 8 ++- 5 files changed, 105 insertions(+), 59 deletions(-) diff --git a/experiments/ti_sweep_1.json b/experiments/ti_sweep_1.json index 034b2a8..40f4ab5 100644 --- a/experiments/ti_sweep_1.json +++ b/experiments/ti_sweep_1.json @@ -1,6 +1,6 @@ { "name": "ti_sweep_1", - "description": "First TI sweep: token count, learning rate, and warm init. All generator weights frozen throughout. Baseline = n_tokens=4, lr=1e-3, random init. Primary goal: find a working (n_tokens, lr) pair before optimising further.", + "description": "First TI sweep: inject position, token count, learning rate, and warm init. n4_baseline already completed (suffix, loss barely moved — model likely ignores last-K positions). Priority: prefix injection group.", "data_dir": "/media/unraid/davinci/Selva/BJ/features", "output_root": "/media/unraid/davinci/Selva/BJ/experiment/ti_sweep_1", "base": { @@ -11,52 +11,61 @@ "seed": 42, "init_text": "", "lr": 1e-3, - "n_tokens": 4 + "n_tokens": 4, + "inject_mode": "suffix" }, "experiments": [ { "id": "n4_baseline", - "group": "token_count", - "description": "4 tokens, lr=1e-3, random init. Primary reference point — all other experiments are measured against this." + "group": "suffix_token_count", + "description": "4 tokens, suffix, lr=1e-3, random init. COMPLETED — loss 1.025→0.965, nearly flat. Token norm grew linearly to 3.2 with no plateau. Model appears to ignore last-K positions." }, { "id": "n8", - "group": "token_count", - "description": "8 tokens, lr=1e-3, random init. Double the capacity — does it capture more style or just overfit faster?", + "group": "suffix_token_count", + "description": "8 tokens, suffix, lr=1e-3. More capacity — does it do better than n4_baseline?", "n_tokens": 8 }, + { - "id": "n16", - "group": "token_count", - "description": "16 tokens, lr=1e-3, random init. Maximum expressiveness — worth the extra convergence difficulty?", - "n_tokens": 16 + "id": "n4_prefix", + "group": "prefix_inject", + "description": "4 tokens at positions 1:5 (after BOS). Prefix positions carry the highest attention weight in CLIP — should produce much stronger loss signal than suffix.", + "inject_mode": "prefix" + }, + { + "id": "n8_prefix", + "group": "prefix_inject", + "description": "8 tokens at prefix positions. More capacity + high-attention positions.", + "n_tokens": 8, + "inject_mode": "prefix" + }, + { + "id": "n4_prefix_warm", + "group": "prefix_inject", + "description": "4 tokens, prefix, warm-started from 'mechanical impact sound design'. Best of both: semantically meaningful start + strong gradient signal.", + "inject_mode": "prefix", + "init_text": "mechanical impact sound design" }, { "id": "lr_5e4", "group": "learning_rate", - "description": "n_tokens=4, lr=5e-4. Half the default LR — smoother convergence, possibly better generalisation.", + "description": "4 tokens, suffix, lr=5e-4. Slower convergence — mainly a baseline comparison for the prefix group.", "lr": 5e-4 }, { "id": "lr_2e3", "group": "learning_rate", - "description": "n_tokens=4, lr=2e-3. Double the default LR — faster early convergence, risk of oscillation.", + "description": "4 tokens, suffix, lr=2e-3. Faster early movement — does token norm plateau earlier?", "lr": 2e-3 }, { "id": "n4_warm", "group": "warm_init", - "description": "4 tokens warm-started from 'mechanical impact sound design'. CLIP embedding initialises tokens in a semantically relevant region of the space — may converge faster and to a better style representation.", - "init_text": "mechanical impact sound design" - }, - { - "id": "n8_warm", - "group": "warm_init", - "description": "8 tokens warm-started from 'mechanical impact sound design'. Combines the warm-init advantage with more expressive capacity.", - "n_tokens": 8, + "description": "4 tokens, suffix, warm-started from 'mechanical impact sound design'.", "init_text": "mechanical impact sound design" } diff --git a/nodes/selva_sampler.py b/nodes/selva_sampler.py index 1bae020..134af12 100644 --- a/nodes/selva_sampler.py +++ b/nodes/selva_sampler.py @@ -3,6 +3,7 @@ import comfy.utils import comfy.model_management from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache +from .selva_textual_inversion_trainer import _inject_tokens class SelvaSampler: @@ -118,16 +119,15 @@ class SelvaSampler: neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \ if negative_prompt.strip() else None - # Inject textual inversion tokens into last K positions of CLIP embedding + # Inject textual inversion tokens into CLIP conditioning if textual_inversion is not None: - emb = textual_inversion["embeddings"].to(device, dtype) # [K, 1024] - K = emb.shape[0] - text_clip = text_clip.clone() - text_clip[:, -K:, :] = emb.unsqueeze(0) + emb = textual_inversion["embeddings"].to(device, dtype) # [K, 1024] + K = emb.shape[0] + inject_mode = textual_inversion.get("inject_mode", "suffix") + text_clip = _inject_tokens(text_clip, emb, K, inject_mode) if neg_text_clip is not None: - neg_text_clip = neg_text_clip.clone() - neg_text_clip[:, -K:, :] = emb.unsqueeze(0) - print(f"[SelVA] Textual inversion: injected {K} tokens into CLIP conditioning", + neg_text_clip = _inject_tokens(neg_text_clip, emb, K, inject_mode) + print(f"[SelVA] Textual inversion: {K} tokens mode={inject_mode}", flush=True) conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip) diff --git a/nodes/selva_textual_inversion_loader.py b/nodes/selva_textual_inversion_loader.py index 2626000..dae900e 100644 --- a/nodes/selva_textual_inversion_loader.py +++ b/nodes/selva_textual_inversion_loader.py @@ -57,10 +57,14 @@ class SelvaTextualInversionLoader: print(f"[TI Loader] trained {data['step']} / {data.get('steps', '?')} steps " f"lr={data.get('lr', '?')}", flush=True) + inject_mode = data.get("inject_mode", "suffix") + print(f"[TI Loader] inject_mode='{inject_mode}'", flush=True) + bundle = { - "embeddings": embeddings, # [K, 1024] float32 on CPU - "n_tokens": n_tokens, - "path": str(p), - "init_text": data.get("init_text", ""), + "embeddings": embeddings, # [K, 1024] float32 on CPU + "n_tokens": n_tokens, + "inject_mode": inject_mode, + "path": str(p), + "init_text": data.get("init_text", ""), } return (bundle,) diff --git a/nodes/selva_textual_inversion_trainer.py b/nodes/selva_textual_inversion_trainer.py index ab79ff6..8c2e7d7 100644 --- a/nodes/selva_textual_inversion_trainer.py +++ b/nodes/selva_textual_inversion_trainer.py @@ -42,7 +42,29 @@ from .selva_lora_trainer import ( # Eval helper with token injection # --------------------------------------------------------------------------- -def _eval_sample_ti(generator, learned_tokens, n_tokens, +def _inject_tokens(text_clip: torch.Tensor, tokens: torch.Tensor, + n_tokens: int, inject_mode: str) -> torch.Tensor: + """Build a text_clip tensor with learned tokens injected. + + inject_mode: + "suffix" — replace last n_tokens positions (EOS/padding zone) + "prefix" — replace positions 1:1+n_tokens (after BOS, before content) + + Always uses torch.cat so gradient flows to `tokens` when tokens.requires_grad. + Works for both training (tokens is a Parameter) and eval (tokens is detached). + """ + if inject_mode == "prefix": + bos = text_clip[:, :1, :].detach() # [B, 1, D] + toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D] + rest = text_clip[:, 1 + n_tokens:, :].detach() # [B, 75-K, D] + return torch.cat([bos, toks, rest], dim=1) # [B, 77, D] + else: # suffix (default) + front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D] + toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D] + return torch.cat([front, toks], dim=1) # [B, 77, D] + + +def _eval_sample_ti(generator, learned_tokens, n_tokens, inject_mode, feature_utils_orig, dataset, seq_cfg, device, dtype, num_steps=25, seed=42, clip_idx=0): """Inference pass with learned tokens injected into text conditioning.""" @@ -53,7 +75,8 @@ def _eval_sample_ti(generator, learned_tokens, n_tokens, sync_f = sync_f_cpu.to(device, dtype) text_clip = text_clip_cpu.to(device, dtype).clone() - text_clip[:, -n_tokens:, :] = learned_tokens.detach().unsqueeze(0).to(device, dtype) + emb = learned_tokens.detach().to(device, dtype) + text_input = _inject_tokens(text_clip, emb, n_tokens, inject_mode) rng = torch.Generator(device=device).manual_seed(seed) x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim, @@ -62,7 +85,7 @@ def _eval_sample_ti(generator, learned_tokens, n_tokens, eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps) def velocity_fn(t, x): - return generator.forward(x, clip_f, sync_f, text_clip, + return generator.forward(x, clip_f, sync_f, text_input, t.reshape(1).to(device, dtype)) with torch.no_grad(): @@ -165,6 +188,14 @@ class SelvaTextualInversionTrainer: }), }, "optional": { + "inject_mode": (["suffix", "prefix"], { + "default": "suffix", + "tooltip": ( + "Where to inject the learned tokens in the 77-token CLIP sequence. " + "'suffix' replaces the last K positions (EOS/padding — may be ignored by the model). " + "'prefix' replaces positions 1:1+K right after BOS — higher attention weight, stronger style signal." + ), + }), "init_text": ("STRING", { "default": "", "tooltip": "Optional text phrase to warm-start token values via CLIP. Leave empty for random init (N(0, 0.02)). Example: 'industrial sound design'.", @@ -178,7 +209,7 @@ class SelvaTextualInversionTrainer: def train(self, model, data_dir, output_path, n_tokens, steps, lr, batch_size, seed, save_every, - init_text="", warmup_steps=100): + inject_mode="suffix", init_text="", warmup_steps=100): device = get_device() dtype = model["dtype"] @@ -212,7 +243,7 @@ class SelvaTextualInversionTrainer: device, dtype, mode, data_dir, out_path, n_tokens, steps, lr, batch_size, - warmup_steps, seed, save_every, init_text, + warmup_steps, seed, save_every, init_text, inject_mode, ) smoothed = _smooth_losses(r["loss_history"]) if r["loss_history"] else [] curve_img = _draw_loss_curve(r["loss_history"], log_interval=50, smoothed=smoothed) @@ -223,7 +254,7 @@ class SelvaTextualInversionTrainer: device, dtype, mode, data_dir, out_path, n_tokens, steps, lr, batch_size, - warmup_steps, seed, save_every, init_text, + warmup_steps, seed, save_every, init_text, inject_mode="suffix", ): torch.manual_seed(seed) @@ -289,12 +320,8 @@ class SelvaTextualInversionTrainer: sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype) text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype).clone() - # Inject learned tokens into last n_tokens positions. - # Must use torch.cat (not in-place assignment) so the computation graph - # links text_input → learned_tokens and gradients flow correctly. - text_front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D], no grad - tokens_expanded = learned_tokens.unsqueeze(0).expand(batch_size, -1, -1) # [B, K, D] - text_input = torch.cat([text_front, tokens_expanded], dim=1) # [B, 77, D] with grad + # Inject learned tokens — gradient flows via torch.cat (not in-place assignment). + text_input = _inject_tokens(text_clip, learned_tokens, n_tokens, inject_mode) x1 = generator.normalize(x1) t = torch.rand(batch_size, device=device, dtype=dtype) @@ -326,12 +353,13 @@ class SelvaTextualInversionTrainer: if step % save_every == 0 or step == steps: # Save checkpoint ckpt = { - "embeddings": learned_tokens.detach().cpu(), - "n_tokens": n_tokens, - "step": step, - "init_text": init_text, - "lr": lr, - "steps": steps, + "embeddings": learned_tokens.detach().cpu(), + "n_tokens": n_tokens, + "inject_mode": inject_mode, + "step": step, + "init_text": init_text, + "lr": lr, + "steps": steps, "loss_history": loss_history, } ckpt_path = ckpt_dir / f"step_{step:05d}.pt" @@ -339,7 +367,7 @@ class SelvaTextualInversionTrainer: # Eval sample wav, sr = _eval_sample_ti( - generator, learned_tokens, n_tokens, + generator, learned_tokens, n_tokens, inject_mode, feature_utils_orig, dataset, seq_cfg, device, dtype, seed=seed, ) @@ -365,12 +393,13 @@ class SelvaTextualInversionTrainer: # --- Final save --- final = { - "embeddings": learned_tokens.detach().cpu(), - "n_tokens": n_tokens, - "step": steps, - "init_text": init_text, - "lr": lr, - "steps": steps, + "embeddings": learned_tokens.detach().cpu(), + "n_tokens": n_tokens, + "inject_mode": inject_mode, + "step": steps, + "init_text": init_text, + "lr": lr, + "steps": steps, "loss_history": loss_history, } torch.save(final, str(out_path)) diff --git a/nodes/selva_ti_scheduler.py b/nodes/selva_ti_scheduler.py index 75a1b4d..65cdcb7 100644 --- a/nodes/selva_ti_scheduler.py +++ b/nodes/selva_ti_scheduler.py @@ -82,6 +82,7 @@ _PARAM_DEFAULTS = { "seed": 42, "save_every": 1000, "init_text": "", + "inject_mode": "suffix", } _PALETTE = [ @@ -356,6 +357,7 @@ class SelvaTiScheduler: seed = int(cfg["seed"]) save_every = int(cfg["save_every"]) init_text = str(cfg["init_text"]) + inject_mode = str(cfg["inject_mode"]) output_dir = output_root / exp_id output_dir.mkdir(parents=True, exist_ok=True) @@ -365,7 +367,8 @@ class SelvaTiScheduler: if exp_desc: print(f"[TI Scheduler] {exp_desc}", flush=True) print(f"[TI Scheduler] n_tokens={n_tokens} lr={lr:.2e} steps={steps} " - f"batch_size={batch_size} warmup={warmup} seed={seed}", flush=True) + f"batch_size={batch_size} warmup={warmup} seed={seed} " + f"inject_mode={inject_mode}", flush=True) if init_text: print(f"[TI Scheduler] init_text='{init_text}'", flush=True) @@ -381,6 +384,7 @@ class SelvaTiScheduler: "seed": seed, "save_every": save_every, "init_text": init_text, + "inject_mode": inject_mode, }, "results": {"status": "running"}, "embeddings_path": None, @@ -397,7 +401,7 @@ class SelvaTiScheduler: device, dtype, mode, data_dir, out_path, n_tokens, steps, lr, batch_size, - warmup, seed, save_every, init_text, + warmup, seed, save_every, init_text, inject_mode, ) duration = time.monotonic() - t_start