fix: resolve critical bugs and quality issues in prismaudio_core/models

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 17:56:02 +01:00
parent 6e1186d5bd
commit 30e85f0f99
6 changed files with 13 additions and 33 deletions
+3 -14
View File
@@ -18,16 +18,12 @@ except ImportError:
flash_attn_kvpacked_func = None
flash_attn_func = None
from .utils import compile
from .utils import compile, checkpoint
try:
import natten
except ImportError:
natten = None
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
@@ -389,8 +385,6 @@ class Attention(nn.Module):
self.lambda_hf = nn.Parameter(torch.zeros(dim))
self.causal = causal
if causal:
print('Using `causal` argument disables FlexAttention. If you want to use them together, incorporate causal masking into `flex_attention_block_mask`.')
@compile
def apply_qk_layernorm(self, q, k):
@@ -409,14 +403,8 @@ class Attention(nn.Module):
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
flash_attn_available = HAS_FLASH_ATTN
if flash_attn_sliding_window is not None and (not flash_attn_available):
print(f"Cannot use FlashAttention sliding window as FlashAttention is disabled or not available")
if (flex_attention_block_mask is not None or flex_attention_score_mod is not None) and flash_attn_sliding_window is not None:
print(f"cannot use both FlashAttention and FlexAttention, favouring FlexAttention")
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
print(f"Disabling FlexAttention because causal is set")
flex_attention_block_mask = None
flex_attention_score_mod = None
@@ -606,7 +594,6 @@ class TransformerBlock(nn.Module):
self.dim_context = dim_context
self.causal = causal
if layer_scale and zero_init_branch_outputs:
print('zero_init_branch_outputs is redundant with layer_scale, setting zero_init_branch_outputs to False')
zero_init_branch_outputs = False
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
@@ -909,6 +896,8 @@ class ContinuousTransformer(nn.Module):
model_dtype = next(self.parameters()).dtype
x = x.to(model_dtype)
prepend_length = 0
info = {
"hidden_states": [],
}