|
|
@@ -1160,15 +1160,17 @@ class TorchVision(nn.Module):
|
|
|
import logging
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
+USE_FLASH_ATTN = False
|
|
|
try:
|
|
|
import torch
|
|
|
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: # Ampere or newer
|
|
|
from flash_attn.flash_attn_interface import flash_attn_func
|
|
|
+ USE_FLASH_ATTN = True
|
|
|
else:
|
|
|
- from torch.nn.functional import scaled_dot_product_attention as flash_attn_func
|
|
|
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
|
|
|
logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
|
|
|
except Exception:
|
|
|
- from torch.nn.functional import scaled_dot_product_attention as flash_attn_func
|
|
|
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
|
|
|
logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
|
|
|
|
|
|
from timm.models.layers import trunc_normal_
|
|
|
@@ -1225,12 +1227,15 @@ class AAttn(nn.Module):
|
|
|
[self.head_dim, self.head_dim, self.head_dim], dim=3
|
|
|
)
|
|
|
|
|
|
- if x.is_cuda:
|
|
|
+ if x.is_cuda and USE_FLASH_ATTN:
|
|
|
x = flash_attn_func(
|
|
|
q.contiguous().half(),
|
|
|
k.contiguous().half(),
|
|
|
v.contiguous().half()
|
|
|
).to(q.dtype)
|
|
|
+ elif x.is_cuda and not USE_FLASH_ATTN:
|
|
|
+ x = sdpa(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
+ x = x.permute(0, 2, 1, 3)
|
|
|
else:
|
|
|
q = q.permute(0, 2, 3, 1)
|
|
|
k = k.permute(0, 2, 3, 1)
|