diff --git a/selva_core/model/lora.py b/selva_core/model/lora.py index bbf0132..c3b0c29 100644 --- a/selva_core/model/lora.py +++ b/selva_core/model/lora.py @@ -41,8 +41,9 @@ class LoRALinear(nn.Module): if linear.bias is not None: linear.bias.requires_grad_(False) - self.lora_A = nn.Parameter(torch.empty(rank, in_f)) - self.lora_B = nn.Parameter(torch.zeros(out_f, rank)) + ref_dtype = linear.weight.dtype + self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype)) + self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype)) self.scale = alpha / rank nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) diff --git a/train_lora.py b/train_lora.py index bc3ab4b..f11c955 100644 --- a/train_lora.py +++ b/train_lora.py @@ -25,7 +25,6 @@ Usage: import argparse import os import sys -import math import random import json from pathlib import Path