Files
ComfyUI-SelVA/nodes/selva_activation_steering_loader.py
Ethanfel 95923cdf42 feat: add activation steering pipeline (extractor, loader, sampler injection)
Implements per-block DiT activation steering as an alternative to textual
inversion. Extractor runs frozen generator on dataset with BJ vs empty
conditions, records mean hidden-state delta per block, saves [hidden_dim]
vectors (seq-averaged so they broadcast to any inference duration). Loader
reads the bundle. Sampler registers forward hooks during the ODE that add
strength × vec to each block output, cleaned up in a finally block.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 00:38:26 +02:00

63 lines
2.2 KiB
Python

"""SelVA Activation Steering Loader.
Loads a steering_vectors.pt bundle produced by SelVA Activation Steering Extractor
and returns a STEERING_VECTORS dict for use by SelVA Sampler.
"""
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaActivationSteeringLoader:
CATEGORY = SELVA_CATEGORY
FUNCTION = "load"
RETURN_TYPES = ("STEERING_VECTORS",)
RETURN_NAMES = ("steering_vectors",)
OUTPUT_TOOLTIPS = ("Steering vectors bundle — connect to SelVA Sampler's steering_vectors input.",)
DESCRIPTION = (
"Loads activation steering vectors from a .pt file produced by "
"SelVA Activation Steering Extractor. Connect to SelVA Sampler to nudge "
"denoising toward the target activation patterns."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"path": ("STRING", {
"default": "steering_vectors.pt",
"tooltip": "Path to steering_vectors.pt. Relative paths resolve to ComfyUI output directory.",
}),
},
}
def load(self, path):
p = Path(path.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[Steering] File not found: {p}")
payload = torch.load(str(p), map_location="cpu", weights_only=False)
n_blocks = payload["n_blocks"]
n_joint = payload["n_joint"]
n_fused = payload["n_fused"]
n_vecs = len(payload["steering_vectors"])
print(f"[Steering] Loaded: {p}", flush=True)
print(f"[Steering] blocks={n_blocks} (joint={n_joint} fused={n_fused}) "
f"latent_seq_len={payload['latent_seq_len']} "
f"n_samples={payload['n_samples']}", flush=True)
print(f"[Steering] mode={payload.get('mode')} variant={payload.get('variant')}", flush=True)
norms = [payload["steering_vectors"][i].norm().item() for i in range(n_vecs)]
mean_norm = sum(norms) / len(norms)
print(f"[Steering] Mean steering norm across {n_vecs} blocks: {mean_norm:.4f}", flush=True)
return (payload,)