block.py 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. """Block modules."""
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from ultralytics.utils.torch_utils import fuse_conv_and_bn
  7. from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
  8. from .transformer import TransformerBlock
  9. __all__ = (
  10. "DFL",
  11. "HGBlock",
  12. "HGStem",
  13. "SPP",
  14. "SPPF",
  15. "C1",
  16. "C2",
  17. "C3",
  18. "C2f",
  19. "C2fAttn",
  20. "ImagePoolingAttn",
  21. "ContrastiveHead",
  22. "BNContrastiveHead",
  23. "C3x",
  24. "C3TR",
  25. "C3Ghost",
  26. "GhostBottleneck",
  27. "Bottleneck",
  28. "BottleneckCSP",
  29. "Proto",
  30. "RepC3",
  31. "ResNetLayer",
  32. "RepNCSPELAN4",
  33. "ELAN1",
  34. "ADown",
  35. "AConv",
  36. "SPPELAN",
  37. "CBFuse",
  38. "CBLinear",
  39. "C3k2",
  40. "C2fPSA",
  41. "C2PSA",
  42. "RepVGGDW",
  43. "CIB",
  44. "C2fCIB",
  45. "Attention",
  46. "PSA",
  47. "SCDown",
  48. "TorchVision",
  49. )
  50. class DFL(nn.Module):
  51. """
  52. Integral module of Distribution Focal Loss (DFL).
  53. Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
  54. """
  55. def __init__(self, c1=16):
  56. """Initialize a convolutional layer with a given number of input channels."""
  57. super().__init__()
  58. self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
  59. x = torch.arange(c1, dtype=torch.float)
  60. self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
  61. self.c1 = c1
  62. def forward(self, x):
  63. """Applies a transformer layer on input tensor 'x' and returns a tensor."""
  64. b, _, a = x.shape # batch, channels, anchors
  65. return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
  66. # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
  67. class Proto(nn.Module):
  68. """YOLOv8 mask Proto module for segmentation models."""
  69. def __init__(self, c1, c_=256, c2=32):
  70. """
  71. Initializes the YOLOv8 mask Proto module with specified number of protos and masks.
  72. Input arguments are ch_in, number of protos, number of masks.
  73. """
  74. super().__init__()
  75. self.cv1 = Conv(c1, c_, k=3)
  76. self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
  77. self.cv2 = Conv(c_, c_, k=3)
  78. self.cv3 = Conv(c_, c2)
  79. def forward(self, x):
  80. """Performs a forward pass through layers using an upsampled input image."""
  81. return self.cv3(self.cv2(self.upsample(self.cv1(x))))
  82. class HGStem(nn.Module):
  83. """
  84. StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.
  85. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
  86. """
  87. def __init__(self, c1, cm, c2):
  88. """Initialize the SPP layer with input/output channels and specified kernel sizes for max pooling."""
  89. super().__init__()
  90. self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU())
  91. self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU())
  92. self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU())
  93. self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU())
  94. self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU())
  95. self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True)
  96. def forward(self, x):
  97. """Forward pass of a PPHGNetV2 backbone layer."""
  98. x = self.stem1(x)
  99. x = F.pad(x, [0, 1, 0, 1])
  100. x2 = self.stem2a(x)
  101. x2 = F.pad(x2, [0, 1, 0, 1])
  102. x2 = self.stem2b(x2)
  103. x1 = self.pool(x)
  104. x = torch.cat([x1, x2], dim=1)
  105. x = self.stem3(x)
  106. x = self.stem4(x)
  107. return x
  108. class HGBlock(nn.Module):
  109. """
  110. HG_Block of PPHGNetV2 with 2 convolutions and LightConv.
  111. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
  112. """
  113. def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()):
  114. """Initializes a CSP Bottleneck with 1 convolution using specified input and output channels."""
  115. super().__init__()
  116. block = LightConv if lightconv else Conv
  117. self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n))
  118. self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv
  119. self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv
  120. self.add = shortcut and c1 == c2
  121. def forward(self, x):
  122. """Forward pass of a PPHGNetV2 backbone layer."""
  123. y = [x]
  124. y.extend(m(y[-1]) for m in self.m)
  125. y = self.ec(self.sc(torch.cat(y, 1)))
  126. return y + x if self.add else y
  127. class SPP(nn.Module):
  128. """Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
  129. def __init__(self, c1, c2, k=(5, 9, 13)):
  130. """Initialize the SPP layer with input/output channels and pooling kernel sizes."""
  131. super().__init__()
  132. c_ = c1 // 2 # hidden channels
  133. self.cv1 = Conv(c1, c_, 1, 1)
  134. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  135. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  136. def forward(self, x):
  137. """Forward pass of the SPP layer, performing spatial pyramid pooling."""
  138. x = self.cv1(x)
  139. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  140. class SPPF(nn.Module):
  141. """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
  142. def __init__(self, c1, c2, k=5):
  143. """
  144. Initializes the SPPF layer with given input/output channels and kernel size.
  145. This module is equivalent to SPP(k=(5, 9, 13)).
  146. """
  147. super().__init__()
  148. c_ = c1 // 2 # hidden channels
  149. self.cv1 = Conv(c1, c_, 1, 1)
  150. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  151. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  152. def forward(self, x):
  153. """Forward pass through Ghost Convolution block."""
  154. y = [self.cv1(x)]
  155. y.extend(self.m(y[-1]) for _ in range(3))
  156. return self.cv2(torch.cat(y, 1))
  157. class C1(nn.Module):
  158. """CSP Bottleneck with 1 convolution."""
  159. def __init__(self, c1, c2, n=1):
  160. """Initializes the CSP Bottleneck with configurations for 1 convolution with arguments ch_in, ch_out, number."""
  161. super().__init__()
  162. self.cv1 = Conv(c1, c2, 1, 1)
  163. self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))
  164. def forward(self, x):
  165. """Applies cross-convolutions to input in the C3 module."""
  166. y = self.cv1(x)
  167. return self.m(y) + y
  168. class C2(nn.Module):
  169. """CSP Bottleneck with 2 convolutions."""
  170. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  171. """Initializes a CSP Bottleneck with 2 convolutions and optional shortcut connection."""
  172. super().__init__()
  173. self.c = int(c2 * e) # hidden channels
  174. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  175. self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)
  176. # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention()
  177. self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))
  178. def forward(self, x):
  179. """Forward pass through the CSP bottleneck with 2 convolutions."""
  180. a, b = self.cv1(x).chunk(2, 1)
  181. return self.cv2(torch.cat((self.m(a), b), 1))
  182. class C2f(nn.Module):
  183. """Faster Implementation of CSP Bottleneck with 2 convolutions."""
  184. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
  185. """Initializes a CSP bottleneck with 2 convolutions and n Bottleneck blocks for faster processing."""
  186. super().__init__()
  187. self.c = int(c2 * e) # hidden channels
  188. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  189. self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
  190. self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
  191. def forward(self, x):
  192. """Forward pass through C2f layer."""
  193. y = list(self.cv1(x).chunk(2, 1))
  194. y.extend(m(y[-1]) for m in self.m)
  195. return self.cv2(torch.cat(y, 1))
  196. def forward_split(self, x):
  197. """Forward pass using split() instead of chunk()."""
  198. y = self.cv1(x).split((self.c, self.c), 1)
  199. y = [y[0], y[1]]
  200. y.extend(m(y[-1]) for m in self.m)
  201. return self.cv2(torch.cat(y, 1))
  202. class C3(nn.Module):
  203. """CSP Bottleneck with 3 convolutions."""
  204. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  205. """Initialize the CSP Bottleneck with given channels, number, shortcut, groups, and expansion values."""
  206. super().__init__()
  207. c_ = int(c2 * e) # hidden channels
  208. self.cv1 = Conv(c1, c_, 1, 1)
  209. self.cv2 = Conv(c1, c_, 1, 1)
  210. self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
  211. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
  212. def forward(self, x):
  213. """Forward pass through the CSP bottleneck with 2 convolutions."""
  214. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
  215. class C3x(C3):
  216. """C3 module with cross-convolutions."""
  217. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  218. """Initialize C3TR instance and set default parameters."""
  219. super().__init__(c1, c2, n, shortcut, g, e)
  220. self.c_ = int(c2 * e)
  221. self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))
  222. class RepC3(nn.Module):
  223. """Rep C3."""
  224. def __init__(self, c1, c2, n=3, e=1.0):
  225. """Initialize CSP Bottleneck with a single convolution using input channels, output channels, and number."""
  226. super().__init__()
  227. c_ = int(c2 * e) # hidden channels
  228. self.cv1 = Conv(c1, c_, 1, 1)
  229. self.cv2 = Conv(c1, c_, 1, 1)
  230. self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)])
  231. self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity()
  232. def forward(self, x):
  233. """Forward pass of RT-DETR neck layer."""
  234. return self.cv3(self.m(self.cv1(x)) + self.cv2(x))
  235. class C3TR(C3):
  236. """C3 module with TransformerBlock()."""
  237. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  238. """Initialize C3Ghost module with GhostBottleneck()."""
  239. super().__init__(c1, c2, n, shortcut, g, e)
  240. c_ = int(c2 * e)
  241. self.m = TransformerBlock(c_, c_, 4, n)
  242. class C3Ghost(C3):
  243. """C3 module with GhostBottleneck()."""
  244. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  245. """Initialize 'SPP' module with various pooling sizes for spatial pyramid pooling."""
  246. super().__init__(c1, c2, n, shortcut, g, e)
  247. c_ = int(c2 * e) # hidden channels
  248. self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
  249. class GhostBottleneck(nn.Module):
  250. """Ghost Bottleneck https://github.com/huawei-noah/ghostnet."""
  251. def __init__(self, c1, c2, k=3, s=1):
  252. """Initializes GhostBottleneck module with arguments ch_in, ch_out, kernel, stride."""
  253. super().__init__()
  254. c_ = c2 // 2
  255. self.conv = nn.Sequential(
  256. GhostConv(c1, c_, 1, 1), # pw
  257. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  258. GhostConv(c_, c2, 1, 1, act=False), # pw-linear
  259. )
  260. self.shortcut = (
  261. nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
  262. )
  263. def forward(self, x):
  264. """Applies skip connection and concatenation to input tensor."""
  265. return self.conv(x) + self.shortcut(x)
  266. class Bottleneck(nn.Module):
  267. """Standard bottleneck."""
  268. def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
  269. """Initializes a standard bottleneck module with optional shortcut connection and configurable parameters."""
  270. super().__init__()
  271. c_ = int(c2 * e) # hidden channels
  272. self.cv1 = Conv(c1, c_, k[0], 1)
  273. self.cv2 = Conv(c_, c2, k[1], 1, g=g)
  274. self.add = shortcut and c1 == c2
  275. def forward(self, x):
  276. """Applies the YOLO FPN to input data."""
  277. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  278. class BottleneckCSP(nn.Module):
  279. """CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks."""
  280. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  281. """Initializes the CSP Bottleneck given arguments for ch_in, ch_out, number, shortcut, groups, expansion."""
  282. super().__init__()
  283. c_ = int(c2 * e) # hidden channels
  284. self.cv1 = Conv(c1, c_, 1, 1)
  285. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  286. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  287. self.cv4 = Conv(2 * c_, c2, 1, 1)
  288. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  289. self.act = nn.SiLU()
  290. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  291. def forward(self, x):
  292. """Applies a CSP bottleneck with 3 convolutions."""
  293. y1 = self.cv3(self.m(self.cv1(x)))
  294. y2 = self.cv2(x)
  295. return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
  296. class ResNetBlock(nn.Module):
  297. """ResNet block with standard convolution layers."""
  298. def __init__(self, c1, c2, s=1, e=4):
  299. """Initialize convolution with given parameters."""
  300. super().__init__()
  301. c3 = e * c2
  302. self.cv1 = Conv(c1, c2, k=1, s=1, act=True)
  303. self.cv2 = Conv(c2, c2, k=3, s=s, p=1, act=True)
  304. self.cv3 = Conv(c2, c3, k=1, act=False)
  305. self.shortcut = nn.Sequential(Conv(c1, c3, k=1, s=s, act=False)) if s != 1 or c1 != c3 else nn.Identity()
  306. def forward(self, x):
  307. """Forward pass through the ResNet block."""
  308. return F.relu(self.cv3(self.cv2(self.cv1(x))) + self.shortcut(x))
  309. class ResNetLayer(nn.Module):
  310. """ResNet layer with multiple ResNet blocks."""
  311. def __init__(self, c1, c2, s=1, is_first=False, n=1, e=4):
  312. """Initializes the ResNetLayer given arguments."""
  313. super().__init__()
  314. self.is_first = is_first
  315. if self.is_first:
  316. self.layer = nn.Sequential(
  317. Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  318. )
  319. else:
  320. blocks = [ResNetBlock(c1, c2, s, e=e)]
  321. blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)])
  322. self.layer = nn.Sequential(*blocks)
  323. def forward(self, x):
  324. """Forward pass through the ResNet layer."""
  325. return self.layer(x)
  326. class MaxSigmoidAttnBlock(nn.Module):
  327. """Max Sigmoid attention block."""
  328. def __init__(self, c1, c2, nh=1, ec=128, gc=512, scale=False):
  329. """Initializes MaxSigmoidAttnBlock with specified arguments."""
  330. super().__init__()
  331. self.nh = nh
  332. self.hc = c2 // nh
  333. self.ec = Conv(c1, ec, k=1, act=False) if c1 != ec else None
  334. self.gl = nn.Linear(gc, ec)
  335. self.bias = nn.Parameter(torch.zeros(nh))
  336. self.proj_conv = Conv(c1, c2, k=3, s=1, act=False)
  337. self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0
  338. def forward(self, x, guide):
  339. """Forward process."""
  340. bs, _, h, w = x.shape
  341. guide = self.gl(guide)
  342. guide = guide.view(bs, -1, self.nh, self.hc)
  343. embed = self.ec(x) if self.ec is not None else x
  344. embed = embed.view(bs, self.nh, self.hc, h, w)
  345. aw = torch.einsum("bmchw,bnmc->bmhwn", embed, guide)
  346. aw = aw.max(dim=-1)[0]
  347. aw = aw / (self.hc**0.5)
  348. aw = aw + self.bias[None, :, None, None]
  349. aw = aw.sigmoid() * self.scale
  350. x = self.proj_conv(x)
  351. x = x.view(bs, self.nh, -1, h, w)
  352. x = x * aw.unsqueeze(2)
  353. return x.view(bs, -1, h, w)
  354. class C2fAttn(nn.Module):
  355. """C2f module with an additional attn module."""
  356. def __init__(self, c1, c2, n=1, ec=128, nh=1, gc=512, shortcut=False, g=1, e=0.5):
  357. """Initializes C2f module with attention mechanism for enhanced feature extraction and processing."""
  358. super().__init__()
  359. self.c = int(c2 * e) # hidden channels
  360. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  361. self.cv2 = Conv((3 + n) * self.c, c2, 1) # optional act=FReLU(c2)
  362. self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
  363. self.attn = MaxSigmoidAttnBlock(self.c, self.c, gc=gc, ec=ec, nh=nh)
  364. def forward(self, x, guide):
  365. """Forward pass through C2f layer."""
  366. y = list(self.cv1(x).chunk(2, 1))
  367. y.extend(m(y[-1]) for m in self.m)
  368. y.append(self.attn(y[-1], guide))
  369. return self.cv2(torch.cat(y, 1))
  370. def forward_split(self, x, guide):
  371. """Forward pass using split() instead of chunk()."""
  372. y = list(self.cv1(x).split((self.c, self.c), 1))
  373. y.extend(m(y[-1]) for m in self.m)
  374. y.append(self.attn(y[-1], guide))
  375. return self.cv2(torch.cat(y, 1))
  376. class ImagePoolingAttn(nn.Module):
  377. """ImagePoolingAttn: Enhance the text embeddings with image-aware information."""
  378. def __init__(self, ec=256, ch=(), ct=512, nh=8, k=3, scale=False):
  379. """Initializes ImagePoolingAttn with specified arguments."""
  380. super().__init__()
  381. nf = len(ch)
  382. self.query = nn.Sequential(nn.LayerNorm(ct), nn.Linear(ct, ec))
  383. self.key = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec))
  384. self.value = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec))
  385. self.proj = nn.Linear(ec, ct)
  386. self.scale = nn.Parameter(torch.tensor([0.0]), requires_grad=True) if scale else 1.0
  387. self.projections = nn.ModuleList([nn.Conv2d(in_channels, ec, kernel_size=1) for in_channels in ch])
  388. self.im_pools = nn.ModuleList([nn.AdaptiveMaxPool2d((k, k)) for _ in range(nf)])
  389. self.ec = ec
  390. self.nh = nh
  391. self.nf = nf
  392. self.hc = ec // nh
  393. self.k = k
  394. def forward(self, x, text):
  395. """Executes attention mechanism on input tensor x and guide tensor."""
  396. bs = x[0].shape[0]
  397. assert len(x) == self.nf
  398. num_patches = self.k**2
  399. x = [pool(proj(x)).view(bs, -1, num_patches) for (x, proj, pool) in zip(x, self.projections, self.im_pools)]
  400. x = torch.cat(x, dim=-1).transpose(1, 2)
  401. q = self.query(text)
  402. k = self.key(x)
  403. v = self.value(x)
  404. # q = q.reshape(1, text.shape[1], self.nh, self.hc).repeat(bs, 1, 1, 1)
  405. q = q.reshape(bs, -1, self.nh, self.hc)
  406. k = k.reshape(bs, -1, self.nh, self.hc)
  407. v = v.reshape(bs, -1, self.nh, self.hc)
  408. aw = torch.einsum("bnmc,bkmc->bmnk", q, k)
  409. aw = aw / (self.hc**0.5)
  410. aw = F.softmax(aw, dim=-1)
  411. x = torch.einsum("bmnk,bkmc->bnmc", aw, v)
  412. x = self.proj(x.reshape(bs, -1, self.ec))
  413. return x * self.scale + text
  414. class ContrastiveHead(nn.Module):
  415. """Implements contrastive learning head for region-text similarity in vision-language models."""
  416. def __init__(self):
  417. """Initializes ContrastiveHead with specified region-text similarity parameters."""
  418. super().__init__()
  419. # NOTE: use -10.0 to keep the init cls loss consistency with other losses
  420. self.bias = nn.Parameter(torch.tensor([-10.0]))
  421. self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
  422. def forward(self, x, w):
  423. """Forward function of contrastive learning."""
  424. x = F.normalize(x, dim=1, p=2)
  425. w = F.normalize(w, dim=-1, p=2)
  426. x = torch.einsum("bchw,bkc->bkhw", x, w)
  427. return x * self.logit_scale.exp() + self.bias
  428. class BNContrastiveHead(nn.Module):
  429. """
  430. Batch Norm Contrastive Head for YOLO-World using batch norm instead of l2-normalization.
  431. Args:
  432. embed_dims (int): Embed dimensions of text and image features.
  433. """
  434. def __init__(self, embed_dims: int):
  435. """Initialize ContrastiveHead with region-text similarity parameters."""
  436. super().__init__()
  437. self.norm = nn.BatchNorm2d(embed_dims)
  438. # NOTE: use -10.0 to keep the init cls loss consistency with other losses
  439. self.bias = nn.Parameter(torch.tensor([-10.0]))
  440. # use -1.0 is more stable
  441. self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
  442. def forward(self, x, w):
  443. """Forward function of contrastive learning."""
  444. x = self.norm(x)
  445. w = F.normalize(w, dim=-1, p=2)
  446. x = torch.einsum("bchw,bkc->bkhw", x, w)
  447. return x * self.logit_scale.exp() + self.bias
  448. class RepBottleneck(Bottleneck):
  449. """Rep bottleneck."""
  450. def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
  451. """Initializes a RepBottleneck module with customizable in/out channels, shortcuts, groups and expansion."""
  452. super().__init__(c1, c2, shortcut, g, k, e)
  453. c_ = int(c2 * e) # hidden channels
  454. self.cv1 = RepConv(c1, c_, k[0], 1)
  455. class RepCSP(C3):
  456. """Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction."""
  457. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  458. """Initializes RepCSP layer with given channels, repetitions, shortcut, groups and expansion ratio."""
  459. super().__init__(c1, c2, n, shortcut, g, e)
  460. c_ = int(c2 * e) # hidden channels
  461. self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  462. class RepNCSPELAN4(nn.Module):
  463. """CSP-ELAN."""
  464. def __init__(self, c1, c2, c3, c4, n=1):
  465. """Initializes CSP-ELAN layer with specified channel sizes, repetitions, and convolutions."""
  466. super().__init__()
  467. self.c = c3 // 2
  468. self.cv1 = Conv(c1, c3, 1, 1)
  469. self.cv2 = nn.Sequential(RepCSP(c3 // 2, c4, n), Conv(c4, c4, 3, 1))
  470. self.cv3 = nn.Sequential(RepCSP(c4, c4, n), Conv(c4, c4, 3, 1))
  471. self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)
  472. def forward(self, x):
  473. """Forward pass through RepNCSPELAN4 layer."""
  474. y = list(self.cv1(x).chunk(2, 1))
  475. y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
  476. return self.cv4(torch.cat(y, 1))
  477. def forward_split(self, x):
  478. """Forward pass using split() instead of chunk()."""
  479. y = list(self.cv1(x).split((self.c, self.c), 1))
  480. y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
  481. return self.cv4(torch.cat(y, 1))
  482. class ELAN1(RepNCSPELAN4):
  483. """ELAN1 module with 4 convolutions."""
  484. def __init__(self, c1, c2, c3, c4):
  485. """Initializes ELAN1 layer with specified channel sizes."""
  486. super().__init__(c1, c2, c3, c4)
  487. self.c = c3 // 2
  488. self.cv1 = Conv(c1, c3, 1, 1)
  489. self.cv2 = Conv(c3 // 2, c4, 3, 1)
  490. self.cv3 = Conv(c4, c4, 3, 1)
  491. self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)
  492. class AConv(nn.Module):
  493. """AConv."""
  494. def __init__(self, c1, c2):
  495. """Initializes AConv module with convolution layers."""
  496. super().__init__()
  497. self.cv1 = Conv(c1, c2, 3, 2, 1)
  498. def forward(self, x):
  499. """Forward pass through AConv layer."""
  500. x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
  501. return self.cv1(x)
  502. class ADown(nn.Module):
  503. """ADown."""
  504. def __init__(self, c1, c2):
  505. """Initializes ADown module with convolution layers to downsample input from channels c1 to c2."""
  506. super().__init__()
  507. self.c = c2 // 2
  508. self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)
  509. self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)
  510. def forward(self, x):
  511. """Forward pass through ADown layer."""
  512. x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
  513. x1, x2 = x.chunk(2, 1)
  514. x1 = self.cv1(x1)
  515. x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
  516. x2 = self.cv2(x2)
  517. return torch.cat((x1, x2), 1)
  518. class SPPELAN(nn.Module):
  519. """SPP-ELAN."""
  520. def __init__(self, c1, c2, c3, k=5):
  521. """Initializes SPP-ELAN block with convolution and max pooling layers for spatial pyramid pooling."""
  522. super().__init__()
  523. self.c = c3
  524. self.cv1 = Conv(c1, c3, 1, 1)
  525. self.cv2 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  526. self.cv3 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  527. self.cv4 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  528. self.cv5 = Conv(4 * c3, c2, 1, 1)
  529. def forward(self, x):
  530. """Forward pass through SPPELAN layer."""
  531. y = [self.cv1(x)]
  532. y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
  533. return self.cv5(torch.cat(y, 1))
  534. class CBLinear(nn.Module):
  535. """CBLinear."""
  536. def __init__(self, c1, c2s, k=1, s=1, p=None, g=1):
  537. """Initializes the CBLinear module, passing inputs unchanged."""
  538. super().__init__()
  539. self.c2s = c2s
  540. self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)
  541. def forward(self, x):
  542. """Forward pass through CBLinear layer."""
  543. return self.conv(x).split(self.c2s, dim=1)
  544. class CBFuse(nn.Module):
  545. """CBFuse."""
  546. def __init__(self, idx):
  547. """Initializes CBFuse module with layer index for selective feature fusion."""
  548. super().__init__()
  549. self.idx = idx
  550. def forward(self, xs):
  551. """Forward pass through CBFuse layer."""
  552. target_size = xs[-1].shape[2:]
  553. res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
  554. return torch.sum(torch.stack(res + xs[-1:]), dim=0)
  555. class C3f(nn.Module):
  556. """Faster Implementation of CSP Bottleneck with 2 convolutions."""
  557. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
  558. """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
  559. expansion.
  560. """
  561. super().__init__()
  562. c_ = int(c2 * e) # hidden channels
  563. self.cv1 = Conv(c1, c_, 1, 1)
  564. self.cv2 = Conv(c1, c_, 1, 1)
  565. self.cv3 = Conv((2 + n) * c_, c2, 1) # optional act=FReLU(c2)
  566. self.m = nn.ModuleList(Bottleneck(c_, c_, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
  567. def forward(self, x):
  568. """Forward pass through C2f layer."""
  569. y = [self.cv2(x), self.cv1(x)]
  570. y.extend(m(y[-1]) for m in self.m)
  571. return self.cv3(torch.cat(y, 1))
  572. class C3k2(C2f):
  573. """Faster Implementation of CSP Bottleneck with 2 convolutions."""
  574. def __init__(self, c1, c2, n=1, c3k=False, e=0.5, g=1, shortcut=True):
  575. """Initializes the C3k2 module, a faster CSP Bottleneck with 2 convolutions and optional C3k blocks."""
  576. super().__init__(c1, c2, n, shortcut, g, e)
  577. self.m = nn.ModuleList(
  578. C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)
  579. )
  580. class C3k(C3):
  581. """C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks."""
  582. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3):
  583. """Initializes the C3k module with specified channels, number of layers, and configurations."""
  584. super().__init__(c1, c2, n, shortcut, g, e)
  585. c_ = int(c2 * e) # hidden channels
  586. # self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
  587. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
  588. class RepVGGDW(torch.nn.Module):
  589. """RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture."""
  590. def __init__(self, ed) -> None:
  591. """Initializes RepVGGDW with depthwise separable convolutional layers for efficient processing."""
  592. super().__init__()
  593. self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False)
  594. self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False)
  595. self.dim = ed
  596. self.act = nn.SiLU()
  597. def forward(self, x):
  598. """
  599. Performs a forward pass of the RepVGGDW block.
  600. Args:
  601. x (torch.Tensor): Input tensor.
  602. Returns:
  603. (torch.Tensor): Output tensor after applying the depth wise separable convolution.
  604. """
  605. return self.act(self.conv(x) + self.conv1(x))
  606. def forward_fuse(self, x):
  607. """
  608. Performs a forward pass of the RepVGGDW block without fusing the convolutions.
  609. Args:
  610. x (torch.Tensor): Input tensor.
  611. Returns:
  612. (torch.Tensor): Output tensor after applying the depth wise separable convolution.
  613. """
  614. return self.act(self.conv(x))
  615. @torch.no_grad()
  616. def fuse(self):
  617. """
  618. Fuses the convolutional layers in the RepVGGDW block.
  619. This method fuses the convolutional layers and updates the weights and biases accordingly.
  620. """
  621. conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn)
  622. conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn)
  623. conv_w = conv.weight
  624. conv_b = conv.bias
  625. conv1_w = conv1.weight
  626. conv1_b = conv1.bias
  627. conv1_w = torch.nn.functional.pad(conv1_w, [2, 2, 2, 2])
  628. final_conv_w = conv_w + conv1_w
  629. final_conv_b = conv_b + conv1_b
  630. conv.weight.data.copy_(final_conv_w)
  631. conv.bias.data.copy_(final_conv_b)
  632. self.conv = conv
  633. del self.conv1
  634. class CIB(nn.Module):
  635. """
  636. Conditional Identity Block (CIB) module.
  637. Args:
  638. c1 (int): Number of input channels.
  639. c2 (int): Number of output channels.
  640. shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True.
  641. e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5.
  642. lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False.
  643. """
  644. def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False):
  645. """Initializes the custom model with optional shortcut, scaling factor, and RepVGGDW layer."""
  646. super().__init__()
  647. c_ = int(c2 * e) # hidden channels
  648. self.cv1 = nn.Sequential(
  649. Conv(c1, c1, 3, g=c1),
  650. Conv(c1, 2 * c_, 1),
  651. RepVGGDW(2 * c_) if lk else Conv(2 * c_, 2 * c_, 3, g=2 * c_),
  652. Conv(2 * c_, c2, 1),
  653. Conv(c2, c2, 3, g=c2),
  654. )
  655. self.add = shortcut and c1 == c2
  656. def forward(self, x):
  657. """
  658. Forward pass of the CIB module.
  659. Args:
  660. x (torch.Tensor): Input tensor.
  661. Returns:
  662. (torch.Tensor): Output tensor.
  663. """
  664. return x + self.cv1(x) if self.add else self.cv1(x)
  665. class C2fCIB(C2f):
  666. """
  667. C2fCIB class represents a convolutional block with C2f and CIB modules.
  668. Args:
  669. c1 (int): Number of input channels.
  670. c2 (int): Number of output channels.
  671. n (int, optional): Number of CIB modules to stack. Defaults to 1.
  672. shortcut (bool, optional): Whether to use shortcut connection. Defaults to False.
  673. lk (bool, optional): Whether to use local key connection. Defaults to False.
  674. g (int, optional): Number of groups for grouped convolution. Defaults to 1.
  675. e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5.
  676. """
  677. def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5):
  678. """Initializes the module with specified parameters for channel, shortcut, local key, groups, and expansion."""
  679. super().__init__(c1, c2, n, shortcut, g, e)
  680. self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))
  681. class Attention(nn.Module):
  682. """
  683. Attention module that performs self-attention on the input tensor.
  684. Args:
  685. dim (int): The input tensor dimension.
  686. num_heads (int): The number of attention heads.
  687. attn_ratio (float): The ratio of the attention key dimension to the head dimension.
  688. Attributes:
  689. num_heads (int): The number of attention heads.
  690. head_dim (int): The dimension of each attention head.
  691. key_dim (int): The dimension of the attention key.
  692. scale (float): The scaling factor for the attention scores.
  693. qkv (Conv): Convolutional layer for computing the query, key, and value.
  694. proj (Conv): Convolutional layer for projecting the attended values.
  695. pe (Conv): Convolutional layer for positional encoding.
  696. """
  697. def __init__(self, dim, num_heads=8, attn_ratio=0.5):
  698. """Initializes multi-head attention module with query, key, and value convolutions and positional encoding."""
  699. super().__init__()
  700. self.num_heads = num_heads
  701. self.head_dim = dim // num_heads
  702. self.key_dim = int(self.head_dim * attn_ratio)
  703. self.scale = self.key_dim**-0.5
  704. nh_kd = self.key_dim * num_heads
  705. h = dim + nh_kd * 2
  706. self.qkv = Conv(dim, h, 1, act=False)
  707. self.proj = Conv(dim, dim, 1, act=False)
  708. self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
  709. def forward(self, x):
  710. """
  711. Forward pass of the Attention module.
  712. Args:
  713. x (torch.Tensor): The input tensor.
  714. Returns:
  715. (torch.Tensor): The output tensor after self-attention.
  716. """
  717. B, C, H, W = x.shape
  718. N = H * W
  719. qkv = self.qkv(x)
  720. q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
  721. [self.key_dim, self.key_dim, self.head_dim], dim=2
  722. )
  723. attn = (q.transpose(-2, -1) @ k) * self.scale
  724. attn = attn.softmax(dim=-1)
  725. x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
  726. x = self.proj(x)
  727. return x
  728. class PSABlock(nn.Module):
  729. """
  730. PSABlock class implementing a Position-Sensitive Attention block for neural networks.
  731. This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers
  732. with optional shortcut connections.
  733. Attributes:
  734. attn (Attention): Multi-head attention module.
  735. ffn (nn.Sequential): Feed-forward neural network module.
  736. add (bool): Flag indicating whether to add shortcut connections.
  737. Methods:
  738. forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers.
  739. Examples:
  740. Create a PSABlock and perform a forward pass
  741. >>> psablock = PSABlock(c=128, attn_ratio=0.5, num_heads=4, shortcut=True)
  742. >>> input_tensor = torch.randn(1, 128, 32, 32)
  743. >>> output_tensor = psablock(input_tensor)
  744. """
  745. def __init__(self, c, attn_ratio=0.5, num_heads=4, shortcut=True) -> None:
  746. """Initializes the PSABlock with attention and feed-forward layers for enhanced feature extraction."""
  747. super().__init__()
  748. self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)
  749. self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
  750. self.add = shortcut
  751. def forward(self, x):
  752. """Executes a forward pass through PSABlock, applying attention and feed-forward layers to the input tensor."""
  753. x = x + self.attn(x) if self.add else self.attn(x)
  754. x = x + self.ffn(x) if self.add else self.ffn(x)
  755. return x
  756. class PSA(nn.Module):
  757. """
  758. PSA class for implementing Position-Sensitive Attention in neural networks.
  759. This class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to
  760. input tensors, enhancing feature extraction and processing capabilities.
  761. Attributes:
  762. c (int): Number of hidden channels after applying the initial convolution.
  763. cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
  764. cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
  765. attn (Attention): Attention module for position-sensitive attention.
  766. ffn (nn.Sequential): Feed-forward network for further processing.
  767. Methods:
  768. forward: Applies position-sensitive attention and feed-forward network to the input tensor.
  769. Examples:
  770. Create a PSA module and apply it to an input tensor
  771. >>> psa = PSA(c1=128, c2=128, e=0.5)
  772. >>> input_tensor = torch.randn(1, 128, 64, 64)
  773. >>> output_tensor = psa.forward(input_tensor)
  774. """
  775. def __init__(self, c1, c2, e=0.5):
  776. """Initializes the PSA module with input/output channels and attention mechanism for feature extraction."""
  777. super().__init__()
  778. assert c1 == c2
  779. self.c = int(c1 * e)
  780. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  781. self.cv2 = Conv(2 * self.c, c1, 1)
  782. self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)
  783. self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))
  784. def forward(self, x):
  785. """Executes forward pass in PSA module, applying attention and feed-forward layers to the input tensor."""
  786. a, b = self.cv1(x).split((self.c, self.c), dim=1)
  787. b = b + self.attn(b)
  788. b = b + self.ffn(b)
  789. return self.cv2(torch.cat((a, b), 1))
  790. class C2PSA(nn.Module):
  791. """
  792. C2PSA module with attention mechanism for enhanced feature extraction and processing.
  793. This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing
  794. capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.
  795. Attributes:
  796. c (int): Number of hidden channels.
  797. cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
  798. cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
  799. m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations.
  800. Methods:
  801. forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations.
  802. Notes:
  803. This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules.
  804. Examples:
  805. >>> c2psa = C2PSA(c1=256, c2=256, n=3, e=0.5)
  806. >>> input_tensor = torch.randn(1, 256, 64, 64)
  807. >>> output_tensor = c2psa(input_tensor)
  808. """
  809. def __init__(self, c1, c2, n=1, e=0.5):
  810. """Initializes the C2PSA module with specified input/output channels, number of layers, and expansion ratio."""
  811. super().__init__()
  812. assert c1 == c2
  813. self.c = int(c1 * e)
  814. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  815. self.cv2 = Conv(2 * self.c, c1, 1)
  816. self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
  817. def forward(self, x):
  818. """Processes the input tensor 'x' through a series of PSA blocks and returns the transformed tensor."""
  819. a, b = self.cv1(x).split((self.c, self.c), dim=1)
  820. b = self.m(b)
  821. return self.cv2(torch.cat((a, b), 1))
  822. class C2fPSA(C2f):
  823. """
  824. C2fPSA module with enhanced feature extraction using PSA blocks.
  825. This class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature extraction.
  826. Attributes:
  827. c (int): Number of hidden channels.
  828. cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
  829. cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
  830. m (nn.ModuleList): List of PSA blocks for feature extraction.
  831. Methods:
  832. forward: Performs a forward pass through the C2fPSA module.
  833. forward_split: Performs a forward pass using split() instead of chunk().
  834. Examples:
  835. >>> import torch
  836. >>> from ultralytics.models.common import C2fPSA
  837. >>> model = C2fPSA(c1=64, c2=64, n=3, e=0.5)
  838. >>> x = torch.randn(1, 64, 128, 128)
  839. >>> output = model(x)
  840. >>> print(output.shape)
  841. """
  842. def __init__(self, c1, c2, n=1, e=0.5):
  843. """Initializes the C2fPSA module, a variant of C2f with PSA blocks for enhanced feature extraction."""
  844. assert c1 == c2
  845. super().__init__(c1, c2, n=n, e=e)
  846. self.m = nn.ModuleList(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n))
  847. class SCDown(nn.Module):
  848. """
  849. SCDown module for downsampling with separable convolutions.
  850. This module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in
  851. efficiently reducing the spatial dimensions of the input tensor while maintaining the channel information.
  852. Attributes:
  853. cv1 (Conv): Pointwise convolution layer that reduces the number of channels.
  854. cv2 (Conv): Depthwise convolution layer that performs spatial downsampling.
  855. Methods:
  856. forward: Applies the SCDown module to the input tensor.
  857. Examples:
  858. >>> import torch
  859. >>> from ultralytics import SCDown
  860. >>> model = SCDown(c1=64, c2=128, k=3, s=2)
  861. >>> x = torch.randn(1, 64, 128, 128)
  862. >>> y = model(x)
  863. >>> print(y.shape)
  864. torch.Size([1, 128, 64, 64])
  865. """
  866. def __init__(self, c1, c2, k, s):
  867. """Initializes the SCDown module with specified input/output channels, kernel size, and stride."""
  868. super().__init__()
  869. self.cv1 = Conv(c1, c2, 1, 1)
  870. self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
  871. def forward(self, x):
  872. """Applies convolution and downsampling to the input tensor in the SCDown module."""
  873. return self.cv2(self.cv1(x))
  874. class TorchVision(nn.Module):
  875. """
  876. TorchVision module to allow loading any torchvision model.
  877. This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.
  878. Attributes:
  879. m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
  880. Args:
  881. c1 (int): Input channels.
  882. c2 (): Output channels.
  883. model (str): Name of the torchvision model to load.
  884. weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
  885. unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
  886. truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.
  887. split (bool, optional): Returns output from intermediate child modules as list. Default is False.
  888. """
  889. def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
  890. """Load the model and weights from torchvision."""
  891. import torchvision # scope for faster 'import ultralytics'
  892. super().__init__()
  893. if hasattr(torchvision.models, "get_model"):
  894. self.m = torchvision.models.get_model(model, weights=weights)
  895. else:
  896. self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
  897. if unwrap:
  898. layers = list(self.m.children())[:-truncate]
  899. if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin
  900. layers = [*list(layers[0].children()), *layers[1:]]
  901. self.m = nn.Sequential(*layers)
  902. self.split = split
  903. else:
  904. self.split = False
  905. self.m.head = self.m.heads = nn.Identity()
  906. def forward(self, x):
  907. """Forward pass through the model."""
  908. if self.split:
  909. y = [x]
  910. y.extend(m(y[-1]) for m in self.m)
  911. else:
  912. y = self.m(x)
  913. return y
  914. import logging
  915. logger = logging.getLogger(__name__)
  916. USE_FLASH_ATTN = False
  917. try:
  918. import torch
  919. if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: # Ampere or newer
  920. from flash_attn.flash_attn_interface import flash_attn_func
  921. USE_FLASH_ATTN = True
  922. else:
  923. from torch.nn.functional import scaled_dot_product_attention as sdpa
  924. logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
  925. except Exception:
  926. from torch.nn.functional import scaled_dot_product_attention as sdpa
  927. logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
  928. class AAttn(nn.Module):
  929. """
  930. Area-attention module with the requirement of flash attention.
  931. Attributes:
  932. dim (int): Number of hidden channels;
  933. num_heads (int): Number of heads into which the attention mechanism is divided;
  934. area (int, optional): Number of areas the feature map is divided. Defaults to 1.
  935. Methods:
  936. forward: Performs a forward process of input tensor and outputs a tensor after the execution of the area attention mechanism.
  937. Examples:
  938. >>> import torch
  939. >>> from ultralytics.nn.modules import AAttn
  940. >>> model = AAttn(dim=64, num_heads=2, area=4)
  941. >>> x = torch.randn(2, 64, 128, 128)
  942. >>> output = model(x)
  943. >>> print(output.shape)
  944. Notes:
  945. recommend that dim//num_heads be a multiple of 32 or 64.
  946. """
  947. def __init__(self, dim, num_heads, area=1):
  948. """Initializes the area-attention module, a simple yet efficient attention module for YOLO."""
  949. super().__init__()
  950. self.area = area
  951. self.num_heads = num_heads
  952. self.head_dim = head_dim = dim // num_heads
  953. all_head_dim = head_dim * self.num_heads
  954. self.qk = Conv(dim, all_head_dim * 2, 1, act=False)
  955. self.v = Conv(dim, all_head_dim, 1, act=False)
  956. self.proj = Conv(all_head_dim, dim, 1, act=False)
  957. self.pe = Conv(all_head_dim, dim, 5, 1, 2, g=dim, act=False)
  958. def forward(self, x):
  959. """Processes the input tensor 'x' through the area-attention"""
  960. B, C, H, W = x.shape
  961. N = H * W
  962. qk = self.qk(x).flatten(2).transpose(1, 2)
  963. v = self.v(x)
  964. pp = self.pe(v)
  965. v = v.flatten(2).transpose(1, 2)
  966. if self.area > 1:
  967. qk = qk.reshape(B * self.area, N // self.area, C * 2)
  968. v = v.reshape(B * self.area, N // self.area, C)
  969. B, N, _ = qk.shape
  970. q, k = qk.split([C, C], dim=2)
  971. if x.is_cuda and USE_FLASH_ATTN:
  972. q = q.view(B, N, self.num_heads, self.head_dim)
  973. k = k.view(B, N, self.num_heads, self.head_dim)
  974. v = v.view(B, N, self.num_heads, self.head_dim)
  975. x = flash_attn_func(
  976. q.contiguous().half(),
  977. k.contiguous().half(),
  978. v.contiguous().half()
  979. ).to(q.dtype)
  980. else:
  981. q = q.transpose(1, 2).view(B, self.num_heads, self.head_dim, N)
  982. k = k.transpose(1, 2).view(B, self.num_heads, self.head_dim, N)
  983. v = v.transpose(1, 2).view(B, self.num_heads, self.head_dim, N)
  984. attn = (q.transpose(-2, -1) @ k) * (self.head_dim ** -0.5)
  985. max_attn = attn.max(dim=-1, keepdim=True).values
  986. exp_attn = torch.exp(attn - max_attn)
  987. attn = exp_attn / exp_attn.sum(dim=-1, keepdim=True)
  988. x = (v @ attn.transpose(-2, -1))
  989. x = x.permute(0, 3, 1, 2)
  990. if self.area > 1:
  991. x = x.reshape(B // self.area, N * self.area, C)
  992. B, N, _ = x.shape
  993. x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
  994. return self.proj(x + pp)
  995. class ABlock(nn.Module):
  996. """
  997. ABlock class implementing a Area-Attention block with effective feature extraction.
  998. This class encapsulates the functionality for applying multi-head attention with feature map are dividing into areas
  999. and feed-forward neural network layers.
  1000. Attributes:
  1001. dim (int): Number of hidden channels;
  1002. num_heads (int): Number of heads into which the attention mechanism is divided;
  1003. mlp_ratio (float, optional): MLP expansion ratio (or MLP hidden dimension ratio). Defaults to 1.2;
  1004. area (int, optional): Number of areas the feature map is divided. Defaults to 1.
  1005. Methods:
  1006. forward: Performs a forward pass through the ABlock, applying area-attention and feed-forward layers.
  1007. Examples:
  1008. Create a ABlock and perform a forward pass
  1009. >>> model = ABlock(dim=64, num_heads=2, mlp_ratio=1.2, area=4)
  1010. >>> x = torch.randn(2, 64, 128, 128)
  1011. >>> output = model(x)
  1012. >>> print(output.shape)
  1013. Notes:
  1014. recommend that dim//num_heads be a multiple of 32 or 64.
  1015. """
  1016. def __init__(self, dim, num_heads, mlp_ratio=1.2, area=1):
  1017. """Initializes the ABlock with area-attention and feed-forward layers for faster feature extraction."""
  1018. super().__init__()
  1019. self.attn = AAttn(dim, num_heads=num_heads, area=area)
  1020. mlp_hidden_dim = int(dim * mlp_ratio)
  1021. self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False))
  1022. self.apply(self._init_weights)
  1023. def _init_weights(self, m):
  1024. """Initialize weights using a truncated normal distribution."""
  1025. if isinstance(m, nn.Conv2d):
  1026. nn.init.trunc_normal_(m.weight, std=0.02)
  1027. if m.bias is not None:
  1028. nn.init.constant_(m.bias, 0)
  1029. def forward(self, x):
  1030. """Executes a forward pass through ABlock, applying area-attention and feed-forward layers to the input tensor."""
  1031. x = x + self.attn(x)
  1032. x = x + self.mlp(x)
  1033. return x
  1034. class A2C2f(nn.Module):
  1035. """
  1036. A2C2f module with residual enhanced feature extraction using ABlock blocks with area-attention. Also known as R-ELAN
  1037. This class extends the C2f module by incorporating ABlock blocks for fast attention mechanisms and feature extraction.
  1038. Attributes:
  1039. c1 (int): Number of input channels;
  1040. c2 (int): Number of output channels;
  1041. n (int, optional): Number of 2xABlock modules to stack. Defaults to 1;
  1042. a2 (bool, optional): Whether use area-attention. Defaults to True;
  1043. area (int, optional): Number of areas the feature map is divided. Defaults to 1;
  1044. residual (bool, optional): Whether use the residual (with layer scale). Defaults to False;
  1045. mlp_ratio (float, optional): MLP expansion ratio (or MLP hidden dimension ratio). Defaults to 1.2;
  1046. e (float, optional): Expansion ratio for R-ELAN modules. Defaults to 0.5;
  1047. g (int, optional): Number of groups for grouped convolution. Defaults to 1;
  1048. shortcut (bool, optional): Whether to use shortcut connection. Defaults to True;
  1049. Methods:
  1050. forward: Performs a forward pass through the A2C2f module.
  1051. Examples:
  1052. >>> import torch
  1053. >>> from ultralytics.nn.modules import A2C2f
  1054. >>> model = A2C2f(c1=64, c2=64, n=2, a2=True, area=4, residual=True, e=0.5)
  1055. >>> x = torch.randn(2, 64, 128, 128)
  1056. >>> output = model(x)
  1057. >>> print(output.shape)
  1058. """
  1059. def __init__(self, c1, c2, n=1, a2=True, area=1, residual=False, mlp_ratio=2.0, e=0.5, g=1, shortcut=True):
  1060. super().__init__()
  1061. c_ = int(c2 * e) # hidden channels
  1062. assert c_ % 32 == 0, "Dimension of ABlock be a multiple of 32."
  1063. # num_heads = c_ // 64 if c_ // 64 >= 2 else c_ // 32
  1064. num_heads = c_ // 32
  1065. self.cv1 = Conv(c1, c_, 1, 1)
  1066. self.cv2 = Conv((1 + n) * c_, c2, 1) # optional act=FReLU(c2)
  1067. init_values = 0.01 # or smaller
  1068. self.gamma = nn.Parameter(init_values * torch.ones((c2)), requires_grad=True) if a2 and residual else None
  1069. self.m = nn.ModuleList(
  1070. nn.Sequential(*(ABlock(c_, num_heads, mlp_ratio, area) for _ in range(2))) if a2 else C3k(c_, c_, 2, shortcut, g) for _ in range(n)
  1071. )
  1072. def forward(self, x):
  1073. """Forward pass through R-ELAN layer."""
  1074. y = [self.cv1(x)]
  1075. y.extend(m(y[-1]) for m in self.m)
  1076. if self.gamma is not None:
  1077. return x + self.gamma.view(1, -1, 1, 1) * self.cv2(torch.cat(y, 1))
  1078. return self.cv2(torch.cat(y, 1))