fix: correct LoRALinear dtype and remove unused import

- LoRALinear now creates lora_A/lora_B with dtype matching the base
  linear's weight, preventing a float32/bf16 mismatch at forward time
  when the generator is loaded in bf16 or fp16.
- Remove unused `import math` from train_lora.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 14:57:09 +02:00
parent 437c62b28f
commit cde280049b
2 changed files with 3 additions and 3 deletions
+3 -2
View File
@@ -41,8 +41,9 @@ class LoRALinear(nn.Module):
if linear.bias is not None: if linear.bias is not None:
linear.bias.requires_grad_(False) linear.bias.requires_grad_(False)
self.lora_A = nn.Parameter(torch.empty(rank, in_f)) ref_dtype = linear.weight.dtype
self.lora_B = nn.Parameter(torch.zeros(out_f, rank)) self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype))
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype))
self.scale = alpha / rank self.scale = alpha / rank
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
-1
View File
@@ -25,7 +25,6 @@ Usage:
import argparse import argparse
import os import os
import sys import sys
import math
import random import random
import json import json
from pathlib import Path from pathlib import Path