diff --git a/selva_core/model/lora.py b/selva_core/model/lora.py index c3b0c29..f726cdb 100644 --- a/selva_core/model/lora.py +++ b/selva_core/model/lora.py @@ -42,8 +42,9 @@ class LoRALinear(nn.Module): linear.bias.requires_grad_(False) ref_dtype = linear.weight.dtype - self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype)) - self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype)) + ref_device = linear.weight.device + self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device)) + self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device)) self.scale = alpha / rank nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))