fix(ti-trainer): clamp token norm to CLIP manifold to prevent buzz artifacts
Diagnosis: learned tokens grew to norm ~3.2 while real CLIP content tokens sit at ~1.0. Model never trained on embeddings that large — activates buzz artifact instead of semantic style shift. Fix: measure mean token norm from content positions (1–20) of dataset CLIP embeddings at startup, clamp learned_tokens per-token after every optimizer step to max 1.5× that reference (50% headroom). Token norm is now logged as current/limit for easy monitoring. ti_sweep_1.json: rebuild around norm_clamp group — n4_clamped (primary diagnostic), prefix_clamped, n8_prefix_clamped, warm_clamped. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+18
-50
@@ -1,16 +1,16 @@
|
|||||||
{
|
{
|
||||||
"name": "ti_sweep_1",
|
"name": "ti_sweep_1",
|
||||||
"description": "First TI sweep. n4_baseline (suffix, batch=16, lr=1e-3) completed — loss 1.025→0.963, plateau after step 1500, token_norm grew linearly without saturation (overshoot sign). Now testing: prefix injection, lower LR, smaller batch.",
|
"description": "First TI sweep. n4_baseline (suffix, batch=16, lr=1e-3) completed — buzz artifact diagnosed as token norm drifting to 3.2x outside CLIP manifold. All new experiments use norm clamping (auto from dataset) + corrected lr/batch.",
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/ti_sweep_1",
|
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/ti_sweep_1",
|
||||||
"base": {
|
"base": {
|
||||||
"steps": 3000,
|
"steps": 3000,
|
||||||
"batch_size": 16,
|
"batch_size": 4,
|
||||||
"warmup_steps": 100,
|
"warmup_steps": 100,
|
||||||
"save_every": 1000,
|
"save_every": 1000,
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"init_text": "",
|
"init_text": "",
|
||||||
"lr": 1e-3,
|
"lr": 2e-4,
|
||||||
"n_tokens": 4,
|
"n_tokens": 4,
|
||||||
"inject_mode": "suffix"
|
"inject_mode": "suffix"
|
||||||
},
|
},
|
||||||
@@ -19,65 +19,33 @@
|
|||||||
{
|
{
|
||||||
"id": "n4_baseline",
|
"id": "n4_baseline",
|
||||||
"group": "reference",
|
"group": "reference",
|
||||||
"description": "COMPLETED. batch=16, lr=1e-3, suffix. Reference. Loss plateau ~0.963, token_norm linear growth to 3.2 — LR too high for the parameter count."
|
"description": "COMPLETED (old code, no norm clamp). batch=16, lr=1e-3. Token norm drifted to 3.2 → buzz artifact. Kept for loss curve comparison only."
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
"id": "n4_prefix",
|
"id": "n4_clamped",
|
||||||
"group": "prefix_inject",
|
"group": "norm_clamp",
|
||||||
"description": "Same as baseline but prefix injection. Tests whether suffix positions are limiting signal — if prefix loss goes lower or converges faster, suffix was the bottleneck.",
|
"description": "Same as baseline but with norm clamping enabled. Primary diagnostic: does clamping alone fix the buzz? lr=2e-4, batch=4, suffix."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "n4_prefix_clamped",
|
||||||
|
"group": "norm_clamp",
|
||||||
|
"description": "Prefix injection + norm clamping. Best of both: high-attention positions, tokens stay on CLIP manifold.",
|
||||||
"inject_mode": "prefix"
|
"inject_mode": "prefix"
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
"id": "lr_low_b4",
|
"id": "n8_prefix_clamped",
|
||||||
"group": "lr_batch",
|
"group": "norm_clamp",
|
||||||
"description": "lr=2e-4, batch=4. Matches LoRA's working regime. Smaller batch = noisier but more diverse gradients; lower LR = smaller steps, token_norm should plateau rather than drift.",
|
"description": "8 tokens, prefix, clamped. More capacity without the artifact.",
|
||||||
"lr": 2e-4,
|
|
||||||
"batch_size": 4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "lr_mid_b8",
|
|
||||||
"group": "lr_batch",
|
|
||||||
"description": "lr=5e-4, batch=8. Middle ground — half the baseline LR and batch. Token norm should grow slower and saturate.",
|
|
||||||
"lr": 5e-4,
|
|
||||||
"batch_size": 8
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "lr_low_b4_prefix",
|
|
||||||
"group": "lr_batch",
|
|
||||||
"description": "lr=2e-4, batch=4, prefix. Best LR/batch regime + best injection position combined.",
|
|
||||||
"lr": 2e-4,
|
|
||||||
"batch_size": 4,
|
|
||||||
"inject_mode": "prefix"
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "n8_prefix",
|
|
||||||
"group": "prefix_inject",
|
|
||||||
"description": "8 tokens, prefix, baseline LR/batch. More capacity at the better injection position.",
|
|
||||||
"n_tokens": 8,
|
"n_tokens": 8,
|
||||||
"inject_mode": "prefix"
|
"inject_mode": "prefix"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "n4_prefix_warm",
|
"id": "n4_prefix_warm_clamped",
|
||||||
"group": "prefix_inject",
|
"group": "norm_clamp",
|
||||||
"description": "4 tokens, prefix, warm-started from 'mechanical impact sound design'.",
|
"description": "4 tokens, prefix, warm init from 'mechanical impact sound design', clamped. Should converge fastest — starts in-manifold, stays in-manifold.",
|
||||||
"inject_mode": "prefix",
|
"inject_mode": "prefix",
|
||||||
"init_text": "mechanical impact sound design"
|
"init_text": "mechanical impact sound design"
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "n8",
|
|
||||||
"group": "suffix_token_count",
|
|
||||||
"description": "8 tokens, suffix, baseline LR/batch. Capacity ablation vs n4_baseline.",
|
|
||||||
"n_tokens": 8
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "lr_2e3",
|
|
||||||
"group": "lr_batch",
|
|
||||||
"description": "lr=2e-3, baseline batch. Expected to plateau earlier and higher than baseline — confirms LR is the issue.",
|
|
||||||
"lr": 2e-3
|
|
||||||
}
|
}
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -288,6 +288,19 @@ class SelvaTextualInversionTrainer:
|
|||||||
)
|
)
|
||||||
print(f"[TI Trainer] Init: random N(0, 0.02)", flush=True)
|
print(f"[TI Trainer] Init: random N(0, 0.02)", flush=True)
|
||||||
|
|
||||||
|
# --- Measure CLIP token norm from the dataset (content positions 1–20) ---
|
||||||
|
# Learned tokens must stay within this range or the model treats them as
|
||||||
|
# out-of-distribution and produces buzz artifacts instead of style shift.
|
||||||
|
with torch.no_grad():
|
||||||
|
sample_norms = []
|
||||||
|
for item in dataset[:min(len(dataset), 20)]:
|
||||||
|
tc = item[3].squeeze(0) # [77, 1024]
|
||||||
|
sample_norms.append(tc[1:20].norm(dim=-1)) # skip BOS/EOS
|
||||||
|
clip_norm_ref = torch.cat(sample_norms).mean().item()
|
||||||
|
clip_norm_limit = clip_norm_ref * 1.5 # 50% headroom above real tokens
|
||||||
|
print(f"[TI Trainer] CLIP token norm ref={clip_norm_ref:.4f} "
|
||||||
|
f"limit={clip_norm_limit:.4f}", flush=True)
|
||||||
|
|
||||||
# --- Optimizer + scheduler ---
|
# --- Optimizer + scheduler ---
|
||||||
optimizer = torch.optim.AdamW([learned_tokens], lr=lr, weight_decay=1e-2)
|
optimizer = torch.optim.AdamW([learned_tokens], lr=lr, weight_decay=1e-2)
|
||||||
|
|
||||||
@@ -356,6 +369,13 @@ class SelvaTextualInversionTrainer:
|
|||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Clamp token norm to CLIP manifold — prevents out-of-distribution
|
||||||
|
# embeddings that cause buzz artifacts instead of style shift.
|
||||||
|
with torch.no_grad():
|
||||||
|
norms = learned_tokens.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||||
|
scale = (clip_norm_limit / norms).clamp(max=1.0)
|
||||||
|
learned_tokens.data.mul_(scale)
|
||||||
|
|
||||||
running_loss += loss.item()
|
running_loss += loss.item()
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
@@ -364,10 +384,10 @@ class SelvaTextualInversionTrainer:
|
|||||||
loss_history.append(round(avg, 6))
|
loss_history.append(round(avg, 6))
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
lr_now = scheduler.get_last_lr()[0]
|
lr_now = scheduler.get_last_lr()[0]
|
||||||
norm = learned_tokens.norm().item()
|
norm = learned_tokens.norm(dim=-1).mean().item()
|
||||||
print(f"[TI Trainer] step {step:5d}/{steps} "
|
print(f"[TI Trainer] step {step:5d}/{steps} "
|
||||||
f"loss={avg:.4f} lr={lr_now:.2e} "
|
f"loss={avg:.4f} lr={lr_now:.2e} "
|
||||||
f"token_norm={norm:.4f}", flush=True)
|
f"token_norm={norm:.4f}/{clip_norm_limit:.4f}", flush=True)
|
||||||
|
|
||||||
if step % save_every == 0 or step == steps:
|
if step % save_every == 0 or step == steps:
|
||||||
# Save checkpoint
|
# Save checkpoint
|
||||||
|
|||||||
Reference in New Issue
Block a user