|
6 | 6 | CpuAdamX86Extension, |
7 | 7 | FlashAttentionDaoCudaExtension, |
8 | 8 | FlashAttentionNpuExtension, |
9 | | - FlashAttentionXformersCudaExtension, |
| 9 | + FlashAttentionSdpaCudaExtension, |
10 | 10 | FusedOptimizerCudaExtension, |
11 | 11 | LayerNormCudaExtension, |
12 | 12 | MoeCudaExtension, |
@@ -65,9 +65,9 @@ def load(self, ext_name: str = None): |
65 | 65 | else: |
66 | 66 | usable_exts = [] |
67 | 67 | for ext in exts: |
68 | | - if ext.is_hardware_available(): |
| 68 | + if ext.is_available(): |
69 | 69 | # make sure the machine is compatible during kernel loading |
70 | | - ext.assert_hardware_compatible() |
| 70 | + ext.assert_compatible() |
71 | 71 | usable_exts.append(ext) |
72 | 72 |
|
73 | 73 | assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine." |
@@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader): |
106 | 106 |
|
107 | 107 |
|
108 | 108 | class FlashAttentionLoader(KernelLoader): |
109 | | - REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension] |
| 109 | + REGISTRY = [ |
| 110 | + FlashAttentionNpuExtension, |
| 111 | + FlashAttentionDaoCudaExtension, |
| 112 | + FlashAttentionSdpaCudaExtension, |
| 113 | + ] |
| 114 | + |
| 115 | + |
| 116 | +class FlashAttentionWithPaddingMaskLoader(KernelLoader): |
| 117 | + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension] |
| 118 | + |
| 119 | + |
| 120 | +class FlashAttentionWithCustomMaskLoader(KernelLoader): |
| 121 | + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] |
| 122 | + |
| 123 | + |
| 124 | +class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader): |
| 125 | + REGISTRY = [FlashAttentionSdpaCudaExtension] |
0 commit comments