fix: resolve critical bugs and quality issues in prismaudio_core/models

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 17:56:02 +01:00
parent 6e1186d5bd
commit 30e85f0f99
6 changed files with 13 additions and 33 deletions
-8
View File
@@ -200,14 +200,6 @@ def zero_init(layer):
nn.init.zeros_(layer.bias)
return layer
def rms_norm(x, scale, eps):
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
return x * scale.to(x.dtype)
#rms_norm = torch.compile(rms_norm)
class AdaRMSNorm(nn.Module):
def __init__(self, features, cond_features, eps=1e-6):
super().__init__()