|
|
@@ -1232,7 +1232,14 @@ class AAttn(nn.Module):
|
|
|
v.contiguous().half()
|
|
|
).to(q.dtype)
|
|
|
elif x.is_cuda and not USE_FLASH_ATTN:
|
|
|
- x = sdpa(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
+ x = sdpa(
|
|
|
+ q.permute(0, 2, 1, 3).contiguous(),
|
|
|
+ k.permute(0, 2, 1, 3).contiguous(),
|
|
|
+ v.permute(0, 2, 1, 3).contiguous(),
|
|
|
+ attn_mask=None,
|
|
|
+ dropout_p=0.0,
|
|
|
+ is_causal=False
|
|
|
+ )
|
|
|
x = x.permute(0, 2, 1, 3)
|
|
|
else:
|
|
|
q = q.permute(0, 2, 3, 1)
|