@@ -5983,8 +5983,191 @@ def test_cuda_module_loading_env(self):
5983
5983
self .assertEqual (val , "LAZY" )
5984
5984
5985
5985
5986
+ class TestCompileKernel (TestCase ):
5987
+ @unittest .skipIf (TEST_WITH_ROCM , "ROCM does not support nvrtc" )
5988
+ @unittest .skipIf (not TEST_CUDA , "No CUDA" )
5989
+ def test_compile_kernel (self ):
5990
+ # Simple vector addition kernel
5991
+ kernel_source = """
5992
+ __global__ void add_tensors(const float* a, const float* b, float* c, int n) {
5993
+ int i = threadIdx.x + blockIdx.x * blockDim.x;
5994
+ if (i < n)
5995
+ c[i] = a[i] + b[i];
5996
+ }
5997
+ """
5998
+
5999
+ # Compile the kernel
6000
+ from torch .cuda import _compile_kernel
6001
+
6002
+ add_kernel = _compile_kernel (kernel_source , "add_tensors" )
6003
+
6004
+ # Prepare data
6005
+ N = 1024
6006
+ a = torch .rand (N , device = "cuda" )
6007
+ b = torch .rand (N , device = "cuda" )
6008
+ c = torch .empty_like (a )
6009
+
6010
+ # Calculate grid and block dimensions
6011
+ threads_per_block = 256
6012
+ blocks_per_grid = (N + threads_per_block - 1 ) // threads_per_block
6013
+
6014
+ # Launch kernel
6015
+ add_kernel (
6016
+ grid = (blocks_per_grid , 1 , 1 ),
6017
+ block = (threads_per_block , 1 , 1 ),
6018
+ args = [a , b , c , N ],
6019
+ )
6020
+
6021
+ # Verify results
6022
+ expected = a + b
6023
+ self .assertEqual (c , expected )
6024
+
6025
+ # Test with different tensor types
6026
+ a_int = torch .randint (0 , 100 , (N ,), device = "cuda" , dtype = torch .int32 )
6027
+ b_int = torch .randint (0 , 100 , (N ,), device = "cuda" , dtype = torch .int32 )
6028
+ c_int = torch .empty_like (a_int )
6029
+
6030
+ # Integer addition kernel
6031
+ int_kernel_source = """
6032
+ __global__ void add_int_tensors(const int* a, const int* b, int* c, int n) {
6033
+ int i = threadIdx.x + blockIdx.x * blockDim.x;
6034
+ if (i < n)
6035
+ c[i] = a[i] + b[i];
6036
+ }
6037
+ """
6038
+ from torch .cuda import _compile_kernel
6039
+
6040
+ add_int_kernel = _compile_kernel (int_kernel_source , "add_int_tensors" )
6041
+
6042
+ # Launch kernel
6043
+ add_int_kernel (
6044
+ grid = (blocks_per_grid , 1 , 1 ),
6045
+ block = (threads_per_block , 1 , 1 ),
6046
+ args = [a_int , b_int , c_int , N ],
6047
+ )
6048
+
6049
+ # Verify results
6050
+ expected_int = a_int + b_int
6051
+ torch .testing .assert_close (c_int , expected_int )
6052
+
6053
+ # Test with header code
6054
+ header_code = """
6055
+ #define SCALE_FACTOR 2.0f
6056
+
6057
+ __device__ float scale_value(float val) {
6058
+ return val * SCALE_FACTOR;
6059
+ }
6060
+ """
6061
+
6062
+ scale_kernel_source = """
6063
+ __global__ void scale_tensors(const float* input, float* output, int n) {
6064
+ int i = threadIdx.x + blockIdx.x * blockDim.x;
6065
+ if (i < n)
6066
+ output[i] = scale_value(input[i]);
6067
+ }
6068
+ """
6069
+
6070
+ scale_kernel = _compile_kernel (
6071
+ scale_kernel_source , "scale_tensors" , header_code = header_code
6072
+ )
6073
+
6074
+ input_tensor = torch .rand (N , device = "cuda" )
6075
+ output_tensor = torch .empty_like (input_tensor )
6076
+
6077
+ scale_kernel (
6078
+ grid = (blocks_per_grid , 1 , 1 ),
6079
+ block = (threads_per_block , 1 , 1 ),
6080
+ args = [input_tensor , output_tensor , N ],
6081
+ )
6082
+
6083
+ # Verify scaling
6084
+ expected_scaled = input_tensor * 2.0
6085
+ torch .testing .assert_close (output_tensor , expected_scaled )
6086
+
6087
+ # Test error handling with invalid kernel
6088
+ invalid_kernel_source = """
6089
+ __global__ void invalid_kernel(float* a) {
6090
+ undeclared_variable = 10; // This will cause a compilation error
6091
+ }
6092
+ """
6093
+
6094
+ with self .assertRaises (RuntimeError ):
6095
+ _compile_kernel (invalid_kernel_source , "invalid_kernel" )
6096
+
6097
+ @unittest .skipIf (TEST_WITH_ROCM , "ROCM does not support nvrtc" )
6098
+ @unittest .skipIf (not TEST_CUDA , "No CUDA" )
6099
+ def test_compile_kernel_advanced (self ):
6100
+ # Test matrix multiplication
6101
+ matmul_kernel_source = """
6102
+ __global__ void matrix_multiply(const float* A, const float* B, float* C, int M, int N, int K) {
6103
+ int row = blockIdx.y * blockDim.y + threadIdx.y;
6104
+ int col = blockIdx.x * blockDim.x + threadIdx.x;
6105
+
6106
+ if (row < M && col < N) {
6107
+ float sum = 0.0f;
6108
+ for (int i = 0; i < K; i++) {
6109
+ sum += A[row * K + i] * B[i * N + col];
6110
+ }
6111
+ C[row * N + col] = sum;
6112
+ }
6113
+ }
6114
+ """
6115
+ from torch .cuda import _compile_kernel
6116
+
6117
+ matmul_kernel = _compile_kernel (matmul_kernel_source , "matrix_multiply" )
6118
+
6119
+ # Matrix dimensions
6120
+ M , K , N = 64 , 32 , 48
6121
+
6122
+ # Create matrices
6123
+ A = torch .rand ((M , K ), device = "cuda" )
6124
+ B = torch .rand ((K , N ), device = "cuda" )
6125
+ C = torch .zeros ((M , N ), device = "cuda" )
6126
+
6127
+ # Calculate grid and block dimensions
6128
+ block_dim = (16 , 16 , 1 )
6129
+ grid_dim = (
6130
+ (N + block_dim [0 ] - 1 ) // block_dim [0 ],
6131
+ (M + block_dim [1 ] - 1 ) // block_dim [1 ],
6132
+ 1 ,
6133
+ )
6134
+
6135
+ # Launch kernel
6136
+ matmul_kernel (
6137
+ grid = grid_dim ,
6138
+ block = block_dim ,
6139
+ args = [A .contiguous (), B .contiguous (), C , M , N , K ],
6140
+ )
6141
+
6142
+ # Verify results
6143
+ expected = torch .matmul (A , B )
6144
+ torch .testing .assert_close (C , expected , rtol = 1e-5 , atol = 1e-5 )
6145
+
6146
+ # Test with different compute capability if specified
6147
+ device_props = torch .cuda .get_device_properties (torch .cuda .current_device ())
6148
+ compute_cap = f"{ device_props .major } { device_props .minor } "
6149
+
6150
+ # Recompile with explicit compute capability
6151
+ matmul_kernel_explicit = _compile_kernel (
6152
+ matmul_kernel_source , "matrix_multiply" , compute_capability = compute_cap
6153
+ )
6154
+
6155
+ C_explicit = torch .zeros ((M , N ), device = "cuda" )
6156
+
6157
+ # Launch kernel
6158
+ matmul_kernel_explicit (
6159
+ grid = grid_dim ,
6160
+ block = block_dim ,
6161
+ args = [A .contiguous (), B .contiguous (), C_explicit , M , N , K ],
6162
+ )
6163
+
6164
+ # Verify results
6165
+ torch .testing .assert_close (C_explicit , expected , rtol = 1e-5 , atol = 1e-5 )
6166
+
6167
+
5986
6168
instantiate_parametrized_tests (TestCuda )
5987
6169
instantiate_parametrized_tests (TestCudaMallocAsync )
6170
+ instantiate_parametrized_tests (TestCompileKernel )
5988
6171
instantiate_device_type_tests (TestCudaOptims , globals ())
5989
6172
5990
6173
if __name__ == "__main__" :
0 commit comments