fix: correct LoRALinear dtype and remove unused import
- LoRALinear now creates lora_A/lora_B with dtype matching the base linear's weight, preventing a float32/bf16 mismatch at forward time when the generator is loaded in bf16 or fp16. - Remove unused `import math` from train_lora.py. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -41,8 +41,9 @@ class LoRALinear(nn.Module):
|
|||||||
if linear.bias is not None:
|
if linear.bias is not None:
|
||||||
linear.bias.requires_grad_(False)
|
linear.bias.requires_grad_(False)
|
||||||
|
|
||||||
self.lora_A = nn.Parameter(torch.empty(rank, in_f))
|
ref_dtype = linear.weight.dtype
|
||||||
self.lora_B = nn.Parameter(torch.zeros(out_f, rank))
|
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype))
|
||||||
|
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype))
|
||||||
self.scale = alpha / rank
|
self.scale = alpha / rank
|
||||||
|
|
||||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import math
|
|
||||||
import random
|
import random
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|||||||
Reference in New Issue
Block a user