田运杰 10 mesiacov pred
rodič
commit
cf93572c47
1 zmenil súbory, kde vykonal 16 pridanie a 44 odobranie
  1. 16 44
      ultralytics/nn/modules/block.py

+ 16 - 44
ultralytics/nn/modules/block.py

@@ -1157,10 +1157,10 @@ class TorchVision(nn.Module):
             y = self.m(x)
         return y
 
-
-from flash_attn.flash_attn_interface import flash_attn_func
-from timm.models.layers import drop_path, trunc_normal_
-
+try:
+    from flash_attn.flash_attn_interface import flash_attn_func
+except Exception:
+    assert False, "import FlashAttention Error! Please install FlashAttention first."
 
 class AAttn(nn.Module):
     """
@@ -1190,7 +1190,6 @@ class AAttn(nn.Module):
     def __init__(self, dim, num_heads, area=1):
         """Initializes the area-attention module, a simple yet efficient attention module for YOLO."""
         super().__init__()
-
         self.area = area
 
         self.num_heads = num_heads
@@ -1199,7 +1198,7 @@ class AAttn(nn.Module):
 
         self.qkv = Conv(dim, all_head_dim * 3, 1, act=False)
         self.proj = Conv(all_head_dim, dim, 1, act=False)
-        self.pe = Conv(all_head_dim, dim, 9, 1, 4, g=dim, act=False)
+        self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False)
 
 
     def forward(self, x):
@@ -1285,20 +1284,10 @@ class ABlock(nn.Module):
         mlp_hidden_dim = int(dim * mlp_ratio)
         self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False))
 
-        self.apply(self._init_weights)
-
-    def _init_weights(self, m):
-        """Initialize the layers."""
-        if isinstance(m, nn.Conv2d):
-            trunc_normal_(m.weight, std=.02)
-            if isinstance(m, nn.Conv2d) and m.bias is not None:
-                nn.init.constant_(m.bias, 0)
-
     def forward(self, x):
         """Executes a forward pass through ABlock, applying area-attention and feed-forward layers to the input tensor."""
         x = x + self.attn(x)
         x = x + self.mlp(x)
-
         return x
 
 
@@ -1334,45 +1323,28 @@ class A2C2f(nn.Module):
         >>> print(output.shape)
     """
 
-    def __init__(self, c1, c2, n=1, a2=True, area=1, align=False, residual=False, e=0.5, mlp_ratio=1.2, g=1, shortcut=True):
+    def __init__(self, c1, c2, n=1, a2=True, area=1, residual=False, mlp_ratio=2.0, e=0.5, g=1, shortcut=True):
         super().__init__()
-
-        self.a2 = a2
-        self.residual = residual
         c_ = int(c2 * e)  # hidden channels
         assert c_ % 32 == 0, "Dimension of ABlock be a multiple of 32."
 
         # num_heads = c_ // 64 if c_ // 64 >= 2 else c_ // 32
         num_heads = c_ // 32
 
-        if self.a2:
-            self.cv1 = Conv(c1, c_, 1, 1)
-            self.cv2 = Conv((1 + n) * c_, c2, 1)  # optional act=FReLU(c2)
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = Conv((1 + n) * c_, c2, 1)  # optional act=FReLU(c2)
 
-            if residual:
-                self.align = Conv(c1, c2, 1, 1) if align else nn.Identity()
-                init_values = 0.01  # or smaller
-                self.gamma = nn.Parameter(init_values * torch.ones((c2)), requires_grad=True)
-            else:
-                self.align, self.gamma = None, None
-        else:
-            self.cv1 = Conv(c1, 2 * c_, 1, 1)
-            self.cv2 = Conv((2 + n) * c_, c2, 1)  # optional act=FReLU(c2)
+        init_values = 0.01  # or smaller
+        self.gamma = nn.Parameter(init_values * torch.ones((c2)), requires_grad=True) if a2 and residual else None
 
         self.m = nn.ModuleList(
-            nn.Sequential(*(ABlock(c_, num_heads, mlp_ratio, area) for _ in range(2))) if a2 else Bottleneck(c_, c_, shortcut, g) for _ in range(n)
+            nn.Sequential(*(ABlock(c_, num_heads, mlp_ratio, area) for _ in range(2))) if a2 else C3k(c_, c_, 2, shortcut, g) for _ in range(n)
         )
 
     def forward(self, x):
         """Forward pass through R-ELAN layer."""
-        if self.a2:
-            y = [self.cv1(x)]
-            y.extend(m(y[-1]) for m in self.m)
-            if self.residual:
-                return self.align(x) + (self.gamma * self.cv2(torch.cat(y, 1)).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
-            else:
-                return self.cv2(torch.cat(y, 1))
-        else:
-            y = list(self.cv1(x).chunk(2, 1))
-            y.extend(m(y[-1]) for m in self.m)
-            return self.cv2(torch.cat(y, 1))
+        y = [self.cv1(x)]
+        y.extend(m(y[-1]) for m in self.m)
+        if self.gamma is not None:
+            return x + (self.gamma * self.cv2(torch.cat(y, 1)).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+        return self.cv2(torch.cat(y, 1))