feat: add activation steering pipeline (extractor, loader, sampler injection)
Implements per-block DiT activation steering as an alternative to textual inversion. Extractor runs frozen generator on dataset with BJ vs empty conditions, records mean hidden-state delta per block, saves [hidden_dim] vectors (seq-averaged so they broadcast to any inference duration). Loader reads the bundle. Sampler registers forward hooks during the ODE that add strength × vec to each block output, cleaned up in a finally block. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -17,6 +17,8 @@ _NODES = {
|
|||||||
"SelvaTextualInversionTrainer": (".selva_textual_inversion_trainer", "SelvaTextualInversionTrainer", "SelVA Textual Inversion Trainer"),
|
"SelvaTextualInversionTrainer": (".selva_textual_inversion_trainer", "SelvaTextualInversionTrainer", "SelVA Textual Inversion Trainer"),
|
||||||
"SelvaTextualInversionLoader": (".selva_textual_inversion_loader", "SelvaTextualInversionLoader", "SelVA Textual Inversion Loader"),
|
"SelvaTextualInversionLoader": (".selva_textual_inversion_loader", "SelvaTextualInversionLoader", "SelVA Textual Inversion Loader"),
|
||||||
"SelvaTiScheduler": (".selva_ti_scheduler", "SelvaTiScheduler", "SelVA TI Scheduler"),
|
"SelvaTiScheduler": (".selva_ti_scheduler", "SelvaTiScheduler", "SelVA TI Scheduler"),
|
||||||
|
"SelvaActivationSteeringExtractor": (".selva_activation_steering_extractor", "SelvaActivationSteeringExtractor", "SelVA Activation Steering Extractor"),
|
||||||
|
"SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||||
|
|||||||
@@ -0,0 +1,203 @@
|
|||||||
|
"""SelVA Activation Steering Extractor.
|
||||||
|
|
||||||
|
Computes per-block steering vectors by running the frozen generator on the
|
||||||
|
training dataset and recording how BJ's conditioning shifts the DiT hidden
|
||||||
|
states vs. empty/unconditional conditioning.
|
||||||
|
|
||||||
|
For each block i:
|
||||||
|
steering[i] = mean(latent_hidden | BJ conditions)
|
||||||
|
- mean(latent_hidden | empty conditions)
|
||||||
|
|
||||||
|
The resulting vectors are injected at inference time (via SelVA Sampler's
|
||||||
|
steering_strength input) to nudge the denoising trajectory toward BJ's
|
||||||
|
activation patterns without modifying any model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||||||
|
from .selva_lora_trainer import _prepare_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_activations(generator, conditions, latent, t_tensor):
|
||||||
|
"""Run one predict_flow call, collecting latent hidden states per block.
|
||||||
|
|
||||||
|
Returns a list of [hidden_dim] float32 CPU tensors,
|
||||||
|
one per block (joint_blocks first, then fused_blocks).
|
||||||
|
"""
|
||||||
|
activations = []
|
||||||
|
|
||||||
|
def make_hook(is_joint):
|
||||||
|
def hook(module, input, output):
|
||||||
|
h = output[0] if is_joint else output
|
||||||
|
# Mean over batch then seq → [hidden]: makes vectors length-agnostic so
|
||||||
|
# they broadcast to any inference duration without shape mismatches.
|
||||||
|
activations.append(h.detach().float().mean(0).mean(0).cpu()) # [hidden]
|
||||||
|
return hook
|
||||||
|
|
||||||
|
handles = []
|
||||||
|
for block in generator.joint_blocks:
|
||||||
|
handles.append(block.register_forward_hook(make_hook(is_joint=True)))
|
||||||
|
for block in generator.fused_blocks:
|
||||||
|
handles.append(block.register_forward_hook(make_hook(is_joint=False)))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
generator.predict_flow(latent, t_tensor, conditions)
|
||||||
|
finally:
|
||||||
|
for h in handles:
|
||||||
|
h.remove()
|
||||||
|
|
||||||
|
return activations # list of n_blocks tensors [seq, hidden]
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaActivationSteeringExtractor:
|
||||||
|
"""Computes activation steering vectors from a training dataset.
|
||||||
|
|
||||||
|
Runs the frozen generator on N clips at random timesteps with both
|
||||||
|
BJ-conditioned and empty-conditioned inputs, then saves the mean
|
||||||
|
difference per DiT block to a .pt file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
FUNCTION = "extract"
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
RETURN_NAMES = ("steering_path",)
|
||||||
|
OUTPUT_TOOLTIPS = ("Path to saved steering_vectors.pt — load with SelVA Activation Steering Loader.",)
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Computes per-block activation steering vectors: mean(BJ activations) − "
|
||||||
|
"mean(empty activations) at each DiT block. Load the result with "
|
||||||
|
"SelVA Activation Steering Loader and connect to the Sampler."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("SELVA_MODEL",),
|
||||||
|
"data_dir": ("STRING", {
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Directory containing .npz feature files (same as LoRA/TI trainer).",
|
||||||
|
}),
|
||||||
|
"output_path": ("STRING", {
|
||||||
|
"default": "steering_vectors.pt",
|
||||||
|
"tooltip": "Where to save the steering vectors. Relative paths resolve to ComfyUI output directory.",
|
||||||
|
}),
|
||||||
|
"n_samples": ("INT", {
|
||||||
|
"default": 16, "min": 1, "max": 256,
|
||||||
|
"tooltip": "Number of clips to average over. More = more stable vectors, slower extraction.",
|
||||||
|
}),
|
||||||
|
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def extract(self, model, data_dir, output_path, n_samples, seed):
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
seq_cfg = model["seq_cfg"]
|
||||||
|
|
||||||
|
data_dir = Path(data_dir.strip())
|
||||||
|
if not data_dir.is_absolute():
|
||||||
|
data_dir = Path(folder_paths.models_dir) / data_dir
|
||||||
|
if not data_dir.exists():
|
||||||
|
raise FileNotFoundError(f"[Steering] data_dir not found: {data_dir}")
|
||||||
|
|
||||||
|
out_path = Path(output_path.strip())
|
||||||
|
if not out_path.is_absolute():
|
||||||
|
out_path = Path(folder_paths.get_output_directory()) / out_path
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"\n[Steering] Extracting steering vectors n_samples={n_samples}", flush=True)
|
||||||
|
print(f"[Steering] data_dir = {data_dir}", flush=True)
|
||||||
|
print(f"[Steering] output = {out_path}\n", flush=True)
|
||||||
|
|
||||||
|
dataset = _prepare_dataset(model, data_dir, device)
|
||||||
|
generator = model["generator"]
|
||||||
|
generator.eval()
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
indices = random.choices(range(len(dataset)), k=n_samples)
|
||||||
|
|
||||||
|
n_blocks = len(generator.joint_blocks) + len(generator.fused_blocks)
|
||||||
|
bj_sums = [None] * n_blocks
|
||||||
|
empty_sums = [None] * n_blocks
|
||||||
|
counts = [0] * n_blocks
|
||||||
|
|
||||||
|
pbar = comfy.utils.ProgressBar(n_samples)
|
||||||
|
|
||||||
|
for sample_i, clip_idx in enumerate(indices):
|
||||||
|
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
|
||||||
|
|
||||||
|
clip_f = clip_f_cpu.to(device, dtype) # [1, T_clip, 1024]
|
||||||
|
sync_f = sync_f_cpu.to(device, dtype) # [1, T_sync, 768]
|
||||||
|
text_clip = text_clip_cpu.to(device, dtype) # [1, 77, 1024]
|
||||||
|
|
||||||
|
# x1 shape is [1, latent_seq_len, latent_dim] — dim 1 is the sequence length.
|
||||||
|
clip_latent_seq_len = x1_cpu.shape[1]
|
||||||
|
|
||||||
|
generator.update_seq_lengths(
|
||||||
|
latent_seq_len=clip_latent_seq_len,
|
||||||
|
clip_seq_len=clip_f.shape[1],
|
||||||
|
sync_seq_len=sync_f.shape[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
conditions = generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
||||||
|
empty_conditions = generator.get_empty_conditions(bs=1)
|
||||||
|
|
||||||
|
# Random timestep and noise latent for this clip
|
||||||
|
t_val = torch.rand(1).item()
|
||||||
|
t_tensor = torch.tensor([t_val], device=device, dtype=dtype)
|
||||||
|
latent = torch.randn(
|
||||||
|
1, clip_latent_seq_len, generator.latent_dim,
|
||||||
|
device=device, dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
bj_acts = _collect_activations(generator, conditions, latent, t_tensor)
|
||||||
|
empty_acts = _collect_activations(generator, empty_conditions, latent, t_tensor)
|
||||||
|
|
||||||
|
for i, (bj, em) in enumerate(zip(bj_acts, empty_acts)):
|
||||||
|
if bj_sums[i] is None:
|
||||||
|
bj_sums[i] = bj.clone()
|
||||||
|
empty_sums[i] = em.clone()
|
||||||
|
else:
|
||||||
|
bj_sums[i] += bj
|
||||||
|
empty_sums[i] += em
|
||||||
|
counts[i] += 1
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
if (sample_i + 1) % 4 == 0 or sample_i == n_samples - 1:
|
||||||
|
print(f"[Steering] Processed {sample_i + 1}/{n_samples} clips", flush=True)
|
||||||
|
|
||||||
|
# Steering vector per block: mean(BJ) - mean(empty)
|
||||||
|
steering_vectors = []
|
||||||
|
for i in range(n_blocks):
|
||||||
|
vec = (bj_sums[i] - empty_sums[i]) / counts[i] # [hidden]
|
||||||
|
steering_vectors.append(vec)
|
||||||
|
|
||||||
|
norm = vec.norm().item()
|
||||||
|
print(f"[Steering] Block {i:2d} steering_norm={norm:.4f}", flush=True)
|
||||||
|
|
||||||
|
n_joint = len(generator.joint_blocks)
|
||||||
|
payload = {
|
||||||
|
"steering_vectors": steering_vectors, # list of [hidden] tensors
|
||||||
|
"n_blocks": n_blocks,
|
||||||
|
"n_joint": n_joint,
|
||||||
|
"n_fused": len(generator.fused_blocks),
|
||||||
|
"latent_seq_len": seq_cfg.latent_seq_len,
|
||||||
|
"n_samples": n_samples,
|
||||||
|
"seed": seed,
|
||||||
|
"mode": model["mode"],
|
||||||
|
"variant": model["variant"],
|
||||||
|
}
|
||||||
|
torch.save(payload, str(out_path))
|
||||||
|
print(f"\n[Steering] Saved: {out_path}", flush=True)
|
||||||
|
|
||||||
|
soft_empty_cache()
|
||||||
|
return (str(out_path),)
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
"""SelVA Activation Steering Loader.
|
||||||
|
|
||||||
|
Loads a steering_vectors.pt bundle produced by SelVA Activation Steering Extractor
|
||||||
|
and returns a STEERING_VECTORS dict for use by SelVA Sampler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
from .utils import SELVA_CATEGORY
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaActivationSteeringLoader:
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
FUNCTION = "load"
|
||||||
|
RETURN_TYPES = ("STEERING_VECTORS",)
|
||||||
|
RETURN_NAMES = ("steering_vectors",)
|
||||||
|
OUTPUT_TOOLTIPS = ("Steering vectors bundle — connect to SelVA Sampler's steering_vectors input.",)
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Loads activation steering vectors from a .pt file produced by "
|
||||||
|
"SelVA Activation Steering Extractor. Connect to SelVA Sampler to nudge "
|
||||||
|
"denoising toward the target activation patterns."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"path": ("STRING", {
|
||||||
|
"default": "steering_vectors.pt",
|
||||||
|
"tooltip": "Path to steering_vectors.pt. Relative paths resolve to ComfyUI output directory.",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def load(self, path):
|
||||||
|
p = Path(path.strip())
|
||||||
|
if not p.is_absolute():
|
||||||
|
p = Path(folder_paths.get_output_directory()) / p
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"[Steering] File not found: {p}")
|
||||||
|
|
||||||
|
payload = torch.load(str(p), map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
n_blocks = payload["n_blocks"]
|
||||||
|
n_joint = payload["n_joint"]
|
||||||
|
n_fused = payload["n_fused"]
|
||||||
|
n_vecs = len(payload["steering_vectors"])
|
||||||
|
|
||||||
|
print(f"[Steering] Loaded: {p}", flush=True)
|
||||||
|
print(f"[Steering] blocks={n_blocks} (joint={n_joint} fused={n_fused}) "
|
||||||
|
f"latent_seq_len={payload['latent_seq_len']} "
|
||||||
|
f"n_samples={payload['n_samples']}", flush=True)
|
||||||
|
print(f"[Steering] mode={payload.get('mode')} variant={payload.get('variant')}", flush=True)
|
||||||
|
|
||||||
|
norms = [payload["steering_vectors"][i].norm().item() for i in range(n_vecs)]
|
||||||
|
mean_norm = sum(norms) / len(norms)
|
||||||
|
print(f"[Steering] Mean steering norm across {n_vecs} blocks: {mean_norm:.4f}", flush=True)
|
||||||
|
|
||||||
|
return (payload,)
|
||||||
+40
-1
@@ -32,6 +32,15 @@ class SelvaSampler:
|
|||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
|
"steering_vectors": ("STEERING_VECTORS", {
|
||||||
|
"tooltip": "Activation steering bundle from SelVA Activation Steering Loader. "
|
||||||
|
"Nudges each DiT block's hidden state toward the extracted pattern.",
|
||||||
|
}),
|
||||||
|
"steering_strength": ("FLOAT", {
|
||||||
|
"default": 0.1, "min": 0.0, "max": 2.0, "step": 0.05,
|
||||||
|
"tooltip": "Scale applied to each steering vector before adding to block output. "
|
||||||
|
"Start around 0.1–0.3; higher values risk destabilizing the ODE.",
|
||||||
|
}),
|
||||||
"normalize": ("BOOLEAN", {
|
"normalize": ("BOOLEAN", {
|
||||||
"default": True,
|
"default": True,
|
||||||
"tooltip": "Normalize output level. Uses RMS normalization to target_lufs rather than peak normalization, so level matches typical audio content.",
|
"tooltip": "Normalize output level. Uses RMS normalization to target_lufs rather than peak normalization, so level matches typical audio content.",
|
||||||
@@ -59,7 +68,7 @@ class SelvaSampler:
|
|||||||
CATEGORY = SELVA_CATEGORY
|
CATEGORY = SELVA_CATEGORY
|
||||||
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
|
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
|
||||||
|
|
||||||
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True, target_lufs=-27.0, textual_inversion=None, ti_strength=1.0):
|
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, steering_vectors=None, steering_strength=0.1, normalize=True, target_lufs=-27.0, textual_inversion=None, ti_strength=1.0):
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from selva_core.model.flow_matching import FlowMatching
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
|
|
||||||
@@ -150,6 +159,33 @@ class SelvaSampler:
|
|||||||
device=gen_device, dtype=dtype, generator=rng,
|
device=gen_device, dtype=dtype, generator=rng,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
# Activation steering hooks
|
||||||
|
steering_handles = []
|
||||||
|
if steering_vectors is not None and steering_strength > 0.0:
|
||||||
|
vecs = steering_vectors["steering_vectors"]
|
||||||
|
n_joint = steering_vectors["n_joint"]
|
||||||
|
|
||||||
|
def _make_steering_hook(vec_cpu, is_joint, strength, dev, dt):
|
||||||
|
vec = vec_cpu.to(dev, dt) # [hidden] — broadcasts over [B, T, H]
|
||||||
|
def hook(module, input, output):
|
||||||
|
if is_joint:
|
||||||
|
# JointBlock returns (latent, clip, text) tuple
|
||||||
|
latent_out = output[0] + strength * vec
|
||||||
|
return (latent_out,) + output[1:]
|
||||||
|
else:
|
||||||
|
return output + strength * vec
|
||||||
|
return hook
|
||||||
|
|
||||||
|
blocks = list(net_generator.joint_blocks) + list(net_generator.fused_blocks)
|
||||||
|
for i, block in enumerate(blocks):
|
||||||
|
is_joint = i < n_joint
|
||||||
|
if i < len(vecs):
|
||||||
|
h = block.register_forward_hook(
|
||||||
|
_make_steering_hook(vecs[i], is_joint, steering_strength, device, dtype)
|
||||||
|
)
|
||||||
|
steering_handles.append(h)
|
||||||
|
print(f"[SelVA] Activation steering: {len(steering_handles)} blocks strength={steering_strength}", flush=True)
|
||||||
|
|
||||||
# Flow matching ODE (Euler)
|
# Flow matching ODE (Euler)
|
||||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
@@ -166,6 +202,9 @@ class SelvaSampler:
|
|||||||
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
|
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
|
||||||
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
for h in steering_handles:
|
||||||
|
h.remove()
|
||||||
|
|
||||||
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user