feat: add activation steering pipeline (extractor, loader, sampler injection)
Implements per-block DiT activation steering as an alternative to textual inversion. Extractor runs frozen generator on dataset with BJ vs empty conditions, records mean hidden-state delta per block, saves [hidden_dim] vectors (seq-averaged so they broadcast to any inference duration). Loader reads the bundle. Sampler registers forward hooks during the ODE that add strength × vec to each block output, cleaned up in a finally block. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,62 @@
|
||||
"""SelVA Activation Steering Loader.
|
||||
|
||||
Loads a steering_vectors.pt bundle produced by SelVA Activation Steering Extractor
|
||||
and returns a STEERING_VECTORS dict for use by SelVA Sampler.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import folder_paths
|
||||
|
||||
from .utils import SELVA_CATEGORY
|
||||
|
||||
|
||||
class SelvaActivationSteeringLoader:
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
FUNCTION = "load"
|
||||
RETURN_TYPES = ("STEERING_VECTORS",)
|
||||
RETURN_NAMES = ("steering_vectors",)
|
||||
OUTPUT_TOOLTIPS = ("Steering vectors bundle — connect to SelVA Sampler's steering_vectors input.",)
|
||||
DESCRIPTION = (
|
||||
"Loads activation steering vectors from a .pt file produced by "
|
||||
"SelVA Activation Steering Extractor. Connect to SelVA Sampler to nudge "
|
||||
"denoising toward the target activation patterns."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"path": ("STRING", {
|
||||
"default": "steering_vectors.pt",
|
||||
"tooltip": "Path to steering_vectors.pt. Relative paths resolve to ComfyUI output directory.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
def load(self, path):
|
||||
p = Path(path.strip())
|
||||
if not p.is_absolute():
|
||||
p = Path(folder_paths.get_output_directory()) / p
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"[Steering] File not found: {p}")
|
||||
|
||||
payload = torch.load(str(p), map_location="cpu", weights_only=False)
|
||||
|
||||
n_blocks = payload["n_blocks"]
|
||||
n_joint = payload["n_joint"]
|
||||
n_fused = payload["n_fused"]
|
||||
n_vecs = len(payload["steering_vectors"])
|
||||
|
||||
print(f"[Steering] Loaded: {p}", flush=True)
|
||||
print(f"[Steering] blocks={n_blocks} (joint={n_joint} fused={n_fused}) "
|
||||
f"latent_seq_len={payload['latent_seq_len']} "
|
||||
f"n_samples={payload['n_samples']}", flush=True)
|
||||
print(f"[Steering] mode={payload.get('mode')} variant={payload.get('variant')}", flush=True)
|
||||
|
||||
norms = [payload["steering_vectors"][i].norm().item() for i in range(n_vecs)]
|
||||
mean_norm = sum(norms) / len(norms)
|
||||
print(f"[Steering] Mean steering norm across {n_vecs} blocks: {mean_norm:.4f}", flush=True)
|
||||
|
||||
return (payload,)
|
||||
Reference in New Issue
Block a user