Selaa lähdekoodia

Update block.py

田运杰 10 kuukautta sitten
vanhempi
commit
98829e05da
1 muutettua tiedostoa jossa 44 lisäystä ja 38 poistoa
  1. 44 38
      ultralytics/nn/modules/block.py

+ 44 - 38
ultralytics/nn/modules/block.py

@@ -1207,9 +1207,11 @@ class AAttn(nn.Module):
         self.head_dim = head_dim = dim // num_heads
         all_head_dim = head_dim * self.num_heads
 
-        self.qkv = Conv(dim, all_head_dim * 3, 1, act=False)
+        self.qk = Conv(dim, all_head_dim * 2, 1, act=False)
+        self.v = Conv(dim, all_head_dim, 1, act=False)
         self.proj = Conv(all_head_dim, dim, 1, act=False)
-        self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False)
+
+        self.pe = Conv(all_head_dim, dim, 5, 1, 2, g=dim, act=False)
 
 
     def forward(self, x):
@@ -1217,53 +1219,57 @@ class AAttn(nn.Module):
         B, C, H, W = x.shape
         N = H * W
 
-        qkv = self.qkv(x).flatten(2).transpose(1, 2)
-        if self.area > 1:
-            qkv = qkv.reshape(B * self.area, N // self.area, C * 3)
-            B, N, _ = qkv.shape
-        q, k, v = qkv.view(B, N, self.num_heads, self.head_dim * 3).split(
-            [self.head_dim, self.head_dim, self.head_dim], dim=3
-        )
-
         if x.is_cuda and USE_FLASH_ATTN:
+            qk = self.qk(x).flatten(2).transpose(1, 2)
+            v = self.v(x)
+            pp = self.pe(v)
+            v = v.flatten(2).transpose(1, 2)
+
+            if self.area > 1:
+                qk = qk.reshape(B * self.area, N // self.area, C * 2)
+                v = v.reshape(B * self.area, N // self.area, C)
+                B, N, _ = qk.shape
+            q, k = qk.split([C, C], dim=2)
+            q = q.view(B, N, self.num_heads, self.head_dim)
+            k = k.view(B, N, self.num_heads, self.head_dim)
+            v = v.view(B, N, self.num_heads, self.head_dim)
+
             x = flash_attn_func(
                 q.contiguous().half(),
                 k.contiguous().half(),
                 v.contiguous().half()
             ).to(q.dtype)
-        elif x.is_cuda and not USE_FLASH_ATTN:
-            x = sdpa(
-                q.permute(0, 2, 1, 3).contiguous(), 
-                k.permute(0, 2, 1, 3).contiguous(), 
-                v.permute(0, 2, 1, 3).contiguous(), 
-                attn_mask=None, 
-                dropout_p=0.0, 
-                is_causal=False
-            )
-            x = x.permute(0, 2, 1, 3)
+
+            if self.area > 1:
+                x = x.reshape(B // self.area, N * self.area, C)
+                B, N, _ = x.shape
+            x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
         else:
-            q = q.permute(0, 2, 3, 1)
-            k = k.permute(0, 2, 3, 1)
-            v = v.permute(0, 2, 3, 1)
+            qk = self.qk(x).flatten(2)
+            v = self.v(x)
+            pp = self.pe(v)
+            v = v.flatten(2)
+            if self.area > 1:
+                qk = qk.reshape(B * self.area, C * 2, N // self.area)
+                v = v.reshape(B * self.area, C, N // self.area)
+                B, _, N = qk.shape
+
+            q, k = qk.split([C, C], dim=1)
+            q = q.view(B, self.num_heads, self.head_dim, N)
+            k = k.view(B, self.num_heads, self.head_dim, N)
+            v = v.view(B, self.num_heads, self.head_dim, N)
             attn = (q.transpose(-2, -1) @ k) * (self.head_dim ** -0.5)
-            max_attn = attn.max(dim=-1, keepdim=True).values 
+            max_attn = attn.max(dim=-1, keepdim=True).values
             exp_attn = torch.exp(attn - max_attn)
             attn = exp_attn / exp_attn.sum(dim=-1, keepdim=True)
             x = (v @ attn.transpose(-2, -1))
-            x = x.permute(0, 3, 1, 2)
-            v = v.permute(0, 3, 1, 2)
-
-        if self.area > 1:
-            x = x.reshape(B // self.area, N * self.area, C)
-            v = v.reshape(B // self.area, N * self.area, C)
-            B, N, _ = x.shape
-
-        x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
-        v = v.reshape(B, H, W, C).permute(0, 3, 1, 2)
-        
-        x = x + self.pe(v)
-        x = self.proj(x)
-        return x
+
+            if self.area > 1:
+                x = x.reshape(B // self.area, C, N * self.area)
+                B, _, N = x.shape
+            x = x.reshape(B, C, H, W)
+
+        return self.proj(x + pp)
     
 
 class ABlock(nn.Module):