diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py index bdba871..85916f2 100644 --- a/nodes/selva_ditto_optimizer.py +++ b/nodes/selva_ditto_optimizer.py @@ -283,6 +283,26 @@ def _do_optimize(net_generator, feature_utils, mel_converter, clip_f = features["clip_features"].to(device, dtype).clone() sync_f = features["sync_features"].to(device, dtype).clone() + # Strip inference-mode flags from all model weights and buffers BEFORE any + # forward pass. Parameters were loaded in ComfyUI's inference_mode context; + # operations on inference tensors produce inference tensors, so conditions + # computed from tainted weights would also be tainted. clone() outside + # inference_mode produces a normal tensor regardless of the source flag. + def _strip_inference(module): + for mod in module.modules(): + for name, buf in list(mod._buffers.items()): + if buf is not None: + mod._buffers[name] = buf.clone() + for name, param in list(mod._parameters.items()): + if param is not None: + mod._parameters[name] = torch.nn.Parameter( + param.data.clone(), requires_grad=False + ) + + _strip_inference(net_generator) + _strip_inference(feature_utils) + _strip_inference(mel_converter) + net_generator.update_seq_lengths( latent_seq_len=seq_cfg.latent_seq_len, clip_seq_len=clip_f.shape[1], @@ -300,6 +320,20 @@ def _do_optimize(net_generator, feature_utils, mel_converter, bs=1, negative_text_features=neg_text_clip ) + # Clone all tensors inside conditions/empty_conditions to ensure no inference + # flags survived from intermediate computations inside preprocess_conditions. + def _clone_nested(obj): + if isinstance(obj, torch.Tensor): + return obj.clone() + elif isinstance(obj, dict): + return {k: _clone_nested(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return type(obj)(_clone_nested(v) for v in obj) + return obj + + conditions = _clone_nested(conditions) + empty_conditions = _clone_nested(empty_conditions) + # Initial noise — x_0 is the parameter we optimize x0_init = torch.randn( 1, seq_cfg.latent_seq_len, net_generator.latent_dim,