田运杰 hace 9 meses
padre
commit
805d7271f7
Se han modificado 1 ficheros con 21 adiciones y 32 borrados
  1. 21 32
      ultralytics/nn/modules/block.py

+ 21 - 32
ultralytics/nn/modules/block.py

@@ -1219,17 +1219,18 @@ class AAttn(nn.Module):
         B, C, H, W = x.shape
         N = H * W
 
+        qk = self.qk(x).flatten(2).transpose(1, 2)
+        v = self.v(x)
+        pp = self.pe(v)
+        v = v.flatten(2).transpose(1, 2)
+
+        if self.area > 1:
+            qk = qk.reshape(B * self.area, N // self.area, C * 2)
+            v = v.reshape(B * self.area, N // self.area, C)
+            B, N, _ = qk.shape
+        q, k = qk.split([C, C], dim=2)
+
         if x.is_cuda and USE_FLASH_ATTN:
-            qk = self.qk(x).flatten(2).transpose(1, 2)
-            v = self.v(x)
-            pp = self.pe(v)
-            v = v.flatten(2).transpose(1, 2)
-
-            if self.area > 1:
-                qk = qk.reshape(B * self.area, N // self.area, C * 2)
-                v = v.reshape(B * self.area, N // self.area, C)
-                B, N, _ = qk.shape
-            q, k = qk.split([C, C], dim=2)
             q = q.view(B, N, self.num_heads, self.head_dim)
             k = k.view(B, N, self.num_heads, self.head_dim)
             v = v.view(B, N, self.num_heads, self.head_dim)
@@ -1239,35 +1240,23 @@ class AAttn(nn.Module):
                 k.contiguous().half(),
                 v.contiguous().half()
             ).to(q.dtype)
-
-            if self.area > 1:
-                x = x.reshape(B // self.area, N * self.area, C)
-                B, N, _ = x.shape
-            x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
         else:
-            qk = self.qk(x).flatten(2)
-            v = self.v(x)
-            pp = self.pe(v)
-            v = v.flatten(2)
-            if self.area > 1:
-                qk = qk.reshape(B * self.area, C * 2, N // self.area)
-                v = v.reshape(B * self.area, C, N // self.area)
-                B, _, N = qk.shape
-
-            q, k = qk.split([C, C], dim=1)
-            q = q.view(B, self.num_heads, self.head_dim, N)
-            k = k.view(B, self.num_heads, self.head_dim, N)
-            v = v.view(B, self.num_heads, self.head_dim, N)
+            q = q.transpose(1, 2).view(B, self.num_heads, self.head_dim, N)
+            k = k.transpose(1, 2).view(B, self.num_heads, self.head_dim, N)
+            v = v.transpose(1, 2).view(B, self.num_heads, self.head_dim, N)
+
             attn = (q.transpose(-2, -1) @ k) * (self.head_dim ** -0.5)
             max_attn = attn.max(dim=-1, keepdim=True).values
             exp_attn = torch.exp(attn - max_attn)
             attn = exp_attn / exp_attn.sum(dim=-1, keepdim=True)
             x = (v @ attn.transpose(-2, -1))
 
-            if self.area > 1:
-                x = x.reshape(B // self.area, C, N * self.area)
-                B, _, N = x.shape
-            x = x.reshape(B, C, H, W)
+            x = x.permute(0, 3, 1, 2)
+
+        if self.area > 1:
+            x = x.reshape(B // self.area, N * self.area, C)
+            B, N, _ = x.shape
+        x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
 
         return self.proj(x + pp)