田运杰 10 месяцев назад
Родитель
Сommit
a28adcd565
1 измененных файлов с 9 добавлено и 0 удалено
  1. 9 0
      ultralytics/nn/modules/block.py

+ 9 - 0
ultralytics/nn/modules/block.py

@@ -1161,6 +1161,7 @@ try:
     from flash_attn.flash_attn_interface import flash_attn_func
 except Exception:
     assert False, "import FlashAttention error! Please install FlashAttention first."
+from timm.models.layers import trunc_normal_
 
 class AAttn(nn.Module):
     """
@@ -1280,6 +1281,14 @@ class ABlock(nn.Module):
         mlp_hidden_dim = int(dim * mlp_ratio)
         self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False))
 
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Conv2d):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Conv2d) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+
     def forward(self, x):
         """Executes a forward pass through ABlock, applying area-attention and feed-forward layers to the input tensor."""
         x = x + self.attn(x)