From cde280049b9d2827b683b372130bed924eb2f3e5 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 5 Apr 2026 14:57:09 +0200 Subject: [PATCH] fix: correct LoRALinear dtype and remove unused import - LoRALinear now creates lora_A/lora_B with dtype matching the base linear's weight, preventing a float32/bf16 mismatch at forward time when the generator is loaded in bf16 or fp16. - Remove unused `import math` from train_lora.py. Co-Authored-By: Claude Sonnet 4.6 --- selva_core/model/lora.py | 5 +++-- train_lora.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/selva_core/model/lora.py b/selva_core/model/lora.py index bbf0132..c3b0c29 100644 --- a/selva_core/model/lora.py +++ b/selva_core/model/lora.py @@ -41,8 +41,9 @@ class LoRALinear(nn.Module): if linear.bias is not None: linear.bias.requires_grad_(False) - self.lora_A = nn.Parameter(torch.empty(rank, in_f)) - self.lora_B = nn.Parameter(torch.zeros(out_f, rank)) + ref_dtype = linear.weight.dtype + self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype)) + self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype)) self.scale = alpha / rank nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) diff --git a/train_lora.py b/train_lora.py index bc3ab4b..f11c955 100644 --- a/train_lora.py +++ b/train_lora.py @@ -25,7 +25,6 @@ Usage: import argparse import os import sys -import math import random import json from pathlib import Path