From fb255edaf05c954ecc72f256324d95fc8f128481 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 17:35:15 +0200 Subject: [PATCH] fix: strip inference-mode tensor flags in DITTO before conditions computation Root cause: net_generator/feature_utils/mel_converter parameters were loaded in ComfyUI's inference_mode; operations on inference tensors propagate the flag, so conditions computed from tainted weights were also tainted. checkpoint() with use_reentrant=False then failed trying to save inference tensors during the backward recompute pass. Fix: _strip_inference() clones all params/buffers of all three models before any forward pass, and _clone_nested() cleans any residual inference flags in the conditions/empty_conditions output tensors. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_ditto_optimizer.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index bdba871..85916f2 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -283,6 +283,26 @@ def _do_optimize(net_generator, feature_utils, mel_converter, clip_f = features["clip_features"].to(device, dtype).clone() sync_f = features["sync_features"].to(device, dtype).clone() + # Strip inference-mode flags from all model weights and buffers BEFORE any + # forward pass. Parameters were loaded in ComfyUI's inference_mode context; + # operations on inference tensors produce inference tensors, so conditions + # computed from tainted weights would also be tainted. clone() outside + # inference_mode produces a normal tensor regardless of the source flag. + def _strip_inference(module): + for mod in module.modules(): + for name, buf in list(mod._buffers.items()): + if buf is not None: + mod._buffers[name] = buf.clone() + for name, param in list(mod._parameters.items()): + if param is not None: + mod._parameters[name] = torch.nn.Parameter( + param.data.clone(), requires_grad=False + ) + + _strip_inference(net_generator) + _strip_inference(feature_utils) + _strip_inference(mel_converter) + net_generator.update_seq_lengths( latent_seq_len=seq_cfg.latent_seq_len, clip_seq_len=clip_f.shape[1], @@ -300,6 +320,20 @@ def _do_optimize(net_generator, feature_utils, mel_converter, bs=1, negative_text_features=neg_text_clip ) + # Clone all tensors inside conditions/empty_conditions to ensure no inference + # flags survived from intermediate computations inside preprocess_conditions. + def _clone_nested(obj): + if isinstance(obj, torch.Tensor): + return obj.clone() + elif isinstance(obj, dict): + return {k: _clone_nested(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return type(obj)(_clone_nested(v) for v in obj) + return obj + + conditions = _clone_nested(conditions) + empty_conditions = _clone_nested(empty_conditions) + # Initial noise — x_0 is the parameter we optimize x0_init = torch.randn( 1, seq_cfg.latent_seq_len, net_generator.latent_dim,