Peter Robicheaux 10 месяцев назад
Родитель
Сommit
558af32786
1 измененных файлов с 13 добавлено и 3 удалено
  1. 13 3
      ultralytics/nn/modules/block.py

+ 13 - 3
ultralytics/nn/modules/block.py

@@ -1157,10 +1157,20 @@ class TorchVision(nn.Module):
             y = self.m(x)
         return y
 
+import logging
+logger = logging.getLogger(__name__)
+
 try:
-    from flash_attn.flash_attn_interface import flash_attn_func
+    import torch
+    if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:  # Ampere or newer
+        from flash_attn.flash_attn_interface import flash_attn_func
+    else:
+        from torch.nn.functional import scaled_dot_product_attention as flash_attn_func
+        logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
 except Exception:
-    assert False, "import FlashAttention error! Please install FlashAttention first."
+    from torch.nn.functional import scaled_dot_product_attention as flash_attn_func
+    logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
+
 from timm.models.layers import trunc_normal_
 
 class AAttn(nn.Module):
@@ -1311,7 +1321,7 @@ class A2C2f(nn.Module):
         area (int, optional): Number of areas the feature map is divided. Defaults to 1;
         residual (bool, optional): Whether use the residual (with layer scale). Defaults to False;
         mlp_ratio (float, optional): MLP expansion ratio (or MLP hidden dimension ratio). Defaults to 1.2;
-        e (float, optional): Expansion ratio for R-ELAN modules. Defaults to 0.5.
+        e (float, optional): Expansion ratio for R-ELAN modules. Defaults to 0.5;
         g (int, optional): Number of groups for grouped convolution. Defaults to 1;
         shortcut (bool, optional): Whether to use shortcut connection. Defaults to True;