@@ -156,6 +156,7 @@ def execute_forward_( # noqa C901
156
156
pooling_mode : PoolingMode ,
157
157
use_cpu : bool ,
158
158
output_dtype : SparseType ,
159
+ use_experimental_tbe : bool ,
159
160
) -> None :
160
161
# NOTE: cache is not applicable to CPU version.
161
162
assume (not use_cpu or not use_cache )
@@ -324,6 +325,7 @@ def execute_forward_( # noqa C901
324
325
cache_algorithm = cache_algorithm ,
325
326
pooling_mode = pooling_mode ,
326
327
output_dtype = output_dtype ,
328
+ use_experimental_tbe = use_experimental_tbe ,
327
329
)
328
330
# NOTE: test TorchScript-compatible!
329
331
cc = torch .jit .script (cc )
@@ -412,6 +414,7 @@ def test_forward_cpu_int8(
412
414
pooling_mode ,
413
415
use_cpu ,
414
416
SparseType .FP32 ,
417
+ False , # use_experimental_tbe
415
418
)
416
419
417
420
def test_forward_cpu_fp32 (
@@ -456,6 +459,7 @@ def test_forward_cpu_fp32(
456
459
pooling_mode ,
457
460
use_cpu ,
458
461
SparseType .FP32 ,
462
+ False , # use_experimental_tbe
459
463
)
460
464
461
465
@unittest .skipIf (* gpu_unavailable )
@@ -505,11 +509,22 @@ def test_forward_gpu_no_cache_int8(
505
509
pooling_mode ,
506
510
use_cpu ,
507
511
SparseType .FP32 ,
512
+ False , # use_experimental_tbe
508
513
)
509
514
510
515
@unittest .skipIf (* gpu_unavailable )
516
+ @given (
517
+ use_experimental_tbe = st .booleans () if not TEST_WITH_ROCM else st .just (False ),
518
+ )
519
+ @settings (
520
+ verbosity = Verbosity .verbose ,
521
+ max_examples = MAX_EXAMPLES_LONG_RUNNING ,
522
+ deadline = None ,
523
+ suppress_health_check = [HealthCheck .filter_too_much , HealthCheck .data_too_large ],
524
+ )
511
525
def test_forward_gpu_no_cache_fp16 (
512
526
self ,
527
+ use_experimental_tbe : bool ,
513
528
) -> None :
514
529
weights_precision = SparseType .FP16
515
530
use_cpu = False
@@ -527,15 +542,17 @@ def test_forward_gpu_no_cache_fp16(
527
542
[
528
543
PoolingMode .SUM ,
529
544
PoolingMode .MEAN ,
530
- PoolingMode .NONE ,
531
545
]
546
+ + ([PoolingMode .NONE ] if not use_experimental_tbe else [])
532
547
)
533
548
if pooling_mode == PoolingMode .NONE :
534
549
mixed = False
535
550
mixed_B = False
536
551
else :
537
552
mixed = random .choice ([True , False ])
538
- mixed_B = random .choice ([True , False ])
553
+ mixed_B = (
554
+ random .choice ([True , False ]) if not use_experimental_tbe else False
555
+ )
539
556
if pooling_mode == PoolingMode .SUM :
540
557
weighted = random .choice ([True , False ])
541
558
else :
@@ -555,11 +572,22 @@ def test_forward_gpu_no_cache_fp16(
555
572
pooling_mode ,
556
573
use_cpu ,
557
574
SparseType .FP32 ,
575
+ use_experimental_tbe ,
558
576
)
559
577
560
578
@unittest .skipIf (* gpu_unavailable )
579
+ @given (
580
+ use_experimental_tbe = st .booleans () if not TEST_WITH_ROCM else st .just (False ),
581
+ )
582
+ @settings (
583
+ verbosity = Verbosity .verbose ,
584
+ max_examples = MAX_EXAMPLES_LONG_RUNNING ,
585
+ deadline = None ,
586
+ suppress_health_check = [HealthCheck .filter_too_much , HealthCheck .data_too_large ],
587
+ )
561
588
def test_forward_gpu_no_cache_fp32 (
562
589
self ,
590
+ use_experimental_tbe : bool ,
563
591
) -> None :
564
592
weights_precision = SparseType .FP32
565
593
use_cpu = False
@@ -577,15 +605,17 @@ def test_forward_gpu_no_cache_fp32(
577
605
[
578
606
PoolingMode .SUM ,
579
607
PoolingMode .MEAN ,
580
- PoolingMode .NONE ,
581
608
]
609
+ + ([PoolingMode .NONE ] if not use_experimental_tbe else [])
582
610
)
583
611
if pooling_mode == PoolingMode .NONE :
584
612
mixed = False
585
613
mixed_B = False
586
614
else :
587
615
mixed = random .choice ([True , False ])
588
- mixed_B = random .choice ([True , False ])
616
+ mixed_B = (
617
+ random .choice ([True , False ]) if not use_experimental_tbe else False
618
+ )
589
619
if pooling_mode == PoolingMode .SUM :
590
620
weighted = random .choice ([True , False ])
591
621
else :
@@ -605,6 +635,7 @@ def test_forward_gpu_no_cache_fp32(
605
635
pooling_mode ,
606
636
use_cpu ,
607
637
SparseType .FP32 ,
638
+ use_experimental_tbe ,
608
639
)
609
640
610
641
@unittest .skipIf (* gpu_unavailable )
@@ -668,11 +699,13 @@ def test_forward_gpu_uvm_cache_int8(
668
699
pooling_mode ,
669
700
use_cpu ,
670
701
output_dtype ,
702
+ False , # use_experimental_tbe
671
703
)
672
704
673
705
@unittest .skipIf (* gpu_unavailable )
674
706
@given (
675
707
cache_algorithm = st .sampled_from (CacheAlgorithm ),
708
+ use_experimental_tbe = st .booleans () if not TEST_WITH_ROCM else st .just (False ),
676
709
)
677
710
@settings (
678
711
verbosity = Verbosity .verbose ,
@@ -683,6 +716,7 @@ def test_forward_gpu_uvm_cache_int8(
683
716
def test_forward_gpu_uvm_cache_fp16 (
684
717
self ,
685
718
cache_algorithm : CacheAlgorithm ,
719
+ use_experimental_tbe : bool ,
686
720
) -> None :
687
721
weights_precision = SparseType .FP16
688
722
use_cpu = False
@@ -698,8 +732,8 @@ def test_forward_gpu_uvm_cache_fp16(
698
732
[
699
733
PoolingMode .SUM ,
700
734
PoolingMode .MEAN ,
701
- PoolingMode .NONE ,
702
735
]
736
+ + ([PoolingMode .NONE ] if not use_experimental_tbe else [])
703
737
)
704
738
output_dtype = random .choice (
705
739
[
@@ -731,11 +765,13 @@ def test_forward_gpu_uvm_cache_fp16(
731
765
pooling_mode ,
732
766
use_cpu ,
733
767
output_dtype ,
768
+ use_experimental_tbe ,
734
769
)
735
770
736
771
@unittest .skipIf (* gpu_unavailable )
737
772
@given (
738
773
cache_algorithm = st .sampled_from (CacheAlgorithm ),
774
+ use_experimental_tbe = st .booleans () if not TEST_WITH_ROCM else st .just (False ),
739
775
)
740
776
@settings (
741
777
verbosity = Verbosity .verbose ,
@@ -746,6 +782,7 @@ def test_forward_gpu_uvm_cache_fp16(
746
782
def test_forward_gpu_uvm_cache_fp32 (
747
783
self ,
748
784
cache_algorithm : CacheAlgorithm ,
785
+ use_experimental_tbe : bool ,
749
786
) -> None :
750
787
weights_precision = SparseType .FP32
751
788
use_cpu = False
@@ -761,8 +798,8 @@ def test_forward_gpu_uvm_cache_fp32(
761
798
[
762
799
PoolingMode .SUM ,
763
800
PoolingMode .MEAN ,
764
- PoolingMode .NONE ,
765
801
]
802
+ + ([PoolingMode .NONE ] if not use_experimental_tbe else [])
766
803
)
767
804
output_dtype = random .choice (
768
805
[
@@ -794,6 +831,7 @@ def test_forward_gpu_uvm_cache_fp32(
794
831
pooling_mode ,
795
832
use_cpu ,
796
833
output_dtype ,
834
+ use_experimental_tbe ,
797
835
)
798
836
799
837
@unittest .skipIf (* gpu_unavailable )
0 commit comments