diff --git a/nodes/selva_activation_steering_extractor.py b/nodes/selva_activation_steering_extractor.py index a5307be..cefb482 100644 --- a/nodes/selva_activation_steering_extractor.py +++ b/nodes/selva_activation_steering_extractor.py @@ -27,7 +27,7 @@ from .selva_lora_trainer import _prepare_dataset def _collect_activations(generator, conditions, latent, t_tensor): """Run one predict_flow call, collecting latent hidden states per block. - Returns a list of [hidden_dim] float32 CPU tensors, + Returns a list of [seq, hidden_dim] float32 CPU tensors, one per block (joint_blocks first, then fused_blocks). """ activations = [] @@ -35,9 +35,7 @@ def _collect_activations(generator, conditions, latent, t_tensor): def make_hook(is_joint): def hook(module, input, output): h = output[0] if is_joint else output - # Mean over batch then seq → [hidden]: makes vectors length-agnostic so - # they broadcast to any inference duration without shape mismatches. - activations.append(h.detach().float().mean(0).mean(0).cpu()) # [hidden] + activations.append(h.detach().float().mean(0).cpu()) # [seq, hidden] return hook handles = [] @@ -186,7 +184,7 @@ class SelvaActivationSteeringExtractor: n_joint = len(generator.joint_blocks) payload = { - "steering_vectors": steering_vectors, # list of [hidden] tensors + "steering_vectors": steering_vectors, # list of [seq, hidden] tensors "n_blocks": n_blocks, "n_joint": n_joint, "n_fused": len(generator.fused_blocks), diff --git a/nodes/selva_sampler.py b/nodes/selva_sampler.py index 1e3f056..95233e7 100644 --- a/nodes/selva_sampler.py +++ b/nodes/selva_sampler.py @@ -159,21 +159,50 @@ class SelvaSampler: device=gen_device, dtype=dtype, generator=rng, ).to(device) - # Activation steering hooks + # Activation steering: apply only during the conditional predict_flow pass + # so steering gets amplified by cfg_strength rather than canceling out. steering_handles = [] + _orig_predict_flow = None if steering_vectors is not None and steering_strength > 0.0: vecs = steering_vectors["steering_vectors"] n_joint = steering_vectors["n_joint"] + # Patch predict_flow to flag which pass is conditional. + # ode_wrapper calls predict_flow(conditions) and predict_flow(empty_conditions); + # identity check tells us which is which. + _is_cond_pass = [False] + _orig_predict_flow = net_generator.predict_flow + + def _tracked_predict_flow(latent, t, cond): + _is_cond_pass[0] = (cond is conditions) + return _orig_predict_flow(latent, t, cond) + + net_generator.predict_flow = _tracked_predict_flow + def _make_steering_hook(vec_cpu, is_joint, strength, dev, dt): - vec = vec_cpu.to(dev, dt) # [hidden] — broadcasts over [B, T, H] + vec = vec_cpu.to(dev, dt) # [seq, hidden] def hook(module, input, output): + if not _is_cond_pass[0]: + return # skip unconditional pass; steering amplified by cfg_strength + # Interpolate steering vec to match actual output seq length + # (handles generation at different duration than extraction) if is_joint: - # JointBlock returns (latent, clip, text) tuple - latent_out = output[0] + strength * vec + out_seq = output[0].shape[1] + else: + out_seq = output.shape[1] + v = vec + if v.shape[0] != out_seq: + v = torch.nn.functional.interpolate( + v.T.unsqueeze(0), # [1, hidden, seq_orig] + size=out_seq, + mode="linear", + align_corners=False, + ).squeeze(0).T # [seq_new, hidden] + if is_joint: + latent_out = output[0] + strength * v return (latent_out,) + output[1:] else: - return output + strength * vec + return output + strength * v return hook blocks = list(net_generator.joint_blocks) + list(net_generator.fused_blocks) @@ -184,7 +213,8 @@ class SelvaSampler: _make_steering_hook(vecs[i], is_joint, steering_strength, device, dtype) ) steering_handles.append(h) - print(f"[SelVA] Activation steering: {len(steering_handles)} blocks strength={steering_strength}", flush=True) + print(f"[SelVA] Activation steering: {len(steering_handles)} blocks " + f"strength={steering_strength} (conditional pass only)", flush=True) # Flow matching ODE (Euler) fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps) @@ -203,6 +233,8 @@ class SelvaSampler: "to 'offload_to_cpu', using a smaller variant, or reducing duration." ) finally: + if _orig_predict_flow is not None: + net_generator.predict_flow = _orig_predict_flow for h in steering_handles: h.remove()