fix: strip inference-mode tensor flags in DITTO before conditions computation
Root cause: net_generator/feature_utils/mel_converter parameters were loaded in ComfyUI's inference_mode; operations on inference tensors propagate the flag, so conditions computed from tainted weights were also tainted. checkpoint() with use_reentrant=False then failed trying to save inference tensors during the backward recompute pass. Fix: _strip_inference() clones all params/buffers of all three models before any forward pass, and _clone_nested() cleans any residual inference flags in the conditions/empty_conditions output tensors. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -283,6 +283,26 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
||||
clip_f = features["clip_features"].to(device, dtype).clone()
|
||||
sync_f = features["sync_features"].to(device, dtype).clone()
|
||||
|
||||
# Strip inference-mode flags from all model weights and buffers BEFORE any
|
||||
# forward pass. Parameters were loaded in ComfyUI's inference_mode context;
|
||||
# operations on inference tensors produce inference tensors, so conditions
|
||||
# computed from tainted weights would also be tainted. clone() outside
|
||||
# inference_mode produces a normal tensor regardless of the source flag.
|
||||
def _strip_inference(module):
|
||||
for mod in module.modules():
|
||||
for name, buf in list(mod._buffers.items()):
|
||||
if buf is not None:
|
||||
mod._buffers[name] = buf.clone()
|
||||
for name, param in list(mod._parameters.items()):
|
||||
if param is not None:
|
||||
mod._parameters[name] = torch.nn.Parameter(
|
||||
param.data.clone(), requires_grad=False
|
||||
)
|
||||
|
||||
_strip_inference(net_generator)
|
||||
_strip_inference(feature_utils)
|
||||
_strip_inference(mel_converter)
|
||||
|
||||
net_generator.update_seq_lengths(
|
||||
latent_seq_len=seq_cfg.latent_seq_len,
|
||||
clip_seq_len=clip_f.shape[1],
|
||||
@@ -300,6 +320,20 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
||||
bs=1, negative_text_features=neg_text_clip
|
||||
)
|
||||
|
||||
# Clone all tensors inside conditions/empty_conditions to ensure no inference
|
||||
# flags survived from intermediate computations inside preprocess_conditions.
|
||||
def _clone_nested(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.clone()
|
||||
elif isinstance(obj, dict):
|
||||
return {k: _clone_nested(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return type(obj)(_clone_nested(v) for v in obj)
|
||||
return obj
|
||||
|
||||
conditions = _clone_nested(conditions)
|
||||
empty_conditions = _clone_nested(empty_conditions)
|
||||
|
||||
# Initial noise — x_0 is the parameter we optimize
|
||||
x0_init = torch.randn(
|
||||
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
||||
|
||||
Reference in New Issue
Block a user