diff --git a/nodes/selva_lora_trainer.py b/nodes/selva_lora_trainer.py index 42d8bf0..9bbb434 100644 --- a/nodes/selva_lora_trainer.py +++ b/nodes/selva_lora_trainer.py @@ -740,7 +740,7 @@ class SelvaLoraTrainer: sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype) text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype) - generator.normalize(x1) + x1 = generator.normalize(x1) if timestep_mode == "logit_normal" or ( timestep_mode == "curriculum" and step <= curriculum_switch_step