feat(steering): conditional-only injection + per-position vectors

Two improvements for stronger steering effect:

1. Apply steering only during the conditional predict_flow pass by
   monkey-patching predict_flow to set a flag via identity check
   (cond is conditions). Hooks skip the unconditional pass, so
   steering is amplified by cfg_strength (~4.5x) instead of canceling
   out in the CFG guidance term.

2. Restore per-position [seq, hidden] steering vectors instead of
   seq-averaged [hidden]. More spatially specific — captures positional
   activation patterns rather than a global mean. Seq length mismatch
   at inference time handled via linear interpolation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 01:02:51 +02:00
parent 95923cdf42
commit 115a0c3718
2 changed files with 41 additions and 11 deletions
+38 -6
View File
@@ -159,21 +159,50 @@ class SelvaSampler:
device=gen_device, dtype=dtype, generator=rng,
).to(device)
# Activation steering hooks
# Activation steering: apply only during the conditional predict_flow pass
# so steering gets amplified by cfg_strength rather than canceling out.
steering_handles = []
_orig_predict_flow = None
if steering_vectors is not None and steering_strength > 0.0:
vecs = steering_vectors["steering_vectors"]
n_joint = steering_vectors["n_joint"]
# Patch predict_flow to flag which pass is conditional.
# ode_wrapper calls predict_flow(conditions) and predict_flow(empty_conditions);
# identity check tells us which is which.
_is_cond_pass = [False]
_orig_predict_flow = net_generator.predict_flow
def _tracked_predict_flow(latent, t, cond):
_is_cond_pass[0] = (cond is conditions)
return _orig_predict_flow(latent, t, cond)
net_generator.predict_flow = _tracked_predict_flow
def _make_steering_hook(vec_cpu, is_joint, strength, dev, dt):
vec = vec_cpu.to(dev, dt) # [hidden] — broadcasts over [B, T, H]
vec = vec_cpu.to(dev, dt) # [seq, hidden]
def hook(module, input, output):
if not _is_cond_pass[0]:
return # skip unconditional pass; steering amplified by cfg_strength
# Interpolate steering vec to match actual output seq length
# (handles generation at different duration than extraction)
if is_joint:
# JointBlock returns (latent, clip, text) tuple
latent_out = output[0] + strength * vec
out_seq = output[0].shape[1]
else:
out_seq = output.shape[1]
v = vec
if v.shape[0] != out_seq:
v = torch.nn.functional.interpolate(
v.T.unsqueeze(0), # [1, hidden, seq_orig]
size=out_seq,
mode="linear",
align_corners=False,
).squeeze(0).T # [seq_new, hidden]
if is_joint:
latent_out = output[0] + strength * v
return (latent_out,) + output[1:]
else:
return output + strength * vec
return output + strength * v
return hook
blocks = list(net_generator.joint_blocks) + list(net_generator.fused_blocks)
@@ -184,7 +213,8 @@ class SelvaSampler:
_make_steering_hook(vecs[i], is_joint, steering_strength, device, dtype)
)
steering_handles.append(h)
print(f"[SelVA] Activation steering: {len(steering_handles)} blocks strength={steering_strength}", flush=True)
print(f"[SelVA] Activation steering: {len(steering_handles)} blocks "
f"strength={steering_strength} (conditional pass only)", flush=True)
# Flow matching ODE (Euler)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
@@ -203,6 +233,8 @@ class SelvaSampler:
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
)
finally:
if _orig_predict_flow is not None:
net_generator.predict_flow = _orig_predict_flow
for h in steering_handles:
h.remove()