Peter Robicheaux 10 mesi fa
parent
commit
7d113ff775
1 ha cambiato i file con 8 aggiunte e 3 eliminazioni
  1. 8 3
      ultralytics/nn/modules/block.py

+ 8 - 3
ultralytics/nn/modules/block.py

@@ -1160,15 +1160,17 @@ class TorchVision(nn.Module):
 import logging
 logger = logging.getLogger(__name__)
 
+USE_FLASH_ATTN = False
 try:
     import torch
     if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:  # Ampere or newer
         from flash_attn.flash_attn_interface import flash_attn_func
+        USE_FLASH_ATTN = True
     else:
-        from torch.nn.functional import scaled_dot_product_attention as flash_attn_func
+        from torch.nn.functional import scaled_dot_product_attention as sdpa
         logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
 except Exception:
-    from torch.nn.functional import scaled_dot_product_attention as flash_attn_func
+    from torch.nn.functional import scaled_dot_product_attention as sdpa
     logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
 
 from timm.models.layers import trunc_normal_
@@ -1225,12 +1227,15 @@ class AAttn(nn.Module):
             [self.head_dim, self.head_dim, self.head_dim], dim=3
         )
 
-        if x.is_cuda:
+        if x.is_cuda and USE_FLASH_ATTN:
             x = flash_attn_func(
                 q.contiguous().half(),
                 k.contiguous().half(),
                 v.contiguous().half()
             ).to(q.dtype)
+        elif x.is_cuda and not USE_FLASH_ATTN:
+            x = sdpa(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), attn_mask=None, dropout_p=0.0, is_causal=False)
+            x = x.permute(0, 2, 1, 3)
         else:
             q = q.permute(0, 2, 3, 1)
             k = k.permute(0, 2, 3, 1)