|
|
@@ -1157,10 +1157,10 @@ class TorchVision(nn.Module):
|
|
|
y = self.m(x)
|
|
|
return y
|
|
|
|
|
|
-
|
|
|
-from flash_attn.flash_attn_interface import flash_attn_func
|
|
|
-from timm.models.layers import drop_path, trunc_normal_
|
|
|
-
|
|
|
+try:
|
|
|
+ from flash_attn.flash_attn_interface import flash_attn_func
|
|
|
+except Exception:
|
|
|
+ assert False, "import FlashAttention Error! Please install FlashAttention first."
|
|
|
|
|
|
class AAttn(nn.Module):
|
|
|
"""
|
|
|
@@ -1190,7 +1190,6 @@ class AAttn(nn.Module):
|
|
|
def __init__(self, dim, num_heads, area=1):
|
|
|
"""Initializes the area-attention module, a simple yet efficient attention module for YOLO."""
|
|
|
super().__init__()
|
|
|
-
|
|
|
self.area = area
|
|
|
|
|
|
self.num_heads = num_heads
|
|
|
@@ -1199,7 +1198,7 @@ class AAttn(nn.Module):
|
|
|
|
|
|
self.qkv = Conv(dim, all_head_dim * 3, 1, act=False)
|
|
|
self.proj = Conv(all_head_dim, dim, 1, act=False)
|
|
|
- self.pe = Conv(all_head_dim, dim, 9, 1, 4, g=dim, act=False)
|
|
|
+ self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False)
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
@@ -1285,20 +1284,10 @@ class ABlock(nn.Module):
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False))
|
|
|
|
|
|
- self.apply(self._init_weights)
|
|
|
-
|
|
|
- def _init_weights(self, m):
|
|
|
- """Initialize the layers."""
|
|
|
- if isinstance(m, nn.Conv2d):
|
|
|
- trunc_normal_(m.weight, std=.02)
|
|
|
- if isinstance(m, nn.Conv2d) and m.bias is not None:
|
|
|
- nn.init.constant_(m.bias, 0)
|
|
|
-
|
|
|
def forward(self, x):
|
|
|
"""Executes a forward pass through ABlock, applying area-attention and feed-forward layers to the input tensor."""
|
|
|
x = x + self.attn(x)
|
|
|
x = x + self.mlp(x)
|
|
|
-
|
|
|
return x
|
|
|
|
|
|
|
|
|
@@ -1334,45 +1323,28 @@ class A2C2f(nn.Module):
|
|
|
>>> print(output.shape)
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, c1, c2, n=1, a2=True, area=1, align=False, residual=False, e=0.5, mlp_ratio=1.2, g=1, shortcut=True):
|
|
|
+ def __init__(self, c1, c2, n=1, a2=True, area=1, residual=False, mlp_ratio=2.0, e=0.5, g=1, shortcut=True):
|
|
|
super().__init__()
|
|
|
-
|
|
|
- self.a2 = a2
|
|
|
- self.residual = residual
|
|
|
c_ = int(c2 * e) # hidden channels
|
|
|
assert c_ % 32 == 0, "Dimension of ABlock be a multiple of 32."
|
|
|
|
|
|
# num_heads = c_ // 64 if c_ // 64 >= 2 else c_ // 32
|
|
|
num_heads = c_ // 32
|
|
|
|
|
|
- if self.a2:
|
|
|
- self.cv1 = Conv(c1, c_, 1, 1)
|
|
|
- self.cv2 = Conv((1 + n) * c_, c2, 1) # optional act=FReLU(c2)
|
|
|
+ self.cv1 = Conv(c1, c_, 1, 1)
|
|
|
+ self.cv2 = Conv((1 + n) * c_, c2, 1) # optional act=FReLU(c2)
|
|
|
|
|
|
- if residual:
|
|
|
- self.align = Conv(c1, c2, 1, 1) if align else nn.Identity()
|
|
|
- init_values = 0.01 # or smaller
|
|
|
- self.gamma = nn.Parameter(init_values * torch.ones((c2)), requires_grad=True)
|
|
|
- else:
|
|
|
- self.align, self.gamma = None, None
|
|
|
- else:
|
|
|
- self.cv1 = Conv(c1, 2 * c_, 1, 1)
|
|
|
- self.cv2 = Conv((2 + n) * c_, c2, 1) # optional act=FReLU(c2)
|
|
|
+ init_values = 0.01 # or smaller
|
|
|
+ self.gamma = nn.Parameter(init_values * torch.ones((c2)), requires_grad=True) if a2 and residual else None
|
|
|
|
|
|
self.m = nn.ModuleList(
|
|
|
- nn.Sequential(*(ABlock(c_, num_heads, mlp_ratio, area) for _ in range(2))) if a2 else Bottleneck(c_, c_, shortcut, g) for _ in range(n)
|
|
|
+ nn.Sequential(*(ABlock(c_, num_heads, mlp_ratio, area) for _ in range(2))) if a2 else C3k(c_, c_, 2, shortcut, g) for _ in range(n)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""Forward pass through R-ELAN layer."""
|
|
|
- if self.a2:
|
|
|
- y = [self.cv1(x)]
|
|
|
- y.extend(m(y[-1]) for m in self.m)
|
|
|
- if self.residual:
|
|
|
- return self.align(x) + (self.gamma * self.cv2(torch.cat(y, 1)).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
- else:
|
|
|
- return self.cv2(torch.cat(y, 1))
|
|
|
- else:
|
|
|
- y = list(self.cv1(x).chunk(2, 1))
|
|
|
- y.extend(m(y[-1]) for m in self.m)
|
|
|
- return self.cv2(torch.cat(y, 1))
|
|
|
+ y = [self.cv1(x)]
|
|
|
+ y.extend(m(y[-1]) for m in self.m)
|
|
|
+ if self.gamma is not None:
|
|
|
+ return x + (self.gamma * self.cv2(torch.cat(y, 1)).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
+ return self.cv2(torch.cat(y, 1))
|