|
|
@@ -1205,44 +1205,40 @@ class AAttn(nn.Module):
|
|
|
"""Processes the input tensor 'x' through the area-attention or attention mechanism."""
|
|
|
B, C, H, W = x.shape
|
|
|
N = H * W
|
|
|
+
|
|
|
+ qkv = self.qkv(x).flatten(2).transpose(1, 2)
|
|
|
+ if self.area > 1:
|
|
|
+ qkv = qkv.reshape(B * self.area, N // self.area, C * 3)
|
|
|
+ B, N, _ = qkv.shape
|
|
|
+ q, k, v = qkv.view(B, N, self.num_heads, self.head_dim * 3).split(
|
|
|
+ [self.head_dim, self.head_dim, self.head_dim], dim=3
|
|
|
+ )
|
|
|
+
|
|
|
if x.is_cuda:
|
|
|
- qkv = self.qkv(x).flatten(2).transpose(1, 2)
|
|
|
- if self.area > 1:
|
|
|
- qkv = qkv.reshape(B * self.area, N // self.area, C * 3)
|
|
|
- B, N, _ = qkv.shape
|
|
|
- q, k, v = qkv.view(B, N, self.num_heads, self.head_dim * 3).split(
|
|
|
- [self.head_dim, self.head_dim, self.head_dim], dim=3
|
|
|
- )
|
|
|
x = flash_attn_func(
|
|
|
q.contiguous().half(),
|
|
|
k.contiguous().half(),
|
|
|
v.contiguous().half()
|
|
|
).to(q.dtype)
|
|
|
- if self.area > 1:
|
|
|
- x = x.reshape(B // self.area, N * self.area, C)
|
|
|
- v = v.reshape(B // self.area, N * self.area, C)
|
|
|
- B, N, _ = x.shape
|
|
|
- x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
|
|
- v = v.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
|
|
else:
|
|
|
- qkv = self.qkv(x).flatten(2)
|
|
|
- if self.area > 1:
|
|
|
- qkv = qkv.reshape(B * self.area, C * 3, N // self.area)
|
|
|
- B, _, N = qkv.shape
|
|
|
- q, k, v = qkv.view(B, self.num_heads, self.head_dim * 3, N).split(
|
|
|
- [self.head_dim, self.head_dim, self.head_dim], dim=2
|
|
|
- )
|
|
|
- attn = (q.transpose(-2, -1) @ k) * (self.num_heads ** -0.5)
|
|
|
+ q = q.permute(0, 2, 3, 1)
|
|
|
+ k = k.permute(0, 2, 3, 1)
|
|
|
+ v = v.permute(0, 2, 3, 1)
|
|
|
+ attn = (q.transpose(-2, -1) @ k) * (self.head_dim ** -0.5)
|
|
|
max_attn = attn.max(dim=-1, keepdim=True).values
|
|
|
exp_attn = torch.exp(attn - max_attn)
|
|
|
attn = exp_attn / exp_attn.sum(dim=-1, keepdim=True)
|
|
|
x = (v @ attn.transpose(-2, -1))
|
|
|
- if self.area > 1:
|
|
|
- x = x.reshape(B // self.area, C, N * self.area)
|
|
|
- v = v.reshape(B // self.area, C, N * self.area)
|
|
|
- B, _, N = x.shape
|
|
|
- x = x.reshape(B, C, H, W)
|
|
|
- v = v.reshape(B, C, H, W)
|
|
|
+ x = x.permute(0, 3, 1, 2)
|
|
|
+ v = v.permute(0, 3, 1, 2)
|
|
|
+
|
|
|
+ if self.area > 1:
|
|
|
+ x = x.reshape(B // self.area, N * self.area, C)
|
|
|
+ v = v.reshape(B // self.area, N * self.area, C)
|
|
|
+ B, N, _ = x.shape
|
|
|
+
|
|
|
+ x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
|
|
+ v = v.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
|
|
|
|
|
x = x + self.pe(v)
|
|
|
x = self.proj(x)
|