田运杰 10 mesiacov pred
rodič
commit
b70d368b3e
1 zmenil súbory, kde vykonal 23 pridanie a 27 odobranie
  1. 23 27
      ultralytics/nn/modules/block.py

+ 23 - 27
ultralytics/nn/modules/block.py

@@ -1205,44 +1205,40 @@ class AAttn(nn.Module):
         """Processes the input tensor 'x' through the area-attention or attention mechanism."""
         B, C, H, W = x.shape
         N = H * W
+
+        qkv = self.qkv(x).flatten(2).transpose(1, 2)
+        if self.area > 1:
+            qkv = qkv.reshape(B * self.area, N // self.area, C * 3)
+            B, N, _ = qkv.shape
+        q, k, v = qkv.view(B, N, self.num_heads, self.head_dim * 3).split(
+            [self.head_dim, self.head_dim, self.head_dim], dim=3
+        )
+
         if x.is_cuda:
-            qkv = self.qkv(x).flatten(2).transpose(1, 2)
-            if self.area > 1:
-                qkv = qkv.reshape(B * self.area, N // self.area, C * 3)
-                B, N, _ = qkv.shape
-            q, k, v = qkv.view(B, N, self.num_heads, self.head_dim * 3).split(
-                [self.head_dim, self.head_dim, self.head_dim], dim=3
-            )
             x = flash_attn_func(
                 q.contiguous().half(),
                 k.contiguous().half(),
                 v.contiguous().half()
             ).to(q.dtype)
-            if self.area > 1:
-                x = x.reshape(B // self.area, N * self.area, C)
-                v = v.reshape(B // self.area, N * self.area, C)
-                B, N, _ = x.shape
-            x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
-            v = v.reshape(B, H, W, C).permute(0, 3, 1, 2)
         else:
-            qkv = self.qkv(x).flatten(2)
-            if self.area > 1:
-                qkv = qkv.reshape(B * self.area, C * 3, N // self.area)
-                B, _, N = qkv.shape
-            q, k, v = qkv.view(B, self.num_heads, self.head_dim * 3, N).split(
-                [self.head_dim, self.head_dim, self.head_dim], dim=2
-            )
-            attn = (q.transpose(-2, -1) @ k) * (self.num_heads ** -0.5)
+            q = q.permute(0, 2, 3, 1)
+            k = k.permute(0, 2, 3, 1)
+            v = v.permute(0, 2, 3, 1)
+            attn = (q.transpose(-2, -1) @ k) * (self.head_dim ** -0.5)
             max_attn = attn.max(dim=-1, keepdim=True).values 
             exp_attn = torch.exp(attn - max_attn)
             attn = exp_attn / exp_attn.sum(dim=-1, keepdim=True)
             x = (v @ attn.transpose(-2, -1))
-            if self.area > 1:
-                x = x.reshape(B // self.area, C, N * self.area)
-                v = v.reshape(B // self.area, C, N * self.area)
-                B, _, N = x.shape
-            x = x.reshape(B, C, H, W)
-            v = v.reshape(B, C, H, W)
+            x = x.permute(0, 3, 1, 2)
+            v = v.permute(0, 3, 1, 2)
+
+        if self.area > 1:
+            x = x.reshape(B // self.area, N * self.area, C)
+            v = v.reshape(B // self.area, N * self.area, C)
+            B, N, _ = x.shape
+
+        x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
+        v = v.reshape(B, H, W, C).permute(0, 3, 1, 2)
         
         x = x + self.pe(v)
         x = self.proj(x)