fix: resolve critical bugs and quality issues in prismaudio_core/models
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -200,14 +200,6 @@ def zero_init(layer):
|
||||
nn.init.zeros_(layer.bias)
|
||||
return layer
|
||||
|
||||
def rms_norm(x, scale, eps):
|
||||
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
||||
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
||||
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
||||
return x * scale.to(x.dtype)
|
||||
|
||||
#rms_norm = torch.compile(rms_norm)
|
||||
|
||||
class AdaRMSNorm(nn.Module):
|
||||
def __init__(self, features, cond_features, eps=1e-6):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user