|
@@ -1376,17 +1376,3 @@ class A2C2f(nn.Module):
|
|
|
y = list(self.cv1(x).chunk(2, 1))
|
|
y = list(self.cv1(x).chunk(2, 1))
|
|
|
y.extend(m(y[-1]) for m in self.m)
|
|
y.extend(m(y[-1]) for m in self.m)
|
|
|
return self.cv2(torch.cat(y, 1))
|
|
return self.cv2(torch.cat(y, 1))
|
|
|
-
|
|
|
|
|
- def forward_split(self, x):
|
|
|
|
|
- """Forward pass using split() instead of chunk()."""
|
|
|
|
|
- if self.a2:
|
|
|
|
|
- y = [self.cv1(x)]
|
|
|
|
|
- y.extend(m(y[-1]) for m in self.m)
|
|
|
|
|
- if self.residual:
|
|
|
|
|
- return self.align(x) + (self.gamma * self.cv2(torch.cat(y, 1)).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
- else:
|
|
|
|
|
- return self.cv2(torch.cat(y, 1))
|
|
|
|
|
- else:
|
|
|
|
|
- y = list(self.cv1(x).chunk(2, 1))
|
|
|
|
|
- y.extend(m(y[-1]) for m in self.m)
|
|
|
|
|
- return self.cv2(torch.cat(y, 1))
|
|
|