-
-
Notifications
You must be signed in to change notification settings - Fork 827
Unify CUDA and HIP kernel sources via compat.cuh portability layer #1877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
matthewdouglas
merged 16 commits into
bitsandbytes-foundation:main
from
Abdennacer-Badaoui:merge-cuda-hip
Feb 24, 2026
+644
−3,023
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
83216e0
first commit
Abdennacer-Badaoui 9c69888
update
Abdennacer-Badaoui c268706
Merge branch 'main' into merge-cuda-hip
Abdennacer-Badaoui b96843c
Merge branch 'main' into merge-cuda-hip
Abdennacer-Badaoui d7f3e15
Activate unified CUDA/HIP kernel files from csrc/examples/
TimDettmers bef474a
Fix HIP build errors in unified kernel files
TimDettmers 214943c
Restore common.h include in ops.cuh for DataType_t enum
TimDettmers 4fae098
Guard blocksize=64 quantize instantiations for warp size compatibility
TimDettmers c538ced
Guard all blocksize=64 quantize instantiations for warp size compat
TimDettmers 0b33411
Use conditional load/store algo for warp size compatibility
TimDettmers 50cef42
Fix BNB_WARP_SIZE detection for HIP host compilation pass
TimDettmers 32cd056
Remove blocksize=64 instantiation guards
TimDettmers 9133f46
Apply clang-format formatting fixes
TimDettmers ebdda00
merge unified-hip-validation & code cleaning
Abdennacer-Badaoui f887942
Merge upstream/main into merge-cuda-hip
Abdennacer-Badaoui 57156ee
Merge remote-tracking branch 'upstream/main' into merge-cuda-hip
Abdennacer-Badaoui File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,181 @@ | ||
| // compat.cuh — Platform abstraction layer for CUDA/HIP portability | ||
| // | ||
| // This header resolves ALL mechanical differences between CUDA and HIP. | ||
| // Kernel code should include this header and use the bnb_* types/macros | ||
| // instead of cuda*/hip* identifiers directly. | ||
| // | ||
| // The guard macro is BNB_HIP, which is defined when compiling for ROCm/HIP | ||
| // (set via CMakeLists.txt's add_compile_definitions(__HIP_PLATFORM_AMD__)). | ||
|
|
||
| #pragma once | ||
|
|
||
| // Platform detection | ||
|
|
||
| #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) | ||
| #define BNB_HIP 1 | ||
| #else | ||
| #define BNB_HIP 0 | ||
| #endif | ||
|
|
||
| // Runtime and FP16/BF16 headers | ||
|
|
||
| #if BNB_HIP | ||
|
|
||
| #include <hip/hip_bfloat16.h> | ||
| #include <hip/hip_fp16.h> | ||
| #include <hip/hip_math_constants.h> | ||
| #include <hip/hip_runtime.h> | ||
| #include <hipblas/hipblas.h> | ||
| #include <rocblas/rocblas.h> | ||
|
|
||
| #else // CUDA | ||
|
|
||
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #endif | ||
|
|
||
| // Stream and error types | ||
|
|
||
| #if BNB_HIP | ||
|
|
||
| using bnb_stream_t = hipStream_t; | ||
| using bnb_error_t = hipError_t; | ||
|
|
||
| #define BNB_SUCCESS hipSuccess | ||
| #define BNB_PEEK_LAST_ERROR() hipPeekAtLastError() | ||
| #define BNB_GET_ERROR_STRING(e) hipGetErrorString(e) | ||
| #define BNB_DEVICE_MALLOC(p, s) hipMalloc(p, s) | ||
| #define BNB_DEVICE_FREE(p) hipFree(p) | ||
| #define BNB_DEVICE_MEMSET(p, v, s) hipMemset(p, v, s) | ||
|
|
||
| #else // CUDA | ||
|
|
||
| using bnb_stream_t = cudaStream_t; | ||
| using bnb_error_t = cudaError_t; | ||
|
|
||
| #define BNB_SUCCESS cudaSuccess | ||
| #define BNB_PEEK_LAST_ERROR() cudaPeekAtLastError() | ||
| #define BNB_GET_ERROR_STRING(e) cudaGetErrorString(e) | ||
| #define BNB_DEVICE_MALLOC(p, s) cudaMalloc(p, s) | ||
| #define BNB_DEVICE_FREE(p) cudaFree(p) | ||
| #define BNB_DEVICE_MEMSET(p, v, s) cudaMemset(p, v, s) | ||
|
|
||
| #endif | ||
|
|
||
| // Error checking | ||
|
|
||
| #define BNB_CHECK_RETURN(value) \ | ||
| { \ | ||
| bnb_error_t _bnb_stat = value; \ | ||
| if (_bnb_stat != BNB_SUCCESS) { \ | ||
| fprintf(stderr, "Error %s at line %d in file %s\n", BNB_GET_ERROR_STRING(_bnb_stat), __LINE__, __FILE__); \ | ||
| exit(1); \ | ||
| } \ | ||
| } | ||
|
|
||
| // Keep backward compat for existing code during migration | ||
| #define CUDA_CHECK_RETURN(value) BNB_CHECK_RETURN(value) | ||
|
|
||
| // Warp synchronization | ||
| // | ||
| // HIP warps are always in lockstep (no independent thread scheduling), | ||
| // so __syncwarp() is a no-op. CUDA needs it for warp convergence. | ||
|
|
||
| #if BNB_HIP | ||
| #define __syncwarp() \ | ||
| do { \ | ||
| } while (0) | ||
| #endif | ||
|
|
||
| // BFloat16 type alias | ||
|
|
||
| #if BNB_HIP | ||
| using bnb_bfloat16 = hip_bfloat16; | ||
| #else | ||
| using bnb_bfloat16 = __nv_bfloat16; | ||
| #endif | ||
|
|
||
| // Data type enum aliases for BLAS libraries | ||
|
|
||
| #if BNB_HIP | ||
|
|
||
| #define BNB_R_16F HIP_R_16F | ||
| #define BNB_R_32F HIP_R_32F | ||
| #define BNB_R_8I HIP_R_8I | ||
| #define BNB_R_32I HIP_R_32I | ||
|
|
||
| #else // CUDA | ||
|
|
||
| #define BNB_R_16F CUDA_R_16F | ||
| #define BNB_R_32F CUDA_R_32F | ||
| #define BNB_R_8I CUDA_R_8I | ||
| #define BNB_R_32I CUDA_R_32I | ||
|
|
||
| #endif | ||
|
|
||
| // BLAS Lt types and functions | ||
|
|
||
| #if BNB_HIP | ||
|
|
||
| #ifndef NO_HIPBLASLT | ||
| #include <hipblaslt/hipblaslt.h> | ||
| #endif | ||
|
|
||
| using bnb_blasLt_handle_t = hipblasLtHandle_t; | ||
| using bnb_blasLt_matmul_desc_t = hipblasLtMatmulDesc_t; | ||
| using bnb_blasLt_layout_t = hipblasLtMatrixLayout_t; | ||
| using bnb_blasLt_preference_t = hipblasLtMatmulPreference_t; | ||
|
|
||
| #define BNB_BLASLT_OP_T HIPBLAS_OP_T | ||
| #define BNB_BLASLT_COMPUTE_32I HIPBLAS_COMPUTE_32I | ||
|
|
||
| #define bnb_blasLtCreate hipblasLtCreate | ||
| #define bnb_blasLtMatmulDescCreate hipblasLtMatmulDescCreate | ||
| #define bnb_blasLtMatmulDescSetAttr hipblasLtMatmulDescSetAttribute | ||
| #define bnb_blasLtLayoutCreate hipblasLtMatrixLayoutCreate | ||
| #define bnb_blasLtLayoutDestroy hipblasLtMatrixLayoutDestroy | ||
| #define bnb_blasLtMatmulDescDestroy hipblasLtMatmulDescDestroy | ||
| #define bnb_blasLtMatmul hipblasLtMatmul | ||
| #define bnb_blasLtPrefCreate hipblasLtMatmulPreferenceCreate | ||
| #define bnb_blasLtPrefSetAttr hipblasLtMatmulPreferenceSetAttribute | ||
| #define bnb_blasLtAlgoGetHeuristic hipblasLtMatmulAlgoGetHeuristic | ||
|
|
||
| #define BNB_BLASLT_DESC_TRANSA HIPBLASLT_MATMUL_DESC_TRANSA | ||
| #define BNB_BLASLT_DESC_POINTER_MODE HIPBLASLT_MATMUL_DESC_POINTER_MODE | ||
| #define BNB_BLASLT_PREF_MAX_WORKSPACE HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES | ||
| #define BNB_BLASLT_PTR_MODE_ALPHA_VEC HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST | ||
|
|
||
| using bnb_blasLt_heuristic_t = hipblasLtMatmulHeuristicResult_t; | ||
| using bnb_blas_status_t = hipblasStatus_t; | ||
| #define BNB_BLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS | ||
|
|
||
| #else // CUDA | ||
|
|
||
| #include <cublasLt.h> | ||
| #include <cublas_v2.h> | ||
|
|
||
| using bnb_blasLt_handle_t = cublasLtHandle_t; | ||
| using bnb_blasLt_matmul_desc_t = cublasLtMatmulDesc_t; | ||
| using bnb_blasLt_layout_t = cublasLtMatrixLayout_t; | ||
|
|
||
| #define BNB_BLASLT_OP_T CUBLAS_OP_T | ||
| #define BNB_BLASLT_COMPUTE_32I CUBLAS_COMPUTE_32I | ||
|
|
||
| #define bnb_blasLtCreate cublasLtCreate | ||
| #define bnb_blasLtMatmulDescCreate cublasLtMatmulDescCreate | ||
| #define bnb_blasLtMatmulDescSetAttr cublasLtMatmulDescSetAttribute | ||
| #define bnb_blasLtLayoutCreate cublasLtMatrixLayoutCreate | ||
| #define bnb_blasLtLayoutDestroy cublasLtMatrixLayoutDestroy | ||
| #define bnb_blasLtMatmulDescDestroy cublasLtMatmulDescDestroy | ||
| #define bnb_blasLtMatmul cublasLtMatmul | ||
|
|
||
| #define BNB_BLASLT_DESC_TRANSA CUBLASLT_MATMUL_DESC_TRANSA | ||
| #define BNB_BLASLT_DESC_POINTER_MODE CUBLASLT_MATMUL_DESC_POINTER_MODE | ||
| #define BNB_BLASLT_PTR_MODE_ALPHA_VEC CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO | ||
|
|
||
| using bnb_blas_status_t = cublasStatus_t; | ||
| #define BNB_BLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS | ||
|
|
||
| #endif | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| // compat_device.cuh — Device-only portability layer (CUB, reduction ops, MMA) | ||
| // | ||
| // Include this from .cu kernel files only (compiled by nvcc/hipcc). | ||
| // Do NOT include from .cpp files — use compat.cuh instead for host-safe types. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "compat.cuh" | ||
|
|
||
| // CUB / hipCUB — namespace alias | ||
|
|
||
| #if BNB_HIP | ||
|
|
||
| #include <hipcub/hipcub.hpp> | ||
| namespace bnb_cub = hipcub; | ||
|
|
||
| #else // CUDA | ||
|
|
||
| #include <cub/block/block_discontinuity.cuh> | ||
| #include <cub/block/block_load.cuh> | ||
| #include <cub/block/block_radix_sort.cuh> | ||
| #include <cub/block/block_reduce.cuh> | ||
| #include <cub/block/block_store.cuh> | ||
| #include <cub/cub.cuh> | ||
| #include <cub/warp/warp_reduce.cuh> | ||
| #include <math_constants.h> | ||
| #include <mma.h> | ||
| namespace bnb_cub = cub; | ||
|
|
||
| #endif | ||
|
|
||
| // Reduction operators | ||
|
|
||
| #if BNB_HIP | ||
|
|
||
| #define BNB_MAX_OP hipcub::Max() | ||
| #define BNB_SUM_OP hipcub::Sum() | ||
|
|
||
| #else // CUDA | ||
|
|
||
| // CCCL 2.8.2+ moved to cuda::maximum<>{}, older versions use cub::Max() | ||
| #if defined(CCCL_VERSION) && CCCL_VERSION >= 2008002 | ||
| #include <cuda/std/functional> | ||
| #define BNB_MAX_OP \ | ||
| cuda::maximum<> {} | ||
| #else | ||
| #define BNB_MAX_OP cub::Max() | ||
| #endif | ||
| #define BNB_SUM_OP cub::Sum() | ||
|
|
||
| #endif |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'll get rid of
NO_HIPBLASLT, was only defined for old ROCm versions. But we can leave in for now.