fix: resolve critical bugs and quality issues in prismaudio_core/models
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -18,16 +18,12 @@ except ImportError:
|
||||
flash_attn_kvpacked_func = None
|
||||
flash_attn_func = None
|
||||
|
||||
from .utils import compile
|
||||
from .utils import compile, checkpoint
|
||||
try:
|
||||
import natten
|
||||
except ImportError:
|
||||
natten = None
|
||||
|
||||
def checkpoint(function, *args, **kwargs):
|
||||
kwargs.setdefault("use_reentrant", False)
|
||||
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
@@ -389,8 +385,6 @@ class Attention(nn.Module):
|
||||
self.lambda_hf = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
self.causal = causal
|
||||
if causal:
|
||||
print('Using `causal` argument disables FlexAttention. If you want to use them together, incorporate causal masking into `flex_attention_block_mask`.')
|
||||
|
||||
@compile
|
||||
def apply_qk_layernorm(self, q, k):
|
||||
@@ -409,14 +403,8 @@ class Attention(nn.Module):
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
|
||||
flash_attn_available = HAS_FLASH_ATTN
|
||||
if flash_attn_sliding_window is not None and (not flash_attn_available):
|
||||
print(f"Cannot use FlashAttention sliding window as FlashAttention is disabled or not available")
|
||||
|
||||
if (flex_attention_block_mask is not None or flex_attention_score_mod is not None) and flash_attn_sliding_window is not None:
|
||||
print(f"cannot use both FlashAttention and FlexAttention, favouring FlexAttention")
|
||||
|
||||
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
|
||||
print(f"Disabling FlexAttention because causal is set")
|
||||
flex_attention_block_mask = None
|
||||
flex_attention_score_mod = None
|
||||
|
||||
@@ -606,7 +594,6 @@ class TransformerBlock(nn.Module):
|
||||
self.dim_context = dim_context
|
||||
self.causal = causal
|
||||
if layer_scale and zero_init_branch_outputs:
|
||||
print('zero_init_branch_outputs is redundant with layer_scale, setting zero_init_branch_outputs to False')
|
||||
zero_init_branch_outputs = False
|
||||
|
||||
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||
@@ -909,6 +896,8 @@ class ContinuousTransformer(nn.Module):
|
||||
model_dtype = next(self.parameters()).dtype
|
||||
x = x.to(model_dtype)
|
||||
|
||||
prepend_length = 0
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user