fix(ti-trainer): clamp token norm to CLIP manifold to prevent buzz artifacts
Diagnosis: learned tokens grew to norm ~3.2 while real CLIP content tokens sit at ~1.0. Model never trained on embeddings that large — activates buzz artifact instead of semantic style shift. Fix: measure mean token norm from content positions (1–20) of dataset CLIP embeddings at startup, clamp learned_tokens per-token after every optimizer step to max 1.5× that reference (50% headroom). Token norm is now logged as current/limit for easy monitoring. ti_sweep_1.json: rebuild around norm_clamp group — n4_clamped (primary diagnostic), prefix_clamped, n8_prefix_clamped, warm_clamped. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -288,6 +288,19 @@ class SelvaTextualInversionTrainer:
|
||||
)
|
||||
print(f"[TI Trainer] Init: random N(0, 0.02)", flush=True)
|
||||
|
||||
# --- Measure CLIP token norm from the dataset (content positions 1–20) ---
|
||||
# Learned tokens must stay within this range or the model treats them as
|
||||
# out-of-distribution and produces buzz artifacts instead of style shift.
|
||||
with torch.no_grad():
|
||||
sample_norms = []
|
||||
for item in dataset[:min(len(dataset), 20)]:
|
||||
tc = item[3].squeeze(0) # [77, 1024]
|
||||
sample_norms.append(tc[1:20].norm(dim=-1)) # skip BOS/EOS
|
||||
clip_norm_ref = torch.cat(sample_norms).mean().item()
|
||||
clip_norm_limit = clip_norm_ref * 1.5 # 50% headroom above real tokens
|
||||
print(f"[TI Trainer] CLIP token norm ref={clip_norm_ref:.4f} "
|
||||
f"limit={clip_norm_limit:.4f}", flush=True)
|
||||
|
||||
# --- Optimizer + scheduler ---
|
||||
optimizer = torch.optim.AdamW([learned_tokens], lr=lr, weight_decay=1e-2)
|
||||
|
||||
@@ -356,6 +369,13 @@ class SelvaTextualInversionTrainer:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Clamp token norm to CLIP manifold — prevents out-of-distribution
|
||||
# embeddings that cause buzz artifacts instead of style shift.
|
||||
with torch.no_grad():
|
||||
norms = learned_tokens.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
scale = (clip_norm_limit / norms).clamp(max=1.0)
|
||||
learned_tokens.data.mul_(scale)
|
||||
|
||||
running_loss += loss.item()
|
||||
pbar.update(1)
|
||||
|
||||
@@ -364,10 +384,10 @@ class SelvaTextualInversionTrainer:
|
||||
loss_history.append(round(avg, 6))
|
||||
running_loss = 0.0
|
||||
lr_now = scheduler.get_last_lr()[0]
|
||||
norm = learned_tokens.norm().item()
|
||||
norm = learned_tokens.norm(dim=-1).mean().item()
|
||||
print(f"[TI Trainer] step {step:5d}/{steps} "
|
||||
f"loss={avg:.4f} lr={lr_now:.2e} "
|
||||
f"token_norm={norm:.4f}", flush=True)
|
||||
f"token_norm={norm:.4f}/{clip_norm_limit:.4f}", flush=True)
|
||||
|
||||
if step % save_every == 0 or step == steps:
|
||||
# Save checkpoint
|
||||
|
||||
Reference in New Issue
Block a user