|
|
@@ -1173,8 +1173,6 @@ except Exception:
|
|
|
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
|
|
logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
|
|
|
|
|
|
-from timm.models.layers import trunc_normal_
|
|
|
-
|
|
|
class AAttn(nn.Module):
|
|
|
"""
|
|
|
Area-attention module with the requirement of flash attention.
|
|
|
@@ -1301,8 +1299,8 @@ class ABlock(nn.Module):
|
|
|
def _init_weights(self, m):
|
|
|
"""Initialize weights using a truncated normal distribution."""
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
- trunc_normal_(m.weight, std=.02)
|
|
|
- if isinstance(m, nn.Conv2d) and m.bias is not None:
|
|
|
+ nn.init.trunc_normal_(m.weight, std=0.02)
|
|
|
+ if m.bias is not None:
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
def forward(self, x):
|