|
|
@@ -1161,6 +1161,7 @@ try:
|
|
|
from flash_attn.flash_attn_interface import flash_attn_func
|
|
|
except Exception:
|
|
|
assert False, "import FlashAttention error! Please install FlashAttention first."
|
|
|
+from timm.models.layers import trunc_normal_
|
|
|
|
|
|
class AAttn(nn.Module):
|
|
|
"""
|
|
|
@@ -1280,6 +1281,14 @@ class ABlock(nn.Module):
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False))
|
|
|
|
|
|
+ self.apply(self._init_weights)
|
|
|
+
|
|
|
+ def _init_weights(self, m):
|
|
|
+ if isinstance(m, nn.Conv2d):
|
|
|
+ trunc_normal_(m.weight, std=.02)
|
|
|
+ if isinstance(m, nn.Conv2d) and m.bias is not None:
|
|
|
+ nn.init.constant_(m.bias, 0)
|
|
|
+
|
|
|
def forward(self, x):
|
|
|
"""Executes a forward pass through ABlock, applying area-attention and feed-forward layers to the input tensor."""
|
|
|
x = x + self.attn(x)
|