feat: add inject_mode (suffix/prefix) to TI pipeline

Observation: n4_baseline loss barely moved (1.025→0.965 over 3000 steps),
token_norm grew linearly without plateau — generator likely ignores last-K
CLIP positions (EOS/padding zone) where suffix injects.

Fix: add inject_mode parameter throughout the pipeline:
- "suffix": replace last K positions (original behavior, model may ignore)
- "prefix": replace positions 1:1+K right after BOS — highest attention
  weight in CLIP, much stronger gradient signal expected

Changes:
- selva_textual_inversion_trainer.py: _inject_tokens() helper centralises
  the torch.cat construction for both modes; used in training loop and eval;
  inject_mode stored in checkpoint files
- selva_textual_inversion_loader.py: reads inject_mode from checkpoint,
  includes in TEXTUAL_INVERSION bundle
- selva_sampler.py: uses _inject_tokens() via bundle's inject_mode field
- selva_ti_scheduler.py: inject_mode in _PARAM_DEFAULTS, config, and
  _train_inner call
- ti_sweep_1.json: updated with prefix_inject group (n4, n8, n4+warm);
  n4_baseline marked completed; suffix experiments retained for comparison

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-08 23:31:52 +02:00
parent f96265da23
commit e1a2f0ed7d
5 changed files with 105 additions and 59 deletions
+6 -2
View File
@@ -82,6 +82,7 @@ _PARAM_DEFAULTS = {
"seed": 42,
"save_every": 1000,
"init_text": "",
"inject_mode": "suffix",
}
_PALETTE = [
@@ -356,6 +357,7 @@ class SelvaTiScheduler:
seed = int(cfg["seed"])
save_every = int(cfg["save_every"])
init_text = str(cfg["init_text"])
inject_mode = str(cfg["inject_mode"])
output_dir = output_root / exp_id
output_dir.mkdir(parents=True, exist_ok=True)
@@ -365,7 +367,8 @@ class SelvaTiScheduler:
if exp_desc:
print(f"[TI Scheduler] {exp_desc}", flush=True)
print(f"[TI Scheduler] n_tokens={n_tokens} lr={lr:.2e} steps={steps} "
f"batch_size={batch_size} warmup={warmup} seed={seed}", flush=True)
f"batch_size={batch_size} warmup={warmup} seed={seed} "
f"inject_mode={inject_mode}", flush=True)
if init_text:
print(f"[TI Scheduler] init_text='{init_text}'", flush=True)
@@ -381,6 +384,7 @@ class SelvaTiScheduler:
"seed": seed,
"save_every": save_every,
"init_text": init_text,
"inject_mode": inject_mode,
},
"results": {"status": "running"},
"embeddings_path": None,
@@ -397,7 +401,7 @@ class SelvaTiScheduler:
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup, seed, save_every, init_text,
warmup, seed, save_every, init_text, inject_mode,
)
duration = time.monotonic() - t_start