|
|
@@ -1157,10 +1157,22 @@ class TorchVision(nn.Module):
|
|
|
y = self.m(x)
|
|
|
return y
|
|
|
|
|
|
+import logging
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+USE_FLASH_ATTN = False
|
|
|
try:
|
|
|
- from flash_attn.flash_attn_interface import flash_attn_func
|
|
|
+ import torch
|
|
|
+ if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: # Ampere or newer
|
|
|
+ from flash_attn.flash_attn_interface import flash_attn_func
|
|
|
+ USE_FLASH_ATTN = True
|
|
|
+ else:
|
|
|
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
|
|
|
+ logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
|
|
|
except Exception:
|
|
|
- assert False, "import FlashAttention error! Please install FlashAttention first."
|
|
|
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
|
|
|
+ logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
|
|
|
+
|
|
|
from timm.models.layers import trunc_normal_
|
|
|
|
|
|
class AAttn(nn.Module):
|
|
|
@@ -1215,12 +1227,15 @@ class AAttn(nn.Module):
|
|
|
[self.head_dim, self.head_dim, self.head_dim], dim=3
|
|
|
)
|
|
|
|
|
|
- if x.is_cuda:
|
|
|
+ if x.is_cuda and USE_FLASH_ATTN:
|
|
|
x = flash_attn_func(
|
|
|
q.contiguous().half(),
|
|
|
k.contiguous().half(),
|
|
|
v.contiguous().half()
|
|
|
).to(q.dtype)
|
|
|
+ elif x.is_cuda and not USE_FLASH_ATTN:
|
|
|
+ x = sdpa(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
|
+ x = x.permute(0, 2, 1, 3)
|
|
|
else:
|
|
|
q = q.permute(0, 2, 3, 1)
|
|
|
k = k.permute(0, 2, 3, 1)
|
|
|
@@ -1311,7 +1326,7 @@ class A2C2f(nn.Module):
|
|
|
area (int, optional): Number of areas the feature map is divided. Defaults to 1;
|
|
|
residual (bool, optional): Whether use the residual (with layer scale). Defaults to False;
|
|
|
mlp_ratio (float, optional): MLP expansion ratio (or MLP hidden dimension ratio). Defaults to 1.2;
|
|
|
- e (float, optional): Expansion ratio for R-ELAN modules. Defaults to 0.5.
|
|
|
+ e (float, optional): Expansion ratio for R-ELAN modules. Defaults to 0.5;
|
|
|
g (int, optional): Number of groups for grouped convolution. Defaults to 1;
|
|
|
shortcut (bool, optional): Whether to use shortcut connection. Defaults to True;
|
|
|
|