Commit d59c1a8
committed
[BugFix][TIR] Fix multi-grouped multi-warp allreduce
PR #15327 and #15373 introduced multi-warp allreduce implementation.
At the time of the introduction, I tested the correctness numerically
via the workload of "taking a matrix of ones as input, computing the
summation over each row". Both PR passed this numerical tess, while
I didn't realize that this test is not complete and cannot guarantee
the correctness.
The previous implementation has bug which can be tested by turning
the input matrix from ones to random floating-point numbers. This will
expose the issues of the previous implementation.
Therefore, this PR fixes the issues, and add the numerical tests
for multi-warp allreduce into `test_allreduce_cuda.py`. By reducing
some of the redundant tests in that file, we hope this can reduce the
testing time a bit while still guarantee the correctness.
Sorry for not testing the implementation completely before.1 parent d6407be commit d59c1a8
File tree
3 files changed
+38
-28
lines changed- src/tir/transforms
- tests/python/unittest
3 files changed
+38
-28
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
76 | 76 | | |
77 | 77 | | |
78 | 78 | | |
79 | | - | |
| 79 | + | |
80 | 80 | | |
81 | | - | |
82 | | - | |
83 | | - | |
84 | | - | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
85 | 90 | | |
86 | 91 | | |
87 | 92 | | |
| |||
344 | 349 | | |
345 | 350 | | |
346 | 351 | | |
347 | | - | |
| 352 | + | |
| 353 | + | |
348 | 354 | | |
349 | 355 | | |
350 | | - | |
| 356 | + | |
351 | 357 | | |
352 | 358 | | |
353 | 359 | | |
354 | | - | |
355 | | - | |
| 360 | + | |
356 | 361 | | |
357 | 362 | | |
358 | 363 | | |
| |||
365 | 370 | | |
366 | 371 | | |
367 | 372 | | |
368 | | - | |
| 373 | + | |
369 | 374 | | |
370 | | - | |
| 375 | + | |
371 | 376 | | |
372 | 377 | | |
373 | 378 | | |
| |||
382 | 387 | | |
383 | 388 | | |
384 | 389 | | |
385 | | - | |
| 390 | + | |
386 | 391 | | |
387 | 392 | | |
388 | 393 | | |
| |||
400 | 405 | | |
401 | 406 | | |
402 | 407 | | |
403 | | - | |
| 408 | + | |
| 409 | + | |
404 | 410 | | |
405 | 411 | | |
406 | 412 | | |
| |||
414 | 420 | | |
415 | 421 | | |
416 | 422 | | |
417 | | - | |
418 | | - | |
419 | | - | |
| 423 | + | |
420 | 424 | | |
421 | 425 | | |
422 | 426 | | |
| |||
772 | 776 | | |
773 | 777 | | |
774 | 778 | | |
775 | | - | |
| 779 | + | |
776 | 780 | | |
777 | 781 | | |
778 | 782 | | |
| |||
Lines changed: 3 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
95 | 95 | | |
96 | 96 | | |
97 | 97 | | |
98 | | - | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
99 | 101 | | |
100 | 102 | | |
101 | 103 | | |
| |||
Lines changed: 14 additions & 10 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
387 | 387 | | |
388 | 388 | | |
389 | 389 | | |
| 390 | + | |
390 | 391 | | |
391 | 392 | | |
392 | 393 | | |
| |||
463 | 464 | | |
464 | 465 | | |
465 | 466 | | |
| 467 | + | |
466 | 468 | | |
467 | 469 | | |
468 | 470 | | |
| |||
550 | 552 | | |
551 | 553 | | |
552 | 554 | | |
| 555 | + | |
553 | 556 | | |
554 | 557 | | |
555 | 558 | | |
| |||
585 | 588 | | |
586 | 589 | | |
587 | 590 | | |
588 | | - | |
589 | | - | |
| 591 | + | |
| 592 | + | |
590 | 593 | | |
591 | 594 | | |
592 | | - | |
| 595 | + | |
593 | 596 | | |
594 | 597 | | |
595 | 598 | | |
596 | 599 | | |
597 | 600 | | |
598 | 601 | | |
599 | 602 | | |
600 | | - | |
| 603 | + | |
601 | 604 | | |
602 | 605 | | |
603 | 606 | | |
604 | | - | |
| 607 | + | |
605 | 608 | | |
606 | 609 | | |
607 | 610 | | |
| |||
636 | 639 | | |
637 | 640 | | |
638 | 641 | | |
| 642 | + | |
639 | 643 | | |
640 | 644 | | |
641 | 645 | | |
| |||
675 | 679 | | |
676 | 680 | | |
677 | 681 | | |
678 | | - | |
679 | | - | |
| 682 | + | |
| 683 | + | |
680 | 684 | | |
681 | 685 | | |
682 | | - | |
| 686 | + | |
683 | 687 | | |
684 | 688 | | |
685 | 689 | | |
| |||
691 | 695 | | |
692 | 696 | | |
693 | 697 | | |
694 | | - | |
| 698 | + | |
695 | 699 | | |
696 | 700 | | |
697 | 701 | | |
698 | | - | |
| 702 | + | |
699 | 703 | | |
700 | 704 | | |
701 | 705 | | |
| |||
0 commit comments