conv.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. """Convolution modules."""
  3. import math
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. __all__ = (
  8. "Conv",
  9. "Conv2",
  10. "LightConv",
  11. "DWConv",
  12. "DWConvTranspose2d",
  13. "ConvTranspose",
  14. "Focus",
  15. "GhostConv",
  16. "ChannelAttention",
  17. "SpatialAttention",
  18. "CBAM",
  19. "Concat",
  20. "RepConv",
  21. "Index",
  22. )
  23. def autopad(k, p=None, d=1): # kernel, padding, dilation
  24. """Pad to 'same' shape outputs."""
  25. if d > 1:
  26. k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
  27. if p is None:
  28. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  29. return p
  30. class Conv(nn.Module):
  31. """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
  32. default_act = nn.SiLU() # default activation
  33. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
  34. """Initialize Conv layer with given arguments including activation."""
  35. super().__init__()
  36. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
  37. self.bn = nn.BatchNorm2d(c2)
  38. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  39. def forward(self, x):
  40. """Apply convolution, batch normalization and activation to input tensor."""
  41. return self.act(self.bn(self.conv(x)))
  42. def forward_fuse(self, x):
  43. """Apply convolution and activation without batch normalization."""
  44. return self.act(self.conv(x))
  45. class Conv2(Conv):
  46. """Simplified RepConv module with Conv fusing."""
  47. def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
  48. """Initialize Conv layer with given arguments including activation."""
  49. super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
  50. self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
  51. def forward(self, x):
  52. """Apply convolution, batch normalization and activation to input tensor."""
  53. return self.act(self.bn(self.conv(x) + self.cv2(x)))
  54. def forward_fuse(self, x):
  55. """Apply fused convolution, batch normalization and activation to input tensor."""
  56. return self.act(self.bn(self.conv(x)))
  57. def fuse_convs(self):
  58. """Fuse parallel convolutions."""
  59. w = torch.zeros_like(self.conv.weight.data)
  60. i = [x // 2 for x in w.shape[2:]]
  61. w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone()
  62. self.conv.weight.data += w
  63. self.__delattr__("cv2")
  64. self.forward = self.forward_fuse
  65. class LightConv(nn.Module):
  66. """
  67. Light convolution with args(ch_in, ch_out, kernel).
  68. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
  69. """
  70. def __init__(self, c1, c2, k=1, act=nn.ReLU()):
  71. """Initialize Conv layer with given arguments including activation."""
  72. super().__init__()
  73. self.conv1 = Conv(c1, c2, 1, act=False)
  74. self.conv2 = DWConv(c2, c2, k, act=act)
  75. def forward(self, x):
  76. """Apply 2 convolutions to input tensor."""
  77. return self.conv2(self.conv1(x))
  78. class DWConv(Conv):
  79. """Depth-wise convolution."""
  80. def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
  81. """Initialize Depth-wise convolution with given parameters."""
  82. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
  83. class DWConvTranspose2d(nn.ConvTranspose2d):
  84. """Depth-wise transpose convolution."""
  85. def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
  86. """Initialize DWConvTranspose2d class with given parameters."""
  87. super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
  88. class ConvTranspose(nn.Module):
  89. """Convolution transpose 2d layer."""
  90. default_act = nn.SiLU() # default activation
  91. def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
  92. """Initialize ConvTranspose2d layer with batch normalization and activation function."""
  93. super().__init__()
  94. self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
  95. self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
  96. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  97. def forward(self, x):
  98. """Applies transposed convolutions, batch normalization and activation to input."""
  99. return self.act(self.bn(self.conv_transpose(x)))
  100. def forward_fuse(self, x):
  101. """Applies activation and convolution transpose operation to input."""
  102. return self.act(self.conv_transpose(x))
  103. class Focus(nn.Module):
  104. """Focus wh information into c-space."""
  105. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
  106. """Initializes Focus object with user defined channel, convolution, padding, group and activation values."""
  107. super().__init__()
  108. self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)
  109. # self.contract = Contract(gain=2)
  110. def forward(self, x):
  111. """
  112. Applies convolution to concatenated tensor and returns the output.
  113. Input shape is (b,c,w,h) and output shape is (b,4c,w/2,h/2).
  114. """
  115. return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
  116. # return self.conv(self.contract(x))
  117. class GhostConv(nn.Module):
  118. """Ghost Convolution https://github.com/huawei-noah/ghostnet."""
  119. def __init__(self, c1, c2, k=1, s=1, g=1, act=True):
  120. """Initializes Ghost Convolution module with primary and cheap operations for efficient feature learning."""
  121. super().__init__()
  122. c_ = c2 // 2 # hidden channels
  123. self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
  124. self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
  125. def forward(self, x):
  126. """Forward propagation through a Ghost Bottleneck layer with skip connection."""
  127. y = self.cv1(x)
  128. return torch.cat((y, self.cv2(y)), 1)
  129. class RepConv(nn.Module):
  130. """
  131. RepConv is a basic rep-style block, including training and deploy status.
  132. This module is used in RT-DETR.
  133. Based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
  134. """
  135. default_act = nn.SiLU() # default activation
  136. def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
  137. """Initializes Light Convolution layer with inputs, outputs & optional activation function."""
  138. super().__init__()
  139. assert k == 3 and p == 1
  140. self.g = g
  141. self.c1 = c1
  142. self.c2 = c2
  143. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  144. self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None
  145. self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
  146. self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
  147. def forward_fuse(self, x):
  148. """Forward process."""
  149. return self.act(self.conv(x))
  150. def forward(self, x):
  151. """Forward process."""
  152. id_out = 0 if self.bn is None else self.bn(x)
  153. return self.act(self.conv1(x) + self.conv2(x) + id_out)
  154. def get_equivalent_kernel_bias(self):
  155. """Returns equivalent kernel and bias by adding 3x3 kernel, 1x1 kernel and identity kernel with their biases."""
  156. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
  157. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
  158. kernelid, biasid = self._fuse_bn_tensor(self.bn)
  159. return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
  160. @staticmethod
  161. def _pad_1x1_to_3x3_tensor(kernel1x1):
  162. """Pads a 1x1 tensor to a 3x3 tensor."""
  163. if kernel1x1 is None:
  164. return 0
  165. else:
  166. return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  167. def _fuse_bn_tensor(self, branch):
  168. """Generates appropriate kernels and biases for convolution by fusing branches of the neural network."""
  169. if branch is None:
  170. return 0, 0
  171. if isinstance(branch, Conv):
  172. kernel = branch.conv.weight
  173. running_mean = branch.bn.running_mean
  174. running_var = branch.bn.running_var
  175. gamma = branch.bn.weight
  176. beta = branch.bn.bias
  177. eps = branch.bn.eps
  178. elif isinstance(branch, nn.BatchNorm2d):
  179. if not hasattr(self, "id_tensor"):
  180. input_dim = self.c1 // self.g
  181. kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
  182. for i in range(self.c1):
  183. kernel_value[i, i % input_dim, 1, 1] = 1
  184. self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
  185. kernel = self.id_tensor
  186. running_mean = branch.running_mean
  187. running_var = branch.running_var
  188. gamma = branch.weight
  189. beta = branch.bias
  190. eps = branch.eps
  191. std = (running_var + eps).sqrt()
  192. t = (gamma / std).reshape(-1, 1, 1, 1)
  193. return kernel * t, beta - running_mean * gamma / std
  194. def fuse_convs(self):
  195. """Combines two convolution layers into a single layer and removes unused attributes from the class."""
  196. if hasattr(self, "conv"):
  197. return
  198. kernel, bias = self.get_equivalent_kernel_bias()
  199. self.conv = nn.Conv2d(
  200. in_channels=self.conv1.conv.in_channels,
  201. out_channels=self.conv1.conv.out_channels,
  202. kernel_size=self.conv1.conv.kernel_size,
  203. stride=self.conv1.conv.stride,
  204. padding=self.conv1.conv.padding,
  205. dilation=self.conv1.conv.dilation,
  206. groups=self.conv1.conv.groups,
  207. bias=True,
  208. ).requires_grad_(False)
  209. self.conv.weight.data = kernel
  210. self.conv.bias.data = bias
  211. for para in self.parameters():
  212. para.detach_()
  213. self.__delattr__("conv1")
  214. self.__delattr__("conv2")
  215. if hasattr(self, "nm"):
  216. self.__delattr__("nm")
  217. if hasattr(self, "bn"):
  218. self.__delattr__("bn")
  219. if hasattr(self, "id_tensor"):
  220. self.__delattr__("id_tensor")
  221. class ChannelAttention(nn.Module):
  222. """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""
  223. def __init__(self, channels: int) -> None:
  224. """Initializes the class and sets the basic configurations and instance variables required."""
  225. super().__init__()
  226. self.pool = nn.AdaptiveAvgPool2d(1)
  227. self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
  228. self.act = nn.Sigmoid()
  229. def forward(self, x: torch.Tensor) -> torch.Tensor:
  230. """Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""
  231. return x * self.act(self.fc(self.pool(x)))
  232. class SpatialAttention(nn.Module):
  233. """Spatial-attention module."""
  234. def __init__(self, kernel_size=7):
  235. """Initialize Spatial-attention module with kernel size argument."""
  236. super().__init__()
  237. assert kernel_size in {3, 7}, "kernel size must be 3 or 7"
  238. padding = 3 if kernel_size == 7 else 1
  239. self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
  240. self.act = nn.Sigmoid()
  241. def forward(self, x):
  242. """Apply channel and spatial attention on input for feature recalibration."""
  243. return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
  244. class CBAM(nn.Module):
  245. """Convolutional Block Attention Module."""
  246. def __init__(self, c1, kernel_size=7):
  247. """Initialize CBAM with given input channel (c1) and kernel size."""
  248. super().__init__()
  249. self.channel_attention = ChannelAttention(c1)
  250. self.spatial_attention = SpatialAttention(kernel_size)
  251. def forward(self, x):
  252. """Applies the forward pass through C1 module."""
  253. return self.spatial_attention(self.channel_attention(x))
  254. class Concat(nn.Module):
  255. """Concatenate a list of tensors along dimension."""
  256. def __init__(self, dimension=1):
  257. """Concatenates a list of tensors along a specified dimension."""
  258. super().__init__()
  259. self.d = dimension
  260. def forward(self, x):
  261. """Forward pass for the YOLOv8 mask Proto module."""
  262. return torch.cat(x, self.d)
  263. class Index(nn.Module):
  264. """Returns a particular index of the input."""
  265. def __init__(self, c1, c2, index=0):
  266. """Returns a particular index of the input."""
  267. super().__init__()
  268. self.index = index
  269. def forward(self, x):
  270. """
  271. Forward pass.
  272. Expects a list of tensors as input.
  273. """
  274. return x[self.index]