feat(steering): conditional-only injection + per-position vectors

Two improvements for stronger steering effect:

1. Apply steering only during the conditional predict_flow pass by
   monkey-patching predict_flow to set a flag via identity check
   (cond is conditions). Hooks skip the unconditional pass, so
   steering is amplified by cfg_strength (~4.5x) instead of canceling
   out in the CFG guidance term.

2. Restore per-position [seq, hidden] steering vectors instead of
   seq-averaged [hidden]. More spatially specific — captures positional
   activation patterns rather than a global mean. Seq length mismatch
   at inference time handled via linear interpolation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 01:02:51 +02:00
parent 95923cdf42
commit 115a0c3718
2 changed files with 41 additions and 11 deletions
+3 -5
View File
@@ -27,7 +27,7 @@ from .selva_lora_trainer import _prepare_dataset
def _collect_activations(generator, conditions, latent, t_tensor):
"""Run one predict_flow call, collecting latent hidden states per block.
Returns a list of [hidden_dim] float32 CPU tensors,
Returns a list of [seq, hidden_dim] float32 CPU tensors,
one per block (joint_blocks first, then fused_blocks).
"""
activations = []
@@ -35,9 +35,7 @@ def _collect_activations(generator, conditions, latent, t_tensor):
def make_hook(is_joint):
def hook(module, input, output):
h = output[0] if is_joint else output
# Mean over batch then seq → [hidden]: makes vectors length-agnostic so
# they broadcast to any inference duration without shape mismatches.
activations.append(h.detach().float().mean(0).mean(0).cpu()) # [hidden]
activations.append(h.detach().float().mean(0).cpu()) # [seq, hidden]
return hook
handles = []
@@ -186,7 +184,7 @@ class SelvaActivationSteeringExtractor:
n_joint = len(generator.joint_blocks)
payload = {
"steering_vectors": steering_vectors, # list of [hidden] tensors
"steering_vectors": steering_vectors, # list of [seq, hidden] tensors
"n_blocks": n_blocks,
"n_joint": n_joint,
"n_fused": len(generator.fused_blocks),