|
|
@@ -1207,9 +1207,11 @@ class AAttn(nn.Module):
|
|
|
self.head_dim = head_dim = dim // num_heads
|
|
|
all_head_dim = head_dim * self.num_heads
|
|
|
|
|
|
- self.qkv = Conv(dim, all_head_dim * 3, 1, act=False)
|
|
|
+ self.qk = Conv(dim, all_head_dim * 2, 1, act=False)
|
|
|
+ self.v = Conv(dim, all_head_dim, 1, act=False)
|
|
|
self.proj = Conv(all_head_dim, dim, 1, act=False)
|
|
|
- self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False)
|
|
|
+
|
|
|
+ self.pe = Conv(all_head_dim, dim, 5, 1, 2, g=dim, act=False)
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
@@ -1217,53 +1219,57 @@ class AAttn(nn.Module):
|
|
|
B, C, H, W = x.shape
|
|
|
N = H * W
|
|
|
|
|
|
- qkv = self.qkv(x).flatten(2).transpose(1, 2)
|
|
|
- if self.area > 1:
|
|
|
- qkv = qkv.reshape(B * self.area, N // self.area, C * 3)
|
|
|
- B, N, _ = qkv.shape
|
|
|
- q, k, v = qkv.view(B, N, self.num_heads, self.head_dim * 3).split(
|
|
|
- [self.head_dim, self.head_dim, self.head_dim], dim=3
|
|
|
- )
|
|
|
-
|
|
|
if x.is_cuda and USE_FLASH_ATTN:
|
|
|
+ qk = self.qk(x).flatten(2).transpose(1, 2)
|
|
|
+ v = self.v(x)
|
|
|
+ pp = self.pe(v)
|
|
|
+ v = v.flatten(2).transpose(1, 2)
|
|
|
+
|
|
|
+ if self.area > 1:
|
|
|
+ qk = qk.reshape(B * self.area, N // self.area, C * 2)
|
|
|
+ v = v.reshape(B * self.area, N // self.area, C)
|
|
|
+ B, N, _ = qk.shape
|
|
|
+ q, k = qk.split([C, C], dim=2)
|
|
|
+ q = q.view(B, N, self.num_heads, self.head_dim)
|
|
|
+ k = k.view(B, N, self.num_heads, self.head_dim)
|
|
|
+ v = v.view(B, N, self.num_heads, self.head_dim)
|
|
|
+
|
|
|
x = flash_attn_func(
|
|
|
q.contiguous().half(),
|
|
|
k.contiguous().half(),
|
|
|
v.contiguous().half()
|
|
|
).to(q.dtype)
|
|
|
- elif x.is_cuda and not USE_FLASH_ATTN:
|
|
|
- x = sdpa(
|
|
|
- q.permute(0, 2, 1, 3).contiguous(),
|
|
|
- k.permute(0, 2, 1, 3).contiguous(),
|
|
|
- v.permute(0, 2, 1, 3).contiguous(),
|
|
|
- attn_mask=None,
|
|
|
- dropout_p=0.0,
|
|
|
- is_causal=False
|
|
|
- )
|
|
|
- x = x.permute(0, 2, 1, 3)
|
|
|
+
|
|
|
+ if self.area > 1:
|
|
|
+ x = x.reshape(B // self.area, N * self.area, C)
|
|
|
+ B, N, _ = x.shape
|
|
|
+ x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
|
|
else:
|
|
|
- q = q.permute(0, 2, 3, 1)
|
|
|
- k = k.permute(0, 2, 3, 1)
|
|
|
- v = v.permute(0, 2, 3, 1)
|
|
|
+ qk = self.qk(x).flatten(2)
|
|
|
+ v = self.v(x)
|
|
|
+ pp = self.pe(v)
|
|
|
+ v = v.flatten(2)
|
|
|
+ if self.area > 1:
|
|
|
+ qk = qk.reshape(B * self.area, C * 2, N // self.area)
|
|
|
+ v = v.reshape(B * self.area, C, N // self.area)
|
|
|
+ B, _, N = qk.shape
|
|
|
+
|
|
|
+ q, k = qk.split([C, C], dim=1)
|
|
|
+ q = q.view(B, self.num_heads, self.head_dim, N)
|
|
|
+ k = k.view(B, self.num_heads, self.head_dim, N)
|
|
|
+ v = v.view(B, self.num_heads, self.head_dim, N)
|
|
|
attn = (q.transpose(-2, -1) @ k) * (self.head_dim ** -0.5)
|
|
|
- max_attn = attn.max(dim=-1, keepdim=True).values
|
|
|
+ max_attn = attn.max(dim=-1, keepdim=True).values
|
|
|
exp_attn = torch.exp(attn - max_attn)
|
|
|
attn = exp_attn / exp_attn.sum(dim=-1, keepdim=True)
|
|
|
x = (v @ attn.transpose(-2, -1))
|
|
|
- x = x.permute(0, 3, 1, 2)
|
|
|
- v = v.permute(0, 3, 1, 2)
|
|
|
-
|
|
|
- if self.area > 1:
|
|
|
- x = x.reshape(B // self.area, N * self.area, C)
|
|
|
- v = v.reshape(B // self.area, N * self.area, C)
|
|
|
- B, N, _ = x.shape
|
|
|
-
|
|
|
- x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
|
|
- v = v.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
|
|
-
|
|
|
- x = x + self.pe(v)
|
|
|
- x = self.proj(x)
|
|
|
- return x
|
|
|
+
|
|
|
+ if self.area > 1:
|
|
|
+ x = x.reshape(B // self.area, C, N * self.area)
|
|
|
+ B, _, N = x.shape
|
|
|
+ x = x.reshape(B, C, H, W)
|
|
|
+
|
|
|
+ return self.proj(x + pp)
|
|
|
|
|
|
|
|
|
class ABlock(nn.Module):
|