田运杰 10 月之前
父節點
當前提交
a53c69e04a
共有 1 個文件被更改,包括 2 次插入4 次删除
  1. 2 4
      ultralytics/nn/modules/block.py

+ 2 - 4
ultralytics/nn/modules/block.py

@@ -1173,8 +1173,6 @@ except Exception:
     from torch.nn.functional import scaled_dot_product_attention as sdpa
     from torch.nn.functional import scaled_dot_product_attention as sdpa
     logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
     logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
 
 
-from timm.models.layers import trunc_normal_
-
 class AAttn(nn.Module):
 class AAttn(nn.Module):
     """
     """
     Area-attention module with the requirement of flash attention.
     Area-attention module with the requirement of flash attention.
@@ -1301,8 +1299,8 @@ class ABlock(nn.Module):
     def _init_weights(self, m):
     def _init_weights(self, m):
         """Initialize weights using a truncated normal distribution."""
         """Initialize weights using a truncated normal distribution."""
         if isinstance(m, nn.Conv2d):
         if isinstance(m, nn.Conv2d):
-            trunc_normal_(m.weight, std=.02)
-            if isinstance(m, nn.Conv2d) and m.bias is not None:
+            nn.init.trunc_normal_(m.weight, std=0.02)
+            if m.bias is not None:
                 nn.init.constant_(m.bias, 0)
                 nn.init.constant_(m.bias, 0)
 
 
     def forward(self, x):
     def forward(self, x):