田运杰 10 месяцев назад
Родитель
Сommit
a53c69e04a
1 измененных файлов с 2 добавлено и 4 удалено
  1. 2 4
      ultralytics/nn/modules/block.py

+ 2 - 4
ultralytics/nn/modules/block.py

@@ -1173,8 +1173,6 @@ except Exception:
     from torch.nn.functional import scaled_dot_product_attention as sdpa
     logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
 
-from timm.models.layers import trunc_normal_
-
 class AAttn(nn.Module):
     """
     Area-attention module with the requirement of flash attention.
@@ -1301,8 +1299,8 @@ class ABlock(nn.Module):
     def _init_weights(self, m):
         """Initialize weights using a truncated normal distribution."""
         if isinstance(m, nn.Conv2d):
-            trunc_normal_(m.weight, std=.02)
-            if isinstance(m, nn.Conv2d) and m.bias is not None:
+            nn.init.trunc_normal_(m.weight, std=0.02)
+            if m.bias is not None:
                 nn.init.constant_(m.bias, 0)
 
     def forward(self, x):