@@ -1350,5 +1350,5 @@ class A2C2f(nn.Module):
y = [self.cv1(x)]
y.extend(m(y[-1]) for m in self.m)
if self.gamma is not None:
- return x + (self.gamma * self.cv2(torch.cat(y, 1)).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x + self.gamma.view(1, -1, 1, 1) * self.cv2(torch.cat(y, 1))
return self.cv2(torch.cat(y, 1))