@@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
114
114
115
115
at::cuda::OptionalCUDAGuard const device_guard (device_of (a));
116
116
int32_t version_num = get_sm_version_num ();
117
- if (version_num >= 90 ) {
118
- // Hopper
117
+ // Hopper
119
118
120
- // Guard against compilation issues for sm90 kernels
119
+ // Guard against compilation issues for sm90 kernels
121
120
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
121
+ if (version_num >= 90 ) {
122
122
cutlass_scaled_mm_sm90 (c, a, b, a_scales, b_scales, bias);
123
- # else
124
- cutlass_scaled_mm_sm80 (c, a, b, a_scales, b_scales, bias);
123
+ return ;
124
+ }
125
125
#endif
126
- } else if (version_num == 89 ) {
126
+
127
+ #if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
128
+ if (version_num == 89 ) {
127
129
// Ada Lovelace
128
130
cutlass_scaled_mm_sm89 (c, a, b, a_scales, b_scales, bias);
129
- } else if (version_num >= 80 ) {
131
+ return ;
132
+ }
133
+
134
+ if (version_num >= 80 ) {
130
135
// Ampere
131
136
cutlass_scaled_mm_sm80 (c, a, b, a_scales, b_scales, bias);
132
- } else {
133
- // Turing
134
- TORCH_CHECK (version_num >= 75 );
135
- cutlass_scaled_mm_sm75 (c, a, b, a_scales, b_scales, bias);
137
+ return ;
136
138
}
139
+
140
+ // Turing
141
+ TORCH_CHECK (version_num >= 75 );
142
+ cutlass_scaled_mm_sm75 (c, a, b, a_scales, b_scales, bias);
143
+ #endif
144
+
145
+ TORCH_CHECK_NOT_IMPLEMENTED (
146
+ false ,
147
+ " No compiled cutlass_scaled_mm for a compute capability less than "
148
+ " CUDA device capability: " ,
149
+ version_num);
137
150
}
138
151
139
152
void cutlass_scaled_mm_azp (torch::Tensor& c, torch::Tensor const & a,
@@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
174
187
" currently bias dtype must match output dtype " , c.dtype ());
175
188
176
189
at::cuda::OptionalCUDAGuard const device_guard (device_of (a));
190
+
177
191
int32_t version_num = get_sm_version_num ();
178
- if (version_num >= 90 ) {
179
- // Hopper
180
192
181
- // Guard against compilation issues for sm90 kernels
182
193
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
194
+ if (version_num >= 90 ) {
183
195
cutlass_scaled_mm_azp_sm90 (c, a, b, a_scales, b_scales, azp_adj, azp, bias);
184
- # else
185
- cutlass_scaled_mm_azp_sm80 (c, a, b, a_scales, b_scales, azp_adj, azp, bias);
196
+ return ;
197
+ }
186
198
#endif
187
- } else if (version_num == 89 ) {
199
+
200
+ #if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
201
+ if (version_num == 89 ) {
188
202
// Ada Lovelace
189
203
cutlass_scaled_mm_azp_sm89 (c, a, b, a_scales, b_scales, azp_adj, azp, bias);
190
- } else if (version_num >= 80 ) {
204
+ return ;
205
+ }
206
+
207
+ if (version_num >= 80 ) {
191
208
// Ampere
192
209
cutlass_scaled_mm_azp_sm80 (c, a, b, a_scales, b_scales, azp_adj, azp, bias);
193
- } else {
194
- // Turing
195
- TORCH_CHECK (version_num >= 75 );
196
- cutlass_scaled_mm_azp_sm75 (c, a, b, a_scales, b_scales, azp_adj, azp, bias);
210
+ return ;
197
211
}
212
+
213
+ // Turing
214
+ TORCH_CHECK (version_num >= 75 );
215
+ cutlass_scaled_mm_azp_sm75 (c, a, b, a_scales, b_scales, azp_adj, azp, bias);
216
+ return ;
217
+ #endif
218
+
219
+ TORCH_CHECK_NOT_IMPLEMENTED (
220
+ false ,
221
+ " No compiled cutlass_scaled_mm_azp for a compute capability less than "
222
+ " CUDA device capability: " ,
223
+ version_num);
198
224
}
0 commit comments