21
21
import torch .testing ._internal .hypothesis_utils as hu
22
22
hu .assert_deadline_disabled ()
23
23
24
- from torch .testing ._internal .common_cuda import SM80OrLater
25
24
from torch .testing ._internal .common_utils import TestCase
26
25
from torch .testing ._internal .common_utils import IS_PPC , TEST_WITH_UBSAN , IS_MACOS , BUILD_WITH_CAFFE2 , IS_SANDCASTLE
27
26
from torch .testing ._internal .common_quantization import skipIfNoFBGEMM , skipIfNoQNNPACK , skipIfNoONEDNN
32
31
qengine_is_onednn ,
33
32
)
34
33
from torch .ao .quantization import PerChannelMinMaxObserver
35
- from torch .testing ._internal .common_cuda import TEST_CUDNN , TEST_CUDNN_VERSION , TEST_CUDA
34
+ from torch .testing ._internal .common_cuda import TEST_CUDNN , TEST_CUDA
36
35
from torch .testing ._internal .optests import opcheck
37
36
import torch .backends .xnnpack
38
37
@@ -906,7 +905,9 @@ def test_qadd_relu_same_qparams(self):
906
905
"""Tests the correctness of the cudnn add and add_relu op
907
906
(Similar to test_qadd_relu_different_qparams, will probably merge in the future)"""
908
907
@unittest .skipIf (not TEST_CUDNN , "cudnn is not enabled." )
909
- @unittest .skipIf (not SM80OrLater , "requires sm80 or later." )
908
+ @unittest .skip ("Local only - currently the test_qadd_relu_cudnn op is bulid "
909
+ "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
910
+ "after it is built by default" )
910
911
def test_qadd_relu_cudnn (self ):
911
912
dtype = torch .qint8
912
913
add_relu = torch .ops .quantized .add_relu
@@ -939,7 +940,9 @@ def test_qadd_relu_cudnn(self):
939
940
940
941
"""Tests the correctness of the cudnn add and add_relu op for nhwc format"""
941
942
@unittest .skipIf (not TEST_CUDNN , "cudnn is not enabled." )
942
- @unittest .skipIf (not SM80OrLater , "requires sm80 or later." )
943
+ @unittest .skip ("Local only - currently the test_qadd_relu_cudnn_nhwc op is bulid "
944
+ "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
945
+ "after it is built by default" )
943
946
def test_qadd_relu_cudnn_nhwc (self ):
944
947
dtype = torch .qint8
945
948
add_relu = torch .ops .quantized .add_relu
@@ -1376,7 +1379,7 @@ def test_max_pool1d(self, X, kernel, stride, dilation, padding, ceil_mode):
1376
1379
self .assertEqual (a_ref , a_hat .dequantize (),
1377
1380
msg = "ops.quantized.max_pool1d results are off" )
1378
1381
1379
- # TODO: merge this test with test_max_pool2d
1382
+ # TODO: merge this test with test_max_pool2d when USE_EXPERIMENTAL_CUDNN_V8_API flag is enabled in CI
1380
1383
"""Tests 2D cudnn max pool operation on quantized tensors."""
1381
1384
@given (X = hu .tensor (shapes = hu .array_shapes (min_dims = 3 , max_dims = 4 ,
1382
1385
min_side = 1 , max_side = 10 ),
@@ -1391,7 +1394,9 @@ def test_max_pool1d(self, X, kernel, stride, dilation, padding, ceil_mode):
1391
1394
padding = st .integers (0 , 2 ),
1392
1395
ceil_mode = st .booleans ())
1393
1396
@unittest .skipIf (not TEST_CUDNN , "cudnn is not enabled." )
1394
- @unittest .skipIf (TEST_CUDNN_VERSION <= 90100 , "cuDNN maxpool2d mishandles -128 before v90100" )
1397
+ @unittest .skip ("Local only - currently the qconv2d_cudnn op is bulid "
1398
+ "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
1399
+ "after it is built by default" )
1395
1400
def test_max_pool2d_cudnn (self , X , kernel , stride , dilation , padding , ceil_mode ):
1396
1401
X , (scale , zero_point , torch_type ) = X
1397
1402
assume (kernel // 2 >= padding ) # Kernel cannot be overhanging!
@@ -4045,7 +4050,9 @@ def test_qlinear_with_input_q_dq_qweight_dq_output_fp32(
4045
4050
use_channelwise = st .sampled_from ([False ])) # channelwise currently not supported for qlinear cudnn
4046
4051
@skipIfNoFBGEMM
4047
4052
@unittest .skipIf (not TEST_CUDNN , "cudnn is not enabled." )
4048
- @unittest .skipIf (not SM80OrLater , "requires sm80 or later." )
4053
+ @unittest .skip ("Local only - currently the qlinear_cudnn op is bulid "
4054
+ "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
4055
+ "after it is built by default" )
4049
4056
# TODO: check with yang regarding CUDNN flags
4050
4057
def test_qlinear_cudnn (self , batch_size , input_channels , output_channels , use_bias ,
4051
4058
use_relu , use_multi_dim_input , use_channelwise ):
@@ -5420,7 +5427,9 @@ def test_qconv2d_add_relu(self):
5420
5427
use_channelwise = st .sampled_from ([False ]))
5421
5428
@skipIfNoFBGEMM
5422
5429
@unittest .skipIf (not TEST_CUDNN , "cudnn is not enabled." )
5423
- @unittest .skipIf (not SM80OrLater , "requires sm80 or later." )
5430
+ @unittest .skip ("Local only - currently the qconv2d_cudnn op is bulid "
5431
+ "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
5432
+ "after it is built by default" )
5424
5433
def test_qconv2d_cudnn (
5425
5434
self ,
5426
5435
batch_size ,
@@ -5501,7 +5510,9 @@ def test_qconv2d_cudnn(
5501
5510
use_channelwise = st .sampled_from ([False ]))
5502
5511
@skipIfNoFBGEMM
5503
5512
@unittest .skipIf (not TEST_CUDNN , "cudnn is not enabled." )
5504
- @unittest .skipIf (not SM80OrLater , "requires sm80 or later." )
5513
+ @unittest .skip ("Local only - currently the qconv2d_cudnn op is bulid "
5514
+ "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
5515
+ "after it is built by default" )
5505
5516
def test_qconv2d_relu_cudnn (
5506
5517
self ,
5507
5518
batch_size ,
@@ -6234,7 +6245,9 @@ def test_qconv1d_relu(
6234
6245
use_channelwise = st .sampled_from ([False ]))
6235
6246
@skipIfNoFBGEMM
6236
6247
@unittest .skipIf (not TEST_CUDNN , "cudnn is not enabled." )
6237
- @unittest .skipIf (not SM80OrLater , "requires sm80 or later." )
6248
+ @unittest .skip ("Local only - currently the qconv1d_cudnn op is bulid "
6249
+ "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
6250
+ "after it is built by default" )
6238
6251
def test_qconv1d_cudnn (
6239
6252
self ,
6240
6253
batch_size ,
@@ -6306,7 +6319,9 @@ def test_qconv1d_cudnn(
6306
6319
use_channelwise = st .sampled_from ([False ]))
6307
6320
@skipIfNoFBGEMM
6308
6321
@unittest .skipIf (not TEST_CUDNN , "cudnn is not enabled." )
6309
- @unittest .skipIf (not SM80OrLater , "requires sm80 or later." )
6322
+ @unittest .skip ("Local only - currently the qconv1d_cudnn op is bulid "
6323
+ "with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
6324
+ "after it is built by default" )
6310
6325
def test_qconv1d_relu_cudnn (
6311
6326
self ,
6312
6327
batch_size ,
0 commit comments