Просмотр исходного кода

Merge pull request #7 from roboflow/main

Adds SDPA Fallback For Attention Impl
田运杰 10 месяцев назад
Родитель
Сommit
18b4bee520
1 измененных файлов с 19 добавлено и 4 удалено
  1. 19 4
      ultralytics/nn/modules/block.py

+ 19 - 4
ultralytics/nn/modules/block.py

@@ -1157,10 +1157,22 @@ class TorchVision(nn.Module):
             y = self.m(x)
         return y
 
+import logging
+logger = logging.getLogger(__name__)
+
+USE_FLASH_ATTN = False
 try:
-    from flash_attn.flash_attn_interface import flash_attn_func
+    import torch
+    if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:  # Ampere or newer
+        from flash_attn.flash_attn_interface import flash_attn_func
+        USE_FLASH_ATTN = True
+    else:
+        from torch.nn.functional import scaled_dot_product_attention as sdpa
+        logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
 except Exception:
-    assert False, "import FlashAttention error! Please install FlashAttention first."
+    from torch.nn.functional import scaled_dot_product_attention as sdpa
+    logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
+
 from timm.models.layers import trunc_normal_
 
 class AAttn(nn.Module):
@@ -1215,12 +1227,15 @@ class AAttn(nn.Module):
             [self.head_dim, self.head_dim, self.head_dim], dim=3
         )
 
-        if x.is_cuda:
+        if x.is_cuda and USE_FLASH_ATTN:
             x = flash_attn_func(
                 q.contiguous().half(),
                 k.contiguous().half(),
                 v.contiguous().half()
             ).to(q.dtype)
+        elif x.is_cuda and not USE_FLASH_ATTN:
+            x = sdpa(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), attn_mask=None, dropout_p=0.0, is_causal=False)
+            x = x.permute(0, 2, 1, 3)
         else:
             q = q.permute(0, 2, 3, 1)
             k = k.permute(0, 2, 3, 1)
@@ -1311,7 +1326,7 @@ class A2C2f(nn.Module):
         area (int, optional): Number of areas the feature map is divided. Defaults to 1;
         residual (bool, optional): Whether use the residual (with layer scale). Defaults to False;
         mlp_ratio (float, optional): MLP expansion ratio (or MLP hidden dimension ratio). Defaults to 1.2;
-        e (float, optional): Expansion ratio for R-ELAN modules. Defaults to 0.5.
+        e (float, optional): Expansion ratio for R-ELAN modules. Defaults to 0.5;
         g (int, optional): Number of groups for grouped convolution. Defaults to 1;
         shortcut (bool, optional): Whether to use shortcut connection. Defaults to True;