fix: initialize LoRA params on same device as wrapped linear

apply_lora() is called after generator.to(device), so lora_A/lora_B were
being created on CPU while the rest of the model was on CUDA.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 22:17:29 +02:00
parent ad57432803
commit 8fade1b0e3
+3 -2
View File
@@ -42,8 +42,9 @@ class LoRALinear(nn.Module):
linear.bias.requires_grad_(False) linear.bias.requires_grad_(False)
ref_dtype = linear.weight.dtype ref_dtype = linear.weight.dtype
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype)) ref_device = linear.weight.device
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype)) self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device))
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device))
self.scale = alpha / rank self.scale = alpha / rank
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))