feat(steering): conditional-only injection + per-position vectors
Two improvements for stronger steering effect: 1. Apply steering only during the conditional predict_flow pass by monkey-patching predict_flow to set a flag via identity check (cond is conditions). Hooks skip the unconditional pass, so steering is amplified by cfg_strength (~4.5x) instead of canceling out in the CFG guidance term. 2. Restore per-position [seq, hidden] steering vectors instead of seq-averaged [hidden]. More spatially specific — captures positional activation patterns rather than a global mean. Seq length mismatch at inference time handled via linear interpolation. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -27,7 +27,7 @@ from .selva_lora_trainer import _prepare_dataset
|
||||
def _collect_activations(generator, conditions, latent, t_tensor):
|
||||
"""Run one predict_flow call, collecting latent hidden states per block.
|
||||
|
||||
Returns a list of [hidden_dim] float32 CPU tensors,
|
||||
Returns a list of [seq, hidden_dim] float32 CPU tensors,
|
||||
one per block (joint_blocks first, then fused_blocks).
|
||||
"""
|
||||
activations = []
|
||||
@@ -35,9 +35,7 @@ def _collect_activations(generator, conditions, latent, t_tensor):
|
||||
def make_hook(is_joint):
|
||||
def hook(module, input, output):
|
||||
h = output[0] if is_joint else output
|
||||
# Mean over batch then seq → [hidden]: makes vectors length-agnostic so
|
||||
# they broadcast to any inference duration without shape mismatches.
|
||||
activations.append(h.detach().float().mean(0).mean(0).cpu()) # [hidden]
|
||||
activations.append(h.detach().float().mean(0).cpu()) # [seq, hidden]
|
||||
return hook
|
||||
|
||||
handles = []
|
||||
@@ -186,7 +184,7 @@ class SelvaActivationSteeringExtractor:
|
||||
|
||||
n_joint = len(generator.joint_blocks)
|
||||
payload = {
|
||||
"steering_vectors": steering_vectors, # list of [hidden] tensors
|
||||
"steering_vectors": steering_vectors, # list of [seq, hidden] tensors
|
||||
"n_blocks": n_blocks,
|
||||
"n_joint": n_joint,
|
||||
"n_fused": len(generator.fused_blocks),
|
||||
|
||||
Reference in New Issue
Block a user