diff --git a/nodes/__init__.py b/nodes/__init__.py index 4fe05fd..f9be34d 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -16,7 +16,9 @@ _NODES = { "SelvaSpectralMatcher": (".selva_audio_preprocessors", "SelvaSpectralMatcher", "SelVA Spectral Matcher"), "SelvaTextualInversionTrainer": (".selva_textual_inversion_trainer", "SelvaTextualInversionTrainer", "SelVA Textual Inversion Trainer"), "SelvaTextualInversionLoader": (".selva_textual_inversion_loader", "SelvaTextualInversionLoader", "SelVA Textual Inversion Loader"), - "SelvaTiScheduler": (".selva_ti_scheduler", "SelvaTiScheduler", "SelVA TI Scheduler"), + "SelvaTiScheduler": (".selva_ti_scheduler", "SelvaTiScheduler", "SelVA TI Scheduler"), + "SelvaActivationSteeringExtractor": (".selva_activation_steering_extractor", "SelvaActivationSteeringExtractor", "SelVA Activation Steering Extractor"), + "SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"), } for key, (module_path, class_name, display_name) in _NODES.items(): diff --git a/nodes/selva_activation_steering_extractor.py b/nodes/selva_activation_steering_extractor.py new file mode 100644 index 0000000..a5307be --- /dev/null +++ b/nodes/selva_activation_steering_extractor.py @@ -0,0 +1,203 @@ +"""SelVA Activation Steering Extractor. + +Computes per-block steering vectors by running the frozen generator on the +training dataset and recording how BJ's conditioning shifts the DiT hidden +states vs. empty/unconditional conditioning. + +For each block i: + steering[i] = mean(latent_hidden | BJ conditions) + - mean(latent_hidden | empty conditions) + +The resulting vectors are injected at inference time (via SelVA Sampler's +steering_strength input) to nudge the denoising trajectory toward BJ's +activation patterns without modifying any model weights. +""" + +import random +from pathlib import Path + +import torch +import comfy.utils +import folder_paths + +from .utils import SELVA_CATEGORY, get_device, soft_empty_cache +from .selva_lora_trainer import _prepare_dataset + + +def _collect_activations(generator, conditions, latent, t_tensor): + """Run one predict_flow call, collecting latent hidden states per block. + + Returns a list of [hidden_dim] float32 CPU tensors, + one per block (joint_blocks first, then fused_blocks). + """ + activations = [] + + def make_hook(is_joint): + def hook(module, input, output): + h = output[0] if is_joint else output + # Mean over batch then seq → [hidden]: makes vectors length-agnostic so + # they broadcast to any inference duration without shape mismatches. + activations.append(h.detach().float().mean(0).mean(0).cpu()) # [hidden] + return hook + + handles = [] + for block in generator.joint_blocks: + handles.append(block.register_forward_hook(make_hook(is_joint=True))) + for block in generator.fused_blocks: + handles.append(block.register_forward_hook(make_hook(is_joint=False))) + + try: + with torch.no_grad(): + generator.predict_flow(latent, t_tensor, conditions) + finally: + for h in handles: + h.remove() + + return activations # list of n_blocks tensors [seq, hidden] + + +class SelvaActivationSteeringExtractor: + """Computes activation steering vectors from a training dataset. + + Runs the frozen generator on N clips at random timesteps with both + BJ-conditioned and empty-conditioned inputs, then saves the mean + difference per DiT block to a .pt file. + """ + + OUTPUT_NODE = True + CATEGORY = SELVA_CATEGORY + FUNCTION = "extract" + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("steering_path",) + OUTPUT_TOOLTIPS = ("Path to saved steering_vectors.pt — load with SelVA Activation Steering Loader.",) + DESCRIPTION = ( + "Computes per-block activation steering vectors: mean(BJ activations) − " + "mean(empty activations) at each DiT block. Load the result with " + "SelVA Activation Steering Loader and connect to the Sampler." + ) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("SELVA_MODEL",), + "data_dir": ("STRING", { + "default": "", + "tooltip": "Directory containing .npz feature files (same as LoRA/TI trainer).", + }), + "output_path": ("STRING", { + "default": "steering_vectors.pt", + "tooltip": "Where to save the steering vectors. Relative paths resolve to ComfyUI output directory.", + }), + "n_samples": ("INT", { + "default": 16, "min": 1, "max": 256, + "tooltip": "Number of clips to average over. More = more stable vectors, slower extraction.", + }), + "seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}), + }, + } + + def extract(self, model, data_dir, output_path, n_samples, seed): + device = get_device() + dtype = model["dtype"] + seq_cfg = model["seq_cfg"] + + data_dir = Path(data_dir.strip()) + if not data_dir.is_absolute(): + data_dir = Path(folder_paths.models_dir) / data_dir + if not data_dir.exists(): + raise FileNotFoundError(f"[Steering] data_dir not found: {data_dir}") + + out_path = Path(output_path.strip()) + if not out_path.is_absolute(): + out_path = Path(folder_paths.get_output_directory()) / out_path + out_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"\n[Steering] Extracting steering vectors n_samples={n_samples}", flush=True) + print(f"[Steering] data_dir = {data_dir}", flush=True) + print(f"[Steering] output = {out_path}\n", flush=True) + + dataset = _prepare_dataset(model, data_dir, device) + generator = model["generator"] + generator.eval() + + torch.manual_seed(seed) + random.seed(seed) + indices = random.choices(range(len(dataset)), k=n_samples) + + n_blocks = len(generator.joint_blocks) + len(generator.fused_blocks) + bj_sums = [None] * n_blocks + empty_sums = [None] * n_blocks + counts = [0] * n_blocks + + pbar = comfy.utils.ProgressBar(n_samples) + + for sample_i, clip_idx in enumerate(indices): + x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx] + + clip_f = clip_f_cpu.to(device, dtype) # [1, T_clip, 1024] + sync_f = sync_f_cpu.to(device, dtype) # [1, T_sync, 768] + text_clip = text_clip_cpu.to(device, dtype) # [1, 77, 1024] + + # x1 shape is [1, latent_seq_len, latent_dim] — dim 1 is the sequence length. + clip_latent_seq_len = x1_cpu.shape[1] + + generator.update_seq_lengths( + latent_seq_len=clip_latent_seq_len, + clip_seq_len=clip_f.shape[1], + sync_seq_len=sync_f.shape[1], + ) + + conditions = generator.preprocess_conditions(clip_f, sync_f, text_clip) + empty_conditions = generator.get_empty_conditions(bs=1) + + # Random timestep and noise latent for this clip + t_val = torch.rand(1).item() + t_tensor = torch.tensor([t_val], device=device, dtype=dtype) + latent = torch.randn( + 1, clip_latent_seq_len, generator.latent_dim, + device=device, dtype=dtype, + ) + + bj_acts = _collect_activations(generator, conditions, latent, t_tensor) + empty_acts = _collect_activations(generator, empty_conditions, latent, t_tensor) + + for i, (bj, em) in enumerate(zip(bj_acts, empty_acts)): + if bj_sums[i] is None: + bj_sums[i] = bj.clone() + empty_sums[i] = em.clone() + else: + bj_sums[i] += bj + empty_sums[i] += em + counts[i] += 1 + + pbar.update(1) + if (sample_i + 1) % 4 == 0 or sample_i == n_samples - 1: + print(f"[Steering] Processed {sample_i + 1}/{n_samples} clips", flush=True) + + # Steering vector per block: mean(BJ) - mean(empty) + steering_vectors = [] + for i in range(n_blocks): + vec = (bj_sums[i] - empty_sums[i]) / counts[i] # [hidden] + steering_vectors.append(vec) + + norm = vec.norm().item() + print(f"[Steering] Block {i:2d} steering_norm={norm:.4f}", flush=True) + + n_joint = len(generator.joint_blocks) + payload = { + "steering_vectors": steering_vectors, # list of [hidden] tensors + "n_blocks": n_blocks, + "n_joint": n_joint, + "n_fused": len(generator.fused_blocks), + "latent_seq_len": seq_cfg.latent_seq_len, + "n_samples": n_samples, + "seed": seed, + "mode": model["mode"], + "variant": model["variant"], + } + torch.save(payload, str(out_path)) + print(f"\n[Steering] Saved: {out_path}", flush=True) + + soft_empty_cache() + return (str(out_path),) diff --git a/nodes/selva_activation_steering_loader.py b/nodes/selva_activation_steering_loader.py new file mode 100644 index 0000000..efd034f --- /dev/null +++ b/nodes/selva_activation_steering_loader.py @@ -0,0 +1,62 @@ +"""SelVA Activation Steering Loader. + +Loads a steering_vectors.pt bundle produced by SelVA Activation Steering Extractor +and returns a STEERING_VECTORS dict for use by SelVA Sampler. +""" + +from pathlib import Path + +import torch +import folder_paths + +from .utils import SELVA_CATEGORY + + +class SelvaActivationSteeringLoader: + CATEGORY = SELVA_CATEGORY + FUNCTION = "load" + RETURN_TYPES = ("STEERING_VECTORS",) + RETURN_NAMES = ("steering_vectors",) + OUTPUT_TOOLTIPS = ("Steering vectors bundle — connect to SelVA Sampler's steering_vectors input.",) + DESCRIPTION = ( + "Loads activation steering vectors from a .pt file produced by " + "SelVA Activation Steering Extractor. Connect to SelVA Sampler to nudge " + "denoising toward the target activation patterns." + ) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "path": ("STRING", { + "default": "steering_vectors.pt", + "tooltip": "Path to steering_vectors.pt. Relative paths resolve to ComfyUI output directory.", + }), + }, + } + + def load(self, path): + p = Path(path.strip()) + if not p.is_absolute(): + p = Path(folder_paths.get_output_directory()) / p + if not p.exists(): + raise FileNotFoundError(f"[Steering] File not found: {p}") + + payload = torch.load(str(p), map_location="cpu", weights_only=False) + + n_blocks = payload["n_blocks"] + n_joint = payload["n_joint"] + n_fused = payload["n_fused"] + n_vecs = len(payload["steering_vectors"]) + + print(f"[Steering] Loaded: {p}", flush=True) + print(f"[Steering] blocks={n_blocks} (joint={n_joint} fused={n_fused}) " + f"latent_seq_len={payload['latent_seq_len']} " + f"n_samples={payload['n_samples']}", flush=True) + print(f"[Steering] mode={payload.get('mode')} variant={payload.get('variant')}", flush=True) + + norms = [payload["steering_vectors"][i].norm().item() for i in range(n_vecs)] + mean_norm = sum(norms) / len(norms) + print(f"[Steering] Mean steering norm across {n_vecs} blocks: {mean_norm:.4f}", flush=True) + + return (payload,) diff --git a/nodes/selva_sampler.py b/nodes/selva_sampler.py index 14eab29..1e3f056 100644 --- a/nodes/selva_sampler.py +++ b/nodes/selva_sampler.py @@ -32,6 +32,15 @@ class SelvaSampler: "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), }, "optional": { + "steering_vectors": ("STEERING_VECTORS", { + "tooltip": "Activation steering bundle from SelVA Activation Steering Loader. " + "Nudges each DiT block's hidden state toward the extracted pattern.", + }), + "steering_strength": ("FLOAT", { + "default": 0.1, "min": 0.0, "max": 2.0, "step": 0.05, + "tooltip": "Scale applied to each steering vector before adding to block output. " + "Start around 0.1–0.3; higher values risk destabilizing the ODE.", + }), "normalize": ("BOOLEAN", { "default": True, "tooltip": "Normalize output level. Uses RMS normalization to target_lufs rather than peak normalization, so level matches typical audio content.", @@ -59,7 +68,7 @@ class SelvaSampler: CATEGORY = SELVA_CATEGORY DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance." - def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True, target_lufs=-27.0, textual_inversion=None, ti_strength=1.0): + def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, steering_vectors=None, steering_strength=0.1, normalize=True, target_lufs=-27.0, textual_inversion=None, ti_strength=1.0): import dataclasses from selva_core.model.flow_matching import FlowMatching @@ -150,6 +159,33 @@ class SelvaSampler: device=gen_device, dtype=dtype, generator=rng, ).to(device) + # Activation steering hooks + steering_handles = [] + if steering_vectors is not None and steering_strength > 0.0: + vecs = steering_vectors["steering_vectors"] + n_joint = steering_vectors["n_joint"] + + def _make_steering_hook(vec_cpu, is_joint, strength, dev, dt): + vec = vec_cpu.to(dev, dt) # [hidden] — broadcasts over [B, T, H] + def hook(module, input, output): + if is_joint: + # JointBlock returns (latent, clip, text) tuple + latent_out = output[0] + strength * vec + return (latent_out,) + output[1:] + else: + return output + strength * vec + return hook + + blocks = list(net_generator.joint_blocks) + list(net_generator.fused_blocks) + for i, block in enumerate(blocks): + is_joint = i < n_joint + if i < len(vecs): + h = block.register_forward_hook( + _make_steering_hook(vecs[i], is_joint, steering_strength, device, dtype) + ) + steering_handles.append(h) + print(f"[SelVA] Activation steering: {len(steering_handles)} blocks strength={steering_strength}", flush=True) + # Flow matching ODE (Euler) fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps) pbar = comfy.utils.ProgressBar(steps) @@ -166,6 +202,9 @@ class SelvaSampler: "[SelVA] CUDA out of memory during generation. Try switching offload_strategy " "to 'offload_to_cpu', using a smaller variant, or reducing duration." ) + finally: + for h in steering_handles: + h.remove() print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)