|
|
@@ -1284,6 +1284,7 @@ class ABlock(nn.Module):
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
+ """Initialize weights using a truncated normal distribution."""
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
if isinstance(m, nn.Conv2d) and m.bias is not None:
|