fix: strip inference-mode tensor flags in DITTO before conditions computation
Root cause: net_generator/feature_utils/mel_converter parameters were loaded in ComfyUI's inference_mode; operations on inference tensors propagate the flag, so conditions computed from tainted weights were also tainted. checkpoint() with use_reentrant=False then failed trying to save inference tensors during the backward recompute pass. Fix: _strip_inference() clones all params/buffers of all three models before any forward pass, and _clone_nested() cleans any residual inference flags in the conditions/empty_conditions output tensors. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -283,6 +283,26 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
|||||||
clip_f = features["clip_features"].to(device, dtype).clone()
|
clip_f = features["clip_features"].to(device, dtype).clone()
|
||||||
sync_f = features["sync_features"].to(device, dtype).clone()
|
sync_f = features["sync_features"].to(device, dtype).clone()
|
||||||
|
|
||||||
|
# Strip inference-mode flags from all model weights and buffers BEFORE any
|
||||||
|
# forward pass. Parameters were loaded in ComfyUI's inference_mode context;
|
||||||
|
# operations on inference tensors produce inference tensors, so conditions
|
||||||
|
# computed from tainted weights would also be tainted. clone() outside
|
||||||
|
# inference_mode produces a normal tensor regardless of the source flag.
|
||||||
|
def _strip_inference(module):
|
||||||
|
for mod in module.modules():
|
||||||
|
for name, buf in list(mod._buffers.items()):
|
||||||
|
if buf is not None:
|
||||||
|
mod._buffers[name] = buf.clone()
|
||||||
|
for name, param in list(mod._parameters.items()):
|
||||||
|
if param is not None:
|
||||||
|
mod._parameters[name] = torch.nn.Parameter(
|
||||||
|
param.data.clone(), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
_strip_inference(net_generator)
|
||||||
|
_strip_inference(feature_utils)
|
||||||
|
_strip_inference(mel_converter)
|
||||||
|
|
||||||
net_generator.update_seq_lengths(
|
net_generator.update_seq_lengths(
|
||||||
latent_seq_len=seq_cfg.latent_seq_len,
|
latent_seq_len=seq_cfg.latent_seq_len,
|
||||||
clip_seq_len=clip_f.shape[1],
|
clip_seq_len=clip_f.shape[1],
|
||||||
@@ -300,6 +320,20 @@ def _do_optimize(net_generator, feature_utils, mel_converter,
|
|||||||
bs=1, negative_text_features=neg_text_clip
|
bs=1, negative_text_features=neg_text_clip
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Clone all tensors inside conditions/empty_conditions to ensure no inference
|
||||||
|
# flags survived from intermediate computations inside preprocess_conditions.
|
||||||
|
def _clone_nested(obj):
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
return obj.clone()
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {k: _clone_nested(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
return type(obj)(_clone_nested(v) for v in obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
conditions = _clone_nested(conditions)
|
||||||
|
empty_conditions = _clone_nested(empty_conditions)
|
||||||
|
|
||||||
# Initial noise — x_0 is the parameter we optimize
|
# Initial noise — x_0 is the parameter we optimize
|
||||||
x0_init = torch.randn(
|
x0_init = torch.randn(
|
||||||
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
||||||
|
|||||||
Reference in New Issue
Block a user