|
|
@@ -1157,10 +1157,20 @@ class TorchVision(nn.Module):
|
|
|
y = self.m(x)
|
|
|
return y
|
|
|
|
|
|
+import logging
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
try:
|
|
|
- from flash_attn.flash_attn_interface import flash_attn_func
|
|
|
+ import torch
|
|
|
+ if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: # Ampere or newer
|
|
|
+ from flash_attn.flash_attn_interface import flash_attn_func
|
|
|
+ else:
|
|
|
+ from torch.nn.functional import scaled_dot_product_attention as flash_attn_func
|
|
|
+ logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
|
|
|
except Exception:
|
|
|
- assert False, "import FlashAttention error! Please install FlashAttention first."
|
|
|
+ from torch.nn.functional import scaled_dot_product_attention as flash_attn_func
|
|
|
+ logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
|
|
|
+
|
|
|
from timm.models.layers import trunc_normal_
|
|
|
|
|
|
class AAttn(nn.Module):
|
|
|
@@ -1311,7 +1321,7 @@ class A2C2f(nn.Module):
|
|
|
area (int, optional): Number of areas the feature map is divided. Defaults to 1;
|
|
|
residual (bool, optional): Whether use the residual (with layer scale). Defaults to False;
|
|
|
mlp_ratio (float, optional): MLP expansion ratio (or MLP hidden dimension ratio). Defaults to 1.2;
|
|
|
- e (float, optional): Expansion ratio for R-ELAN modules. Defaults to 0.5.
|
|
|
+ e (float, optional): Expansion ratio for R-ELAN modules. Defaults to 0.5;
|
|
|
g (int, optional): Number of groups for grouped convolution. Defaults to 1;
|
|
|
shortcut (bool, optional): Whether to use shortcut connection. Defaults to True;
|
|
|
|