feat: add activation steering pipeline (extractor, loader, sampler injection)
Implements per-block DiT activation steering as an alternative to textual inversion. Extractor runs frozen generator on dataset with BJ vs empty conditions, records mean hidden-state delta per block, saves [hidden_dim] vectors (seq-averaged so they broadcast to any inference duration). Loader reads the bundle. Sampler registers forward hooks during the ODE that add strength × vec to each block output, cleaned up in a finally block. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+40
-1
@@ -32,6 +32,15 @@ class SelvaSampler:
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
"optional": {
|
||||
"steering_vectors": ("STEERING_VECTORS", {
|
||||
"tooltip": "Activation steering bundle from SelVA Activation Steering Loader. "
|
||||
"Nudges each DiT block's hidden state toward the extracted pattern.",
|
||||
}),
|
||||
"steering_strength": ("FLOAT", {
|
||||
"default": 0.1, "min": 0.0, "max": 2.0, "step": 0.05,
|
||||
"tooltip": "Scale applied to each steering vector before adding to block output. "
|
||||
"Start around 0.1–0.3; higher values risk destabilizing the ODE.",
|
||||
}),
|
||||
"normalize": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Normalize output level. Uses RMS normalization to target_lufs rather than peak normalization, so level matches typical audio content.",
|
||||
@@ -59,7 +68,7 @@ class SelvaSampler:
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
|
||||
|
||||
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True, target_lufs=-27.0, textual_inversion=None, ti_strength=1.0):
|
||||
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, steering_vectors=None, steering_strength=0.1, normalize=True, target_lufs=-27.0, textual_inversion=None, ti_strength=1.0):
|
||||
import dataclasses
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
|
||||
@@ -150,6 +159,33 @@ class SelvaSampler:
|
||||
device=gen_device, dtype=dtype, generator=rng,
|
||||
).to(device)
|
||||
|
||||
# Activation steering hooks
|
||||
steering_handles = []
|
||||
if steering_vectors is not None and steering_strength > 0.0:
|
||||
vecs = steering_vectors["steering_vectors"]
|
||||
n_joint = steering_vectors["n_joint"]
|
||||
|
||||
def _make_steering_hook(vec_cpu, is_joint, strength, dev, dt):
|
||||
vec = vec_cpu.to(dev, dt) # [hidden] — broadcasts over [B, T, H]
|
||||
def hook(module, input, output):
|
||||
if is_joint:
|
||||
# JointBlock returns (latent, clip, text) tuple
|
||||
latent_out = output[0] + strength * vec
|
||||
return (latent_out,) + output[1:]
|
||||
else:
|
||||
return output + strength * vec
|
||||
return hook
|
||||
|
||||
blocks = list(net_generator.joint_blocks) + list(net_generator.fused_blocks)
|
||||
for i, block in enumerate(blocks):
|
||||
is_joint = i < n_joint
|
||||
if i < len(vecs):
|
||||
h = block.register_forward_hook(
|
||||
_make_steering_hook(vecs[i], is_joint, steering_strength, device, dtype)
|
||||
)
|
||||
steering_handles.append(h)
|
||||
print(f"[SelVA] Activation steering: {len(steering_handles)} blocks strength={steering_strength}", flush=True)
|
||||
|
||||
# Flow matching ODE (Euler)
|
||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
@@ -166,6 +202,9 @@ class SelvaSampler:
|
||||
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
|
||||
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
||||
)
|
||||
finally:
|
||||
for h in steering_handles:
|
||||
h.remove()
|
||||
|
||||
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user