田运杰 10 달 전
부모
커밋
e970dcd265
1개의 변경된 파일8개의 추가작업 그리고 1개의 파일을 삭제
  1. 8 1
      ultralytics/nn/modules/block.py

+ 8 - 1
ultralytics/nn/modules/block.py

@@ -1232,7 +1232,14 @@ class AAttn(nn.Module):
                 v.contiguous().half()
             ).to(q.dtype)
         elif x.is_cuda and not USE_FLASH_ATTN:
-            x = sdpa(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), attn_mask=None, dropout_p=0.0, is_causal=False)
+            x = sdpa(
+                q.permute(0, 2, 1, 3).contiguous(), 
+                k.permute(0, 2, 1, 3).contiguous(), 
+                v.permute(0, 2, 1, 3).contiguous(), 
+                attn_mask=None, 
+                dropout_p=0.0, 
+                is_causal=False
+            )
             x = x.permute(0, 2, 1, 3)
         else:
             q = q.permute(0, 2, 3, 1)