fix: strip inference-mode tensor flags in DITTO before conditions computation

Root cause: net_generator/feature_utils/mel_converter parameters were loaded
in ComfyUI's inference_mode; operations on inference tensors propagate the flag,
so conditions computed from tainted weights were also tainted. checkpoint()
with use_reentrant=False then failed trying to save inference tensors during
the backward recompute pass.

Fix: _strip_inference() clones all params/buffers of all three models before
any forward pass, and _clone_nested() cleans any residual inference flags in
the conditions/empty_conditions output tensors.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 17:35:15 +02:00
parent 8ccc2438e4
commit fb255edaf0
+34
View File
@@ -283,6 +283,26 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
clip_f = features["clip_features"].to(device, dtype).clone() clip_f = features["clip_features"].to(device, dtype).clone()
sync_f = features["sync_features"].to(device, dtype).clone() sync_f = features["sync_features"].to(device, dtype).clone()
# Strip inference-mode flags from all model weights and buffers BEFORE any
# forward pass. Parameters were loaded in ComfyUI's inference_mode context;
# operations on inference tensors produce inference tensors, so conditions
# computed from tainted weights would also be tainted. clone() outside
# inference_mode produces a normal tensor regardless of the source flag.
def _strip_inference(module):
for mod in module.modules():
for name, buf in list(mod._buffers.items()):
if buf is not None:
mod._buffers[name] = buf.clone()
for name, param in list(mod._parameters.items()):
if param is not None:
mod._parameters[name] = torch.nn.Parameter(
param.data.clone(), requires_grad=False
)
_strip_inference(net_generator)
_strip_inference(feature_utils)
_strip_inference(mel_converter)
net_generator.update_seq_lengths( net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len, latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=clip_f.shape[1], clip_seq_len=clip_f.shape[1],
@@ -300,6 +320,20 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
bs=1, negative_text_features=neg_text_clip bs=1, negative_text_features=neg_text_clip
) )
# Clone all tensors inside conditions/empty_conditions to ensure no inference
# flags survived from intermediate computations inside preprocess_conditions.
def _clone_nested(obj):
if isinstance(obj, torch.Tensor):
return obj.clone()
elif isinstance(obj, dict):
return {k: _clone_nested(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return type(obj)(_clone_nested(v) for v in obj)
return obj
conditions = _clone_nested(conditions)
empty_conditions = _clone_nested(empty_conditions)
# Initial noise — x_0 is the parameter we optimize # Initial noise — x_0 is the parameter we optimize
x0_init = torch.randn( x0_init = torch.randn(
1, seq_cfg.latent_seq_len, net_generator.latent_dim, 1, seq_cfg.latent_seq_len, net_generator.latent_dim,