Files
ComfyUI-SelVA/nodes/selva_activation_steering_extractor.py
Ethanfel f745e241c4 chore: sanitize tooltips/comments + add experiment configs
- Replace all BJ references with generic "target style/audio" in
  activation steering, DITTO optimizer, and BigVGAN trainer
- Add latent_mixup_alpha/latent_noise_sigma to LoRA scheduler defaults
- Add bigvgan_disc_fm_retest.json and lora_optimized_dataset.json

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 13:44:37 +02:00

202 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""SelVA Activation Steering Extractor.
Computes per-block steering vectors by running the frozen generator on the
training dataset and recording how target style's conditioning shifts the DiT hidden
states vs. empty/unconditional conditioning.
For each block i:
steering[i] = mean(latent_hidden | target style conditions)
- mean(latent_hidden | empty conditions)
The resulting vectors are injected at inference time (via SelVA Sampler's
steering_strength input) to nudge the denoising trajectory toward target style's
activation patterns without modifying any model weights.
"""
import random
from pathlib import Path
import torch
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from .selva_lora_trainer import _prepare_dataset
def _collect_activations(generator, conditions, latent, t_tensor):
"""Run one predict_flow call, collecting latent hidden states per block.
Returns a list of [seq, hidden_dim] float32 CPU tensors,
one per block (joint_blocks first, then fused_blocks).
"""
activations = []
def make_hook(is_joint):
def hook(module, input, output):
h = output[0] if is_joint else output
activations.append(h.detach().float().mean(0).cpu()) # [seq, hidden]
return hook
handles = []
for block in generator.joint_blocks:
handles.append(block.register_forward_hook(make_hook(is_joint=True)))
for block in generator.fused_blocks:
handles.append(block.register_forward_hook(make_hook(is_joint=False)))
try:
with torch.no_grad():
generator.predict_flow(latent, t_tensor, conditions)
finally:
for h in handles:
h.remove()
return activations # list of n_blocks tensors [seq, hidden]
class SelvaActivationSteeringExtractor:
"""Computes activation steering vectors from a training dataset.
Runs the frozen generator on N clips at random timesteps with both
target style-conditioned and empty-conditioned inputs, then saves the mean
difference per DiT block to a .pt file.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "extract"
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("steering_path",)
OUTPUT_TOOLTIPS = ("Path to saved steering_vectors.pt — load with SelVA Activation Steering Loader.",)
DESCRIPTION = (
"Computes per-block activation steering vectors: mean(target style activations) "
"mean(empty activations) at each DiT block. Load the result with "
"SelVA Activation Steering Loader and connect to the Sampler."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"data_dir": ("STRING", {
"default": "",
"tooltip": "Directory containing .npz feature files (same as LoRA/TI trainer).",
}),
"output_path": ("STRING", {
"default": "steering_vectors.pt",
"tooltip": "Where to save the steering vectors. Relative paths resolve to ComfyUI output directory.",
}),
"n_samples": ("INT", {
"default": 16, "min": 1, "max": 256,
"tooltip": "Number of clips to average over. More = more stable vectors, slower extraction.",
}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
},
}
def extract(self, model, data_dir, output_path, n_samples, seed):
device = get_device()
dtype = model["dtype"]
seq_cfg = model["seq_cfg"]
data_dir = Path(data_dir.strip())
if not data_dir.is_absolute():
data_dir = Path(folder_paths.models_dir) / data_dir
if not data_dir.exists():
raise FileNotFoundError(f"[Steering] data_dir not found: {data_dir}")
out_path = Path(output_path.strip())
if not out_path.is_absolute():
out_path = Path(folder_paths.get_output_directory()) / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[Steering] Extracting steering vectors n_samples={n_samples}", flush=True)
print(f"[Steering] data_dir = {data_dir}", flush=True)
print(f"[Steering] output = {out_path}\n", flush=True)
dataset = _prepare_dataset(model, data_dir, device)
generator = model["generator"]
generator.eval()
torch.manual_seed(seed)
random.seed(seed)
indices = random.choices(range(len(dataset)), k=n_samples)
n_blocks = len(generator.joint_blocks) + len(generator.fused_blocks)
style_sums = [None] * n_blocks
empty_sums = [None] * n_blocks
counts = [0] * n_blocks
pbar = comfy.utils.ProgressBar(n_samples)
for sample_i, clip_idx in enumerate(indices):
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
clip_f = clip_f_cpu.to(device, dtype) # [1, T_clip, 1024]
sync_f = sync_f_cpu.to(device, dtype) # [1, T_sync, 768]
text_clip = text_clip_cpu.to(device, dtype) # [1, 77, 1024]
# x1 shape is [1, latent_seq_len, latent_dim] — dim 1 is the sequence length.
clip_latent_seq_len = x1_cpu.shape[1]
generator.update_seq_lengths(
latent_seq_len=clip_latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
conditions = generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = generator.get_empty_conditions(bs=1)
# Random timestep and noise latent for this clip
t_val = torch.rand(1).item()
t_tensor = torch.tensor([t_val], device=device, dtype=dtype)
latent = torch.randn(
1, clip_latent_seq_len, generator.latent_dim,
device=device, dtype=dtype,
)
style_acts = _collect_activations(generator, conditions, latent, t_tensor)
empty_acts = _collect_activations(generator, empty_conditions, latent, t_tensor)
for i, (st, em) in enumerate(zip(style_acts, empty_acts)):
if style_sums[i] is None:
style_sums[i] = st.clone()
empty_sums[i] = em.clone()
else:
style_sums[i] += st
empty_sums[i] += em
counts[i] += 1
pbar.update(1)
if (sample_i + 1) % 4 == 0 or sample_i == n_samples - 1:
print(f"[Steering] Processed {sample_i + 1}/{n_samples} clips", flush=True)
# Steering vector per block: mean(target style) - mean(empty)
steering_vectors = []
for i in range(n_blocks):
vec = (style_sums[i] - empty_sums[i]) / counts[i] # [hidden]
steering_vectors.append(vec)
norm = vec.norm().item()
print(f"[Steering] Block {i:2d} steering_norm={norm:.4f}", flush=True)
n_joint = len(generator.joint_blocks)
payload = {
"steering_vectors": steering_vectors, # list of [seq, hidden] tensors
"n_blocks": n_blocks,
"n_joint": n_joint,
"n_fused": len(generator.fused_blocks),
"latent_seq_len": seq_cfg.latent_seq_len,
"n_samples": n_samples,
"seed": seed,
"mode": model["mode"],
"variant": model["variant"],
}
torch.save(payload, str(out_path))
print(f"\n[Steering] Saved: {out_path}", flush=True)
soft_empty_cache()
return (str(out_path),)