Skip to content

Commit 5b368fa

Browse files
msaroufimpytorchmergebot
authored andcommitted
Add torch.cuda._compile_kernel() (pytorch#151484)
Followup work on top pytorch#149480 Wrapper on top of nvrtc inspired by https://gist.github.com/malfet/2c9a25976dd7396430c38af603f791da from @malfet Compiling toy kernels with this setup takes 0.01s vs 90s using `load_inline()` on my local H100. This was primarily motivated by the timeouts I was seeing in the popcorn leaderboard but would also be useful to integrate into KernelBench This PR is in the same spirit as pytorch#148972 which was a similar UX for Metal For now we are planning on landing this as a private function because we expect to iterate both on the user facing API and the internals implementation, will open up a seperate issue to discuss the path towards making this work public and give a broader overview of the state of custom cuda kernel authoring in PyTorch Future work, as a prereq to making the work public * divup primitive * support multiple kernels * Expose _get_nvrtc_version from native code * interop with torch.compile * AMD support Pull Request resolved: pytorch#151484 Approved by: https://github.com/malfet
1 parent 78953ee commit 5b368fa

File tree

3 files changed

+642
-2
lines changed

3 files changed

+642
-2
lines changed

test/test_cuda.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5983,8 +5983,191 @@ def test_cuda_module_loading_env(self):
59835983
self.assertEqual(val, "LAZY")
59845984

59855985

5986+
class TestCompileKernel(TestCase):
5987+
@unittest.skipIf(TEST_WITH_ROCM, "ROCM does not support nvrtc")
5988+
@unittest.skipIf(not TEST_CUDA, "No CUDA")
5989+
def test_compile_kernel(self):
5990+
# Simple vector addition kernel
5991+
kernel_source = """
5992+
__global__ void add_tensors(const float* a, const float* b, float* c, int n) {
5993+
int i = threadIdx.x + blockIdx.x * blockDim.x;
5994+
if (i < n)
5995+
c[i] = a[i] + b[i];
5996+
}
5997+
"""
5998+
5999+
# Compile the kernel
6000+
from torch.cuda import _compile_kernel
6001+
6002+
add_kernel = _compile_kernel(kernel_source, "add_tensors")
6003+
6004+
# Prepare data
6005+
N = 1024
6006+
a = torch.rand(N, device="cuda")
6007+
b = torch.rand(N, device="cuda")
6008+
c = torch.empty_like(a)
6009+
6010+
# Calculate grid and block dimensions
6011+
threads_per_block = 256
6012+
blocks_per_grid = (N + threads_per_block - 1) // threads_per_block
6013+
6014+
# Launch kernel
6015+
add_kernel(
6016+
grid=(blocks_per_grid, 1, 1),
6017+
block=(threads_per_block, 1, 1),
6018+
args=[a, b, c, N],
6019+
)
6020+
6021+
# Verify results
6022+
expected = a + b
6023+
self.assertEqual(c, expected)
6024+
6025+
# Test with different tensor types
6026+
a_int = torch.randint(0, 100, (N,), device="cuda", dtype=torch.int32)
6027+
b_int = torch.randint(0, 100, (N,), device="cuda", dtype=torch.int32)
6028+
c_int = torch.empty_like(a_int)
6029+
6030+
# Integer addition kernel
6031+
int_kernel_source = """
6032+
__global__ void add_int_tensors(const int* a, const int* b, int* c, int n) {
6033+
int i = threadIdx.x + blockIdx.x * blockDim.x;
6034+
if (i < n)
6035+
c[i] = a[i] + b[i];
6036+
}
6037+
"""
6038+
from torch.cuda import _compile_kernel
6039+
6040+
add_int_kernel = _compile_kernel(int_kernel_source, "add_int_tensors")
6041+
6042+
# Launch kernel
6043+
add_int_kernel(
6044+
grid=(blocks_per_grid, 1, 1),
6045+
block=(threads_per_block, 1, 1),
6046+
args=[a_int, b_int, c_int, N],
6047+
)
6048+
6049+
# Verify results
6050+
expected_int = a_int + b_int
6051+
torch.testing.assert_close(c_int, expected_int)
6052+
6053+
# Test with header code
6054+
header_code = """
6055+
#define SCALE_FACTOR 2.0f
6056+
6057+
__device__ float scale_value(float val) {
6058+
return val * SCALE_FACTOR;
6059+
}
6060+
"""
6061+
6062+
scale_kernel_source = """
6063+
__global__ void scale_tensors(const float* input, float* output, int n) {
6064+
int i = threadIdx.x + blockIdx.x * blockDim.x;
6065+
if (i < n)
6066+
output[i] = scale_value(input[i]);
6067+
}
6068+
"""
6069+
6070+
scale_kernel = _compile_kernel(
6071+
scale_kernel_source, "scale_tensors", header_code=header_code
6072+
)
6073+
6074+
input_tensor = torch.rand(N, device="cuda")
6075+
output_tensor = torch.empty_like(input_tensor)
6076+
6077+
scale_kernel(
6078+
grid=(blocks_per_grid, 1, 1),
6079+
block=(threads_per_block, 1, 1),
6080+
args=[input_tensor, output_tensor, N],
6081+
)
6082+
6083+
# Verify scaling
6084+
expected_scaled = input_tensor * 2.0
6085+
torch.testing.assert_close(output_tensor, expected_scaled)
6086+
6087+
# Test error handling with invalid kernel
6088+
invalid_kernel_source = """
6089+
__global__ void invalid_kernel(float* a) {
6090+
undeclared_variable = 10; // This will cause a compilation error
6091+
}
6092+
"""
6093+
6094+
with self.assertRaises(RuntimeError):
6095+
_compile_kernel(invalid_kernel_source, "invalid_kernel")
6096+
6097+
@unittest.skipIf(TEST_WITH_ROCM, "ROCM does not support nvrtc")
6098+
@unittest.skipIf(not TEST_CUDA, "No CUDA")
6099+
def test_compile_kernel_advanced(self):
6100+
# Test matrix multiplication
6101+
matmul_kernel_source = """
6102+
__global__ void matrix_multiply(const float* A, const float* B, float* C, int M, int N, int K) {
6103+
int row = blockIdx.y * blockDim.y + threadIdx.y;
6104+
int col = blockIdx.x * blockDim.x + threadIdx.x;
6105+
6106+
if (row < M && col < N) {
6107+
float sum = 0.0f;
6108+
for (int i = 0; i < K; i++) {
6109+
sum += A[row * K + i] * B[i * N + col];
6110+
}
6111+
C[row * N + col] = sum;
6112+
}
6113+
}
6114+
"""
6115+
from torch.cuda import _compile_kernel
6116+
6117+
matmul_kernel = _compile_kernel(matmul_kernel_source, "matrix_multiply")
6118+
6119+
# Matrix dimensions
6120+
M, K, N = 64, 32, 48
6121+
6122+
# Create matrices
6123+
A = torch.rand((M, K), device="cuda")
6124+
B = torch.rand((K, N), device="cuda")
6125+
C = torch.zeros((M, N), device="cuda")
6126+
6127+
# Calculate grid and block dimensions
6128+
block_dim = (16, 16, 1)
6129+
grid_dim = (
6130+
(N + block_dim[0] - 1) // block_dim[0],
6131+
(M + block_dim[1] - 1) // block_dim[1],
6132+
1,
6133+
)
6134+
6135+
# Launch kernel
6136+
matmul_kernel(
6137+
grid=grid_dim,
6138+
block=block_dim,
6139+
args=[A.contiguous(), B.contiguous(), C, M, N, K],
6140+
)
6141+
6142+
# Verify results
6143+
expected = torch.matmul(A, B)
6144+
torch.testing.assert_close(C, expected, rtol=1e-5, atol=1e-5)
6145+
6146+
# Test with different compute capability if specified
6147+
device_props = torch.cuda.get_device_properties(torch.cuda.current_device())
6148+
compute_cap = f"{device_props.major}{device_props.minor}"
6149+
6150+
# Recompile with explicit compute capability
6151+
matmul_kernel_explicit = _compile_kernel(
6152+
matmul_kernel_source, "matrix_multiply", compute_capability=compute_cap
6153+
)
6154+
6155+
C_explicit = torch.zeros((M, N), device="cuda")
6156+
6157+
# Launch kernel
6158+
matmul_kernel_explicit(
6159+
grid=grid_dim,
6160+
block=block_dim,
6161+
args=[A.contiguous(), B.contiguous(), C_explicit, M, N, K],
6162+
)
6163+
6164+
# Verify results
6165+
torch.testing.assert_close(C_explicit, expected, rtol=1e-5, atol=1e-5)
6166+
6167+
59866168
instantiate_parametrized_tests(TestCuda)
59876169
instantiate_parametrized_tests(TestCudaMallocAsync)
6170+
instantiate_parametrized_tests(TestCompileKernel)
59886171
instantiate_device_type_tests(TestCudaOptims, globals())
59896172

59906173
if __name__ == "__main__":

torch/cuda/__init__.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import importlib
1515
import os
16+
import sys
1617
import threading
1718
import traceback
1819
import warnings
@@ -1038,7 +1039,7 @@ def current_device() -> int:
10381039
return torch._C._cuda_getDevice()
10391040

10401041

1041-
def synchronize(device: _device_t = None) -> None:
1042+
def synchronize(device: Optional[_device_t] = None) -> None:
10421043
r"""Wait for all kernels in all streams on a CUDA device to complete.
10431044
10441045
Args:
@@ -1693,6 +1694,75 @@ def addmm_kernel_impl(*args, **kwargs):
16931694
_lazy_call(_register_triton_kernels)
16941695

16951696

1697+
def _compile_kernel(
1698+
kernel_source: str,
1699+
kernel_name: str,
1700+
compute_capability: Optional[str] = None,
1701+
header_code: str = "",
1702+
cuda_include_dirs: Optional[list] = None,
1703+
nvcc_options: Optional[list] = None,
1704+
):
1705+
"""
1706+
Compiles a CUDA kernel using NVRTC and returns a callable function.
1707+
1708+
This function is a wrapper for NVRTC that enables runtime compilation of CUDA kernels.
1709+
Note that this returns a raw CUDA kernel that operates on raw memory pointers.
1710+
To use this kernel as a proper PyTorch operator, you should wrap it following the guide at:
1711+
pytorch.org/tutorials/advanced/python_custom_ops.html
1712+
1713+
Args:
1714+
kernel_source (str): The CUDA kernel source code as a string
1715+
kernel_name (str): The name of the kernel function to compile
1716+
compute_capability (str, optional): The compute capability to target (e.g., "86").
1717+
If None, will detect from current device.
1718+
header_code (str, optional): Additional header code to prepend to the kernel source
1719+
cuda_include_dirs (list, optional): List of directories containing CUDA headers
1720+
nvcc_options (list, optional): Additional options to pass to NVRTC
1721+
1722+
Returns:
1723+
callable: A Python function that can be called with PyTorch tensor arguments to execute the kernel
1724+
1725+
Example:
1726+
>>> # xdoctest: +SKIP
1727+
>>> kernel_code = '''
1728+
extern "C"
1729+
__global__ void add_tensors(const float* a, const float* b, float* c, int n) {
1730+
int i = threadIdx.x + blockIdx.x * blockDim.x;
1731+
if (i < n)
1732+
c[i] = a[i] + b[i];
1733+
}
1734+
'''
1735+
>>> add_kernel = torch.cuda.compile_kernel(kernel_code, "add_tensors")
1736+
>>> a = torch.randn(1024, device="cuda")
1737+
>>> b = torch.randn(1024, device="cuda")
1738+
>>> c = torch.empty_like(a)
1739+
>>> add_kernel(grid=(4,1,1), block=(256,1,1), args=[a, b, c, a.numel()])
1740+
"""
1741+
import ctypes
1742+
1743+
from torch.cuda._utils import _cuda_load_module, _nvrtc_compile
1744+
1745+
# Compile the kernel to PTX
1746+
ptx = _nvrtc_compile(
1747+
kernel_source,
1748+
kernel_name,
1749+
compute_capability,
1750+
header_code,
1751+
cuda_include_dirs,
1752+
nvcc_options,
1753+
)
1754+
1755+
# Load the module and get the kernel
1756+
result = _cuda_load_module(ptx, [kernel_name])
1757+
1758+
if isinstance(result, dict):
1759+
return result[kernel_name]
1760+
else:
1761+
# This branch shouldn't be executed if kernel_names is provided,
1762+
# but MyPy needs this to understand type narrowing
1763+
return getattr(result, kernel_name)
1764+
1765+
16961766
from . import amp, jiterator, nvtx, profiler, sparse, tunable
16971767

16981768

0 commit comments

Comments
 (0)