Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ endif()

# Define included source files
set(CPP_FILES csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
set(GPU_FILES csrc/ops.cu csrc/kernels.cu)
set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal)
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
Expand Down Expand Up @@ -225,7 +224,7 @@ if(BUILD_CUDA)
message(STATUS "CUDA Targets: ${CMAKE_CUDA_ARCHITECTURES}")
message(STATUS "CUDA NVCC Flags: ${CMAKE_CUDA_FLAGS}")

list(APPEND SRC_FILES ${CUDA_FILES})
list(APPEND SRC_FILES ${GPU_FILES})

string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
add_compile_definitions(BUILD_CUDA)
Expand All @@ -244,7 +243,7 @@ elseif(BUILD_HIP)
message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}")
message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}")

list(APPEND SRC_FILES ${HIP_FILES})
list(APPEND SRC_FILES ${GPU_FILES})

string(APPEND BNB_OUTPUT_NAME "_rocm")

Expand Down Expand Up @@ -389,7 +388,7 @@ if(BUILD_HIP)
endif()

target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
set_source_files_properties(${GPU_FILES} PROPERTIES LANGUAGE HIP)
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)

if(HIP_VERSION VERSION_LESS "6.1")
Expand Down
58 changes: 47 additions & 11 deletions csrc/common.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,32 @@
// common.cuh — Architecture constants and feature detection

#pragma once

// TODO: Let's make some of these constexpr and put in a namespace.
#include "compat.cuh"

// Warp size

#if BNB_HIP
// CDNA (gfx9xx) = 64, RDNA = 32.
#ifdef __AMDGCN_WAVEFRONT_SIZE
#define BNB_WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
#else
#define BNB_WARP_SIZE 64 // Safe default for HIP (matches CDNA)
#endif
#else
#define BNB_WARP_SIZE 32
#endif

// BF16 availability

#if BNB_HIP
// BF16 is available on all currently-supported ROCm architectures (CDNA2+, RDNA3+)
#define BNB_BF16_AVAILABLE true
#else
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)
#endif

// Compute capability constants

#define BNB_CC_PASCAL 600
#define BNB_CC_PASCAL_X2 620
Expand All @@ -14,31 +40,41 @@
#define BNB_CC_HOPPER 900
#define BNB_CC_BLACKWELL 1000

// Feature availability based on arch

#if BNB_HIP
// HIP: MMA not supported via mma.h; FP8 support varies by arch
#define BNB_FP16_MMA_AVAILABLE 0
#define BNB_INT8_MMA_AVAILABLE 0
#define BNB_FP8_AVAILABLE 0
#else
#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA)
#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER)
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)
#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA)
#endif

#define BNB_WARP_SIZE 32
// Maximum threads per SM/CU

// The maximum number of resident threads per SM varies by arch.
// For A100/H100 and all prior to Turing, it is 2048, which allows
// for 2 full blocks of 1024 threads per SM.
// Reference:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
#if BNB_HIP
// For currently supported ROCm architectures (CDNA2, RDNA3)
#define BNB_MAX_THREADS_PER_SM 2048
#else
// The maximum number of resident threads per SM varies by NVIDIA arch.
// Reference: CUDA Programming Guide, Technical Specifications per Compute Capability
#if __CUDA_ARCH__ == 750
#define BNB_MAX_THREADS_PER_SM 1024
#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890
#define BNB_MAX_THREADS_PER_SM 1536
#else
#define BNB_MAX_THREADS_PER_SM 2048
#endif
#endif

// Maximum resident warps per SM is always directly related to the number of threads.
// Maximum resident warps per SM/CU
#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE))

// Maximum resident blocks per SM may vary.
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870
// Maximum resident blocks per SM/CU
#if !BNB_HIP && (defined(__CUDA_ARCH__)) && (__CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870)
#define BNB_MAX_BLOCKS_PER_SM 16
#else
#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2)
Expand Down
11 changes: 0 additions & 11 deletions csrc/common_hip.cuh

This file was deleted.

181 changes: 181 additions & 0 deletions csrc/compat.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// compat.cuh — Platform abstraction layer for CUDA/HIP portability
//
// This header resolves ALL mechanical differences between CUDA and HIP.
// Kernel code should include this header and use the bnb_* types/macros
// instead of cuda*/hip* identifiers directly.
//
// The guard macro is BNB_HIP, which is defined when compiling for ROCm/HIP
// (set via CMakeLists.txt's add_compile_definitions(__HIP_PLATFORM_AMD__)).

#pragma once

// Platform detection

#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
#define BNB_HIP 1
#else
#define BNB_HIP 0
#endif

// Runtime and FP16/BF16 headers

#if BNB_HIP

#include <hip/hip_bfloat16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_math_constants.h>
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <rocblas/rocblas.h>

#else // CUDA

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#endif

// Stream and error types

#if BNB_HIP

using bnb_stream_t = hipStream_t;
using bnb_error_t = hipError_t;

#define BNB_SUCCESS hipSuccess
#define BNB_PEEK_LAST_ERROR() hipPeekAtLastError()
#define BNB_GET_ERROR_STRING(e) hipGetErrorString(e)
#define BNB_DEVICE_MALLOC(p, s) hipMalloc(p, s)
#define BNB_DEVICE_FREE(p) hipFree(p)
#define BNB_DEVICE_MEMSET(p, v, s) hipMemset(p, v, s)

#else // CUDA

using bnb_stream_t = cudaStream_t;
using bnb_error_t = cudaError_t;

#define BNB_SUCCESS cudaSuccess
#define BNB_PEEK_LAST_ERROR() cudaPeekAtLastError()
#define BNB_GET_ERROR_STRING(e) cudaGetErrorString(e)
#define BNB_DEVICE_MALLOC(p, s) cudaMalloc(p, s)
#define BNB_DEVICE_FREE(p) cudaFree(p)
#define BNB_DEVICE_MEMSET(p, v, s) cudaMemset(p, v, s)

#endif

// Error checking

#define BNB_CHECK_RETURN(value) \
{ \
bnb_error_t _bnb_stat = value; \
if (_bnb_stat != BNB_SUCCESS) { \
fprintf(stderr, "Error %s at line %d in file %s\n", BNB_GET_ERROR_STRING(_bnb_stat), __LINE__, __FILE__); \
exit(1); \
} \
}

// Keep backward compat for existing code during migration
#define CUDA_CHECK_RETURN(value) BNB_CHECK_RETURN(value)

// Warp synchronization
//
// HIP warps are always in lockstep (no independent thread scheduling),
// so __syncwarp() is a no-op. CUDA needs it for warp convergence.

#if BNB_HIP
#define __syncwarp() \
do { \
} while (0)
#endif

// BFloat16 type alias

#if BNB_HIP
using bnb_bfloat16 = hip_bfloat16;
#else
using bnb_bfloat16 = __nv_bfloat16;
#endif

// Data type enum aliases for BLAS libraries

#if BNB_HIP

#define BNB_R_16F HIP_R_16F
#define BNB_R_32F HIP_R_32F
#define BNB_R_8I HIP_R_8I
#define BNB_R_32I HIP_R_32I

#else // CUDA

#define BNB_R_16F CUDA_R_16F
#define BNB_R_32F CUDA_R_32F
#define BNB_R_8I CUDA_R_8I
#define BNB_R_32I CUDA_R_32I

#endif

// BLAS Lt types and functions

#if BNB_HIP

#ifndef NO_HIPBLASLT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll get rid of NO_HIPBLASLT, was only defined for old ROCm versions. But we can leave in for now.

#include <hipblaslt/hipblaslt.h>
#endif

using bnb_blasLt_handle_t = hipblasLtHandle_t;
using bnb_blasLt_matmul_desc_t = hipblasLtMatmulDesc_t;
using bnb_blasLt_layout_t = hipblasLtMatrixLayout_t;
using bnb_blasLt_preference_t = hipblasLtMatmulPreference_t;

#define BNB_BLASLT_OP_T HIPBLAS_OP_T
#define BNB_BLASLT_COMPUTE_32I HIPBLAS_COMPUTE_32I

#define bnb_blasLtCreate hipblasLtCreate
#define bnb_blasLtMatmulDescCreate hipblasLtMatmulDescCreate
#define bnb_blasLtMatmulDescSetAttr hipblasLtMatmulDescSetAttribute
#define bnb_blasLtLayoutCreate hipblasLtMatrixLayoutCreate
#define bnb_blasLtLayoutDestroy hipblasLtMatrixLayoutDestroy
#define bnb_blasLtMatmulDescDestroy hipblasLtMatmulDescDestroy
#define bnb_blasLtMatmul hipblasLtMatmul
#define bnb_blasLtPrefCreate hipblasLtMatmulPreferenceCreate
#define bnb_blasLtPrefSetAttr hipblasLtMatmulPreferenceSetAttribute
#define bnb_blasLtAlgoGetHeuristic hipblasLtMatmulAlgoGetHeuristic

#define BNB_BLASLT_DESC_TRANSA HIPBLASLT_MATMUL_DESC_TRANSA
#define BNB_BLASLT_DESC_POINTER_MODE HIPBLASLT_MATMUL_DESC_POINTER_MODE
#define BNB_BLASLT_PREF_MAX_WORKSPACE HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
#define BNB_BLASLT_PTR_MODE_ALPHA_VEC HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST

using bnb_blasLt_heuristic_t = hipblasLtMatmulHeuristicResult_t;
using bnb_blas_status_t = hipblasStatus_t;
#define BNB_BLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS

#else // CUDA

#include <cublasLt.h>
#include <cublas_v2.h>

using bnb_blasLt_handle_t = cublasLtHandle_t;
using bnb_blasLt_matmul_desc_t = cublasLtMatmulDesc_t;
using bnb_blasLt_layout_t = cublasLtMatrixLayout_t;

#define BNB_BLASLT_OP_T CUBLAS_OP_T
#define BNB_BLASLT_COMPUTE_32I CUBLAS_COMPUTE_32I

#define bnb_blasLtCreate cublasLtCreate
#define bnb_blasLtMatmulDescCreate cublasLtMatmulDescCreate
#define bnb_blasLtMatmulDescSetAttr cublasLtMatmulDescSetAttribute
#define bnb_blasLtLayoutCreate cublasLtMatrixLayoutCreate
#define bnb_blasLtLayoutDestroy cublasLtMatrixLayoutDestroy
#define bnb_blasLtMatmulDescDestroy cublasLtMatmulDescDestroy
#define bnb_blasLtMatmul cublasLtMatmul

#define BNB_BLASLT_DESC_TRANSA CUBLASLT_MATMUL_DESC_TRANSA
#define BNB_BLASLT_DESC_POINTER_MODE CUBLASLT_MATMUL_DESC_POINTER_MODE
#define BNB_BLASLT_PTR_MODE_ALPHA_VEC CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO

using bnb_blas_status_t = cublasStatus_t;
#define BNB_BLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS

#endif
51 changes: 51 additions & 0 deletions csrc/compat_device.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// compat_device.cuh — Device-only portability layer (CUB, reduction ops, MMA)
//
// Include this from .cu kernel files only (compiled by nvcc/hipcc).
// Do NOT include from .cpp files — use compat.cuh instead for host-safe types.

#pragma once

#include "compat.cuh"

// CUB / hipCUB — namespace alias

#if BNB_HIP

#include <hipcub/hipcub.hpp>
namespace bnb_cub = hipcub;

#else // CUDA

#include <cub/block/block_discontinuity.cuh>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include <cub/warp/warp_reduce.cuh>
#include <math_constants.h>
#include <mma.h>
namespace bnb_cub = cub;

#endif

// Reduction operators

#if BNB_HIP

#define BNB_MAX_OP hipcub::Max()
#define BNB_SUM_OP hipcub::Sum()

#else // CUDA

// CCCL 2.8.2+ moved to cuda::maximum<>{}, older versions use cub::Max()
#if defined(CCCL_VERSION) && CCCL_VERSION >= 2008002
#include <cuda/std/functional>
#define BNB_MAX_OP \
cuda::maximum<> {}
#else
#define BNB_MAX_OP cub::Max()
#endif
#define BNB_SUM_OP cub::Sum()

#endif
Loading