@@ -327,14 +327,86 @@ static const std::map<int, RowwiseKernel> N_5120_K_640_dispatch_table = {
327
327
{ 5984 , fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
328
328
};
329
329
330
+ static const std::map<int , RowwiseKernel> N_4096_K_5120_dispatch_table = {
331
+ { 16 , fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2},
332
+ { 32 , fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_interwave_v2},
333
+ { 48 , fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3},
334
+ { 128 , fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
335
+ { 256 , fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
336
+ { 288 , fp8_rowwise_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
337
+ { 576 , fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
338
+ { 896 , fp8_rowwise_256x128x96x128_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
339
+ { 1152 , fp8_rowwise_256x128x128x128_16x16_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
340
+ { 1392 , fp8_rowwise_256x128x160x128_16x16_4x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
341
+ { 1440 , fp8_rowwise_256x160x128x128_16x16_5x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
342
+ { 1776 , fp8_rowwise_256x128x96x128_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
343
+ { 1824 , fp8_rowwise_256x96x128x128_16x16_3x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
344
+ { 2240 , fp8_rowwise_256x160x96x128_16x16_5x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v3},
345
+ { 2496 , fp8_rowwise_256x192x192x128_16x16_6x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
346
+ { 2816 , fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
347
+ { 2896 , fp8_rowwise_256x224x192x128_16x16_7x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
348
+ { 3040 , fp8_rowwise_256x160x256x128_16x16_5x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
349
+ { 3072 , fp8_rowwise_256x192x224x128_16x16_6x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
350
+ { 3328 , fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
351
+ { 3648 , fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
352
+ { 4096 , fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
353
+ { 4256 , fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
354
+ { 4832 , fp8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4},
355
+ { 4864 , fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
356
+ { 5152 , fp8_rowwise_256x224x160x128_16x16_7x5_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v3},
357
+ { 5184 , fp8_rowwise_256x192x192x128_16x16_6x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
358
+ { 5888 , fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
359
+ { 5920 , fp8_rowwise_256x160x256x128_16x16_5x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
360
+ { 5984 , fp8_rowwise_256x224x192x128_16x16_7x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
361
+ };
362
+
363
+ static const std::map<int , RowwiseKernel> N_5120_K_2048_dispatch_table = {
364
+ { 48 , fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v2},
365
+ { 96 , fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
366
+ { 192 , fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
367
+ { 224 , fp8_rowwise_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
368
+ { 384 , fp8_rowwise_256x128x64x256_32x32_2x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
369
+ { 448 , fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
370
+ { 560 , fp8_rowwise_256x80x128x256_16x16_5x2_16x16x1_16x16x1_1x16x1x16_8x8x1_1x2_intrawave_v3},
371
+ { 608 , fp8_rowwise_256x128x96x128_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
372
+ { 672 , fp8_rowwise_256x96x128x128_16x16_3x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
373
+ { 896 , fp8_rowwise_256x128x128x128_16x16_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
374
+ { 1008 , fp8_rowwise_256x128x160x128_16x16_4x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
375
+ { 1120 , fp8_rowwise_256x160x128x128_16x16_5x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
376
+ { 1408 , fp8_rowwise_256x128x96x128_16x16_4x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
377
+ { 1440 , fp8_rowwise_256x96x128x128_16x16_3x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
378
+ { 1536 , fp8_rowwise_256x128x128x128_16x16_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
379
+ { 1600 , fp8_rowwise_256x160x96x128_16x16_5x3_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v3},
380
+ { 1920 , fp8_rowwise_256x128x128x128_16x16_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
381
+ { 2112 , fp8_rowwise_256x192x192x128_16x16_6x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
382
+ { 2400 , fp8_rowwise_256x160x256x128_16x16_5x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
383
+ { 2464 , fp8_rowwise_256x224x192x128_16x16_7x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
384
+ { 2496 , fp8_rowwise_256x192x224x128_16x16_6x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
385
+ { 2816 , fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
386
+ { 2880 , fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
387
+ { 3328 , fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
388
+ { 3360 , fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
389
+ { 3840 , fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
390
+ { 4224 , fp8_rowwise_256x192x192x128_16x16_6x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
391
+ { 4736 , fp8_rowwise_256x128x128x128_16x16_4x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
392
+ { 4864 , fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
393
+ { 4928 , fp8_rowwise_256x224x192x128_16x16_7x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
394
+ { 4992 , fp8_rowwise_256x192x224x128_16x16_6x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
395
+ { 5632 , fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
396
+ { 5760 , fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
397
+ { 5984 , fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
398
+ };
399
+
330
400
static const std::unordered_map<std::tuple<int , int >, NKLookupTableType, IntTupleHash> NK_lookup_table = {
331
401
{{7168 , 8192 }, N_7168_K_8192_dispatch_table},
332
402
{{8192 , 3584 }, N_8192_K_3584_dispatch_table},
333
403
{{1024 , 5120 }, N_1024_K_5120_dispatch_table},
334
404
{{5120 , 1024 }, N_5120_K_1024_dispatch_table},
335
405
{{2048 , 5120 }, N_2048_K_5120_dispatch_table},
336
406
{{896 , 5120 }, N_896_K_5120_dispatch_table},
337
- {{5120 , 640 }, N_5120_K_640_dispatch_table}
407
+ {{5120 , 640 }, N_5120_K_640_dispatch_table},
408
+ {{4096 , 5120 }, N_4096_K_5120_dispatch_table},
409
+ {{5120 , 2048 }, N_5120_K_2048_dispatch_table}
338
410
};
339
411
340
412
RowwiseKernel rowwise_nk_lookup (int M, const NKLookupTableType& table) {
0 commit comments