feat(steering): conditional-only injection + per-position vectors
Two improvements for stronger steering effect: 1. Apply steering only during the conditional predict_flow pass by monkey-patching predict_flow to set a flag via identity check (cond is conditions). Hooks skip the unconditional pass, so steering is amplified by cfg_strength (~4.5x) instead of canceling out in the CFG guidance term. 2. Restore per-position [seq, hidden] steering vectors instead of seq-averaged [hidden]. More spatially specific — captures positional activation patterns rather than a global mean. Seq length mismatch at inference time handled via linear interpolation. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -27,7 +27,7 @@ from .selva_lora_trainer import _prepare_dataset
|
|||||||
def _collect_activations(generator, conditions, latent, t_tensor):
|
def _collect_activations(generator, conditions, latent, t_tensor):
|
||||||
"""Run one predict_flow call, collecting latent hidden states per block.
|
"""Run one predict_flow call, collecting latent hidden states per block.
|
||||||
|
|
||||||
Returns a list of [hidden_dim] float32 CPU tensors,
|
Returns a list of [seq, hidden_dim] float32 CPU tensors,
|
||||||
one per block (joint_blocks first, then fused_blocks).
|
one per block (joint_blocks first, then fused_blocks).
|
||||||
"""
|
"""
|
||||||
activations = []
|
activations = []
|
||||||
@@ -35,9 +35,7 @@ def _collect_activations(generator, conditions, latent, t_tensor):
|
|||||||
def make_hook(is_joint):
|
def make_hook(is_joint):
|
||||||
def hook(module, input, output):
|
def hook(module, input, output):
|
||||||
h = output[0] if is_joint else output
|
h = output[0] if is_joint else output
|
||||||
# Mean over batch then seq → [hidden]: makes vectors length-agnostic so
|
activations.append(h.detach().float().mean(0).cpu()) # [seq, hidden]
|
||||||
# they broadcast to any inference duration without shape mismatches.
|
|
||||||
activations.append(h.detach().float().mean(0).mean(0).cpu()) # [hidden]
|
|
||||||
return hook
|
return hook
|
||||||
|
|
||||||
handles = []
|
handles = []
|
||||||
@@ -186,7 +184,7 @@ class SelvaActivationSteeringExtractor:
|
|||||||
|
|
||||||
n_joint = len(generator.joint_blocks)
|
n_joint = len(generator.joint_blocks)
|
||||||
payload = {
|
payload = {
|
||||||
"steering_vectors": steering_vectors, # list of [hidden] tensors
|
"steering_vectors": steering_vectors, # list of [seq, hidden] tensors
|
||||||
"n_blocks": n_blocks,
|
"n_blocks": n_blocks,
|
||||||
"n_joint": n_joint,
|
"n_joint": n_joint,
|
||||||
"n_fused": len(generator.fused_blocks),
|
"n_fused": len(generator.fused_blocks),
|
||||||
|
|||||||
+38
-6
@@ -159,21 +159,50 @@ class SelvaSampler:
|
|||||||
device=gen_device, dtype=dtype, generator=rng,
|
device=gen_device, dtype=dtype, generator=rng,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
# Activation steering hooks
|
# Activation steering: apply only during the conditional predict_flow pass
|
||||||
|
# so steering gets amplified by cfg_strength rather than canceling out.
|
||||||
steering_handles = []
|
steering_handles = []
|
||||||
|
_orig_predict_flow = None
|
||||||
if steering_vectors is not None and steering_strength > 0.0:
|
if steering_vectors is not None and steering_strength > 0.0:
|
||||||
vecs = steering_vectors["steering_vectors"]
|
vecs = steering_vectors["steering_vectors"]
|
||||||
n_joint = steering_vectors["n_joint"]
|
n_joint = steering_vectors["n_joint"]
|
||||||
|
|
||||||
|
# Patch predict_flow to flag which pass is conditional.
|
||||||
|
# ode_wrapper calls predict_flow(conditions) and predict_flow(empty_conditions);
|
||||||
|
# identity check tells us which is which.
|
||||||
|
_is_cond_pass = [False]
|
||||||
|
_orig_predict_flow = net_generator.predict_flow
|
||||||
|
|
||||||
|
def _tracked_predict_flow(latent, t, cond):
|
||||||
|
_is_cond_pass[0] = (cond is conditions)
|
||||||
|
return _orig_predict_flow(latent, t, cond)
|
||||||
|
|
||||||
|
net_generator.predict_flow = _tracked_predict_flow
|
||||||
|
|
||||||
def _make_steering_hook(vec_cpu, is_joint, strength, dev, dt):
|
def _make_steering_hook(vec_cpu, is_joint, strength, dev, dt):
|
||||||
vec = vec_cpu.to(dev, dt) # [hidden] — broadcasts over [B, T, H]
|
vec = vec_cpu.to(dev, dt) # [seq, hidden]
|
||||||
def hook(module, input, output):
|
def hook(module, input, output):
|
||||||
|
if not _is_cond_pass[0]:
|
||||||
|
return # skip unconditional pass; steering amplified by cfg_strength
|
||||||
|
# Interpolate steering vec to match actual output seq length
|
||||||
|
# (handles generation at different duration than extraction)
|
||||||
if is_joint:
|
if is_joint:
|
||||||
# JointBlock returns (latent, clip, text) tuple
|
out_seq = output[0].shape[1]
|
||||||
latent_out = output[0] + strength * vec
|
else:
|
||||||
|
out_seq = output.shape[1]
|
||||||
|
v = vec
|
||||||
|
if v.shape[0] != out_seq:
|
||||||
|
v = torch.nn.functional.interpolate(
|
||||||
|
v.T.unsqueeze(0), # [1, hidden, seq_orig]
|
||||||
|
size=out_seq,
|
||||||
|
mode="linear",
|
||||||
|
align_corners=False,
|
||||||
|
).squeeze(0).T # [seq_new, hidden]
|
||||||
|
if is_joint:
|
||||||
|
latent_out = output[0] + strength * v
|
||||||
return (latent_out,) + output[1:]
|
return (latent_out,) + output[1:]
|
||||||
else:
|
else:
|
||||||
return output + strength * vec
|
return output + strength * v
|
||||||
return hook
|
return hook
|
||||||
|
|
||||||
blocks = list(net_generator.joint_blocks) + list(net_generator.fused_blocks)
|
blocks = list(net_generator.joint_blocks) + list(net_generator.fused_blocks)
|
||||||
@@ -184,7 +213,8 @@ class SelvaSampler:
|
|||||||
_make_steering_hook(vecs[i], is_joint, steering_strength, device, dtype)
|
_make_steering_hook(vecs[i], is_joint, steering_strength, device, dtype)
|
||||||
)
|
)
|
||||||
steering_handles.append(h)
|
steering_handles.append(h)
|
||||||
print(f"[SelVA] Activation steering: {len(steering_handles)} blocks strength={steering_strength}", flush=True)
|
print(f"[SelVA] Activation steering: {len(steering_handles)} blocks "
|
||||||
|
f"strength={steering_strength} (conditional pass only)", flush=True)
|
||||||
|
|
||||||
# Flow matching ODE (Euler)
|
# Flow matching ODE (Euler)
|
||||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
||||||
@@ -203,6 +233,8 @@ class SelvaSampler:
|
|||||||
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
if _orig_predict_flow is not None:
|
||||||
|
net_generator.predict_flow = _orig_predict_flow
|
||||||
for h in steering_handles:
|
for h in steering_handles:
|
||||||
h.remove()
|
h.remove()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user