|
@@ -1160,7 +1160,7 @@ class TorchVision(nn.Module):
|
|
|
try:
|
|
try:
|
|
|
from flash_attn.flash_attn_interface import flash_attn_func
|
|
from flash_attn.flash_attn_interface import flash_attn_func
|
|
|
except Exception:
|
|
except Exception:
|
|
|
- assert False, "import FlashAttention Error! Please install FlashAttention first."
|
|
|
|
|
|
|
+ assert False, "import FlashAttention error! Please install FlashAttention first."
|
|
|
|
|
|
|
|
class AAttn(nn.Module):
|
|
class AAttn(nn.Module):
|
|
|
"""
|
|
"""
|