Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/Tiled-MM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ add_library(Tiled-MM gpu_context.cpp
tiled_mm.cpp
tile_coord.cpp)

# Add BF16 conversion kernels if support is enabled
if(TILED_MM_HAS_BF16_SUPPORT)
message(STATUS "Adding BF16 conversion kernels to Tiled-MM")
if(${TILED_MM_CUDA})
message(STATUS " - Using CUDA backend: bf16_convert.cu")
target_sources(Tiled-MM PRIVATE bf16_convert.cu)
elseif(${TILED_MM_ROCM})
message(STATUS " - Using ROCm backend: bf16_convert.hip")
target_sources(Tiled-MM PRIVATE bf16_convert.hip)
endif()
endif()

target_link_libraries(Tiled-MM PUBLIC
$<$<BOOL:${TILED_MM_ROCM}>:roc::rocblas>
$<$<BOOL:${TILED_MM_CUDA}>:CUDA::cublas CUDA::cudart>
Expand Down
100 changes: 100 additions & 0 deletions src/Tiled-MM/bf16_convert.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* BFloat16 Conversion Kernels for CUDA
*
* Implements efficient FP32 ↔ BF16 conversion using CUDA intrinsics.
*/
#include "bf16_convert.hpp"

#if defined(TILED_MM_CUDA)

namespace gpu {
namespace bf16_convert {

// Kernel configuration
constexpr int THREADS_PER_BLOCK = 256;

/**
* @brief CUDA kernel to convert FP32 → BF16
*
* Uses __float2bfloat16 intrinsic for hardware-accelerated conversion.
* Applies round-to-nearest-even (RNE) rounding.
*/
__global__ void fp32_to_bf16_kernel(
const float* __restrict__ input,
__nv_bfloat16* __restrict__ output,
size_t n) {

size_t idx = blockIdx.x * blockDim.x + threadIdx.x;

if (idx < n) {
// CUDA intrinsic: converts FP32 to BF16 with RNE rounding
// Available on all GPUs (software fallback on pre-Ampere)
output[idx] = __float2bfloat16(input[idx]);
}
}

/**
* @brief CUDA kernel to convert BF16 → FP32
*
* Uses __bfloat162float intrinsic for hardware-accelerated conversion.
* Conversion is lossless (zero-extends mantissa).
*/
__global__ void bf16_to_fp32_kernel(
const __nv_bfloat16* __restrict__ input,
float* __restrict__ output,
size_t n) {

size_t idx = blockIdx.x * blockDim.x + threadIdx.x;

if (idx < n) {
// CUDA intrinsic: converts BF16 to FP32 (lossless)
output[idx] = __bfloat162float(input[idx]);
}
}

// Host-side wrapper functions

void convert_fp32_to_bf16(
const float* d_input,
BF16Type* d_output,
size_t n,
StreamType stream) {

if (n == 0) return;

// Calculate grid dimensions
int blocks = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;

// Launch kernel
fp32_to_bf16_kernel<<<blocks, THREADS_PER_BLOCK, 0, stream>>>(
d_input,
reinterpret_cast<__nv_bfloat16*>(d_output),
n);

// Note: Kernel is asynchronous. Caller should sync stream if needed.
}

void convert_bf16_to_fp32(
const BF16Type* d_input,
float* d_output,
size_t n,
StreamType stream) {

if (n == 0) return;

// Calculate grid dimensions
int blocks = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;

// Launch kernel
bf16_to_fp32_kernel<<<blocks, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(d_input),
d_output,
n);

// Note: Kernel is asynchronous. Caller should sync stream if needed.
}

} // namespace bf16_convert
} // namespace gpu

#endif // TILED_MM_CUDA
110 changes: 110 additions & 0 deletions src/Tiled-MM/bf16_convert.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* BFloat16 Conversion Kernels for ROCm/HIP
*
* Implements efficient FP32 ↔ BF16 conversion using HIP intrinsics.
*/
#include "bf16_convert.hpp"

#if defined(TILED_MM_ROCM)

namespace gpu {
namespace bf16_convert {

// Kernel configuration
constexpr int THREADS_PER_BLOCK = 256;

/**
* @brief HIP kernel to convert FP32 → BF16
*
* Uses float_to_bfloat16 intrinsic for hardware-accelerated conversion.
* Applies round-to-nearest-even (RNE) rounding.
*/
__global__ void fp32_to_bf16_kernel(
const float* __restrict__ input,
hip_bfloat16* __restrict__ output,
size_t n) {

size_t idx = blockIdx.x * blockDim.x + threadIdx.x;

if (idx < n) {
// HIP intrinsic: converts FP32 to BF16 with RNE rounding
// Hardware-accelerated on CDNA2+ (MI200 series)
output[idx] = float_to_bfloat16(input[idx]);
}
}

/**
* @brief HIP kernel to convert BF16 → FP32
*
* Uses bfloat16_to_float intrinsic for hardware-accelerated conversion.
* Conversion is lossless (zero-extends mantissa).
*/
__global__ void bf16_to_fp32_kernel(
const hip_bfloat16* __restrict__ input,
float* __restrict__ output,
size_t n) {

size_t idx = blockIdx.x * blockDim.x + threadIdx.x;

if (idx < n) {
// HIP intrinsic: converts BF16 to FP32 (lossless)
output[idx] = bfloat16_to_float(input[idx]);
}
}

// Host-side wrapper functions

void convert_fp32_to_bf16(
const float* d_input,
BF16Type* d_output,
size_t n,
StreamType stream) {

if (n == 0) return;

// Calculate grid dimensions
int blocks = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;

// Launch kernel
hipLaunchKernelGGL(
fp32_to_bf16_kernel,
dim3(blocks),
dim3(THREADS_PER_BLOCK),
0,
stream,
d_input,
reinterpret_cast<hip_bfloat16*>(d_output),
n);

// Note: Kernel is asynchronous. Caller should sync stream if needed.
}

void convert_bf16_to_fp32(
const BF16Type* d_input,
float* d_output,
size_t n,
StreamType stream) {

if (n == 0) return;

// Calculate grid dimensions
int blocks = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;

// Launch kernel
hipLaunchKernelGGL(
bf16_to_fp32_kernel,
dim3(blocks),
dim3(THREADS_PER_BLOCK),
0,
stream,
reinterpret_cast<const hip_bfloat16*>(d_input),
d_output,
n);

// Note: Kernel is asynchronous. Caller should sync stream if needed.
}

} // namespace bf16_convert
} // namespace gpu

#endif // TILED_MM_ROCM
73 changes: 73 additions & 0 deletions src/Tiled-MM/bf16_convert.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* BFloat16 Conversion Utilities for GPU
*
* Provides device-side conversion kernels between FP32 and BF16.
* Supports both CUDA (NVIDIA) and ROCm (AMD) backends.
*/
#pragma once

#include <cstddef>

#if defined(TILED_MM_CUDA)
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#elif defined(TILED_MM_ROCM)
#include <hip/hip_runtime.h>
#include <hip/hip_bfloat16.h>
#else
#error Either TILED_MM_CUDA or TILED_MM_ROCM must be defined!
#endif

namespace gpu {
namespace bf16_convert {

#if defined(TILED_MM_CUDA)
using StreamType = cudaStream_t;
using BF16Type = __nv_bfloat16;
#elif defined(TILED_MM_ROCM)
using StreamType = hipStream_t;
using BF16Type = hip_bfloat16;
#endif

/**
* @brief Convert FP32 array to BF16 on device
*
* Launches a GPU kernel to convert a device-allocated FP32 array to BF16.
* Uses hardware intrinsics for efficient conversion.
*
* @param d_input Device pointer to FP32 input array
* @param d_output Device pointer to BF16 output array
* @param n Number of elements to convert
* @param stream CUDA/HIP stream for asynchronous execution
*
* @note Both pointers must be device-allocated (cudaMalloc/hipMalloc)
* @note Kernel is launched asynchronously; use stream synchronization if needed
*/
void convert_fp32_to_bf16(
const float* d_input,
BF16Type* d_output,
size_t n,
StreamType stream = 0);

/**
* @brief Convert BF16 array to FP32 on device
*
* Launches a GPU kernel to convert a device-allocated BF16 array to FP32.
* Conversion is lossless (BF16 is truncated FP32).
*
* @param d_input Device pointer to BF16 input array
* @param d_output Device pointer to FP32 output array
* @param n Number of elements to convert
* @param stream CUDA/HIP stream for asynchronous execution
*
* @note Both pointers must be device-allocated (cudaMalloc/hipMalloc)
* @note Kernel is launched asynchronously; use stream synchronization if needed
*/
void convert_bf16_to_fp32(
const BF16Type* d_input,
float* d_output,
size_t n,
StreamType stream = 0);

} // namespace bf16_convert
} // namespace gpu
54 changes: 54 additions & 0 deletions src/Tiled-MM/gpu_blas_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@

#if defined(TILED_MM_CUDA)
#include <cublas_v2.h>
#include <cuda_bf16.h> // For __nv_bfloat16 type

#elif defined(TILED_MM_ROCM)
#include <rocblas/rocblas.h>
#include <hip/hip_bfloat16.h> // For hip_bfloat16 type

#else
#error Either TILED_MM_CUDA or TILED_MM_ROCM must be defined!
Expand Down Expand Up @@ -251,6 +253,58 @@ inline auto zgemm(ARGS... args) -> StatusType {
#endif // TILED_MM_CUDA
}

// BFloat16 GEMM (mixed precision: BF16 × BF16 → FP32)
// Requires CUDA 11.0+ with Ampere (SM 80+) or ROCm 4.5+ with CDNA2 (gfx90a)
inline auto gemm_bf16(
HandleType handle,
OperationType trans_a,
OperationType trans_b,
int m, int n, int k,
const float* alpha, // FP32 scalar
const void* A, // BF16 matrix (device pointer)
int lda,
const void* B, // BF16 matrix (device pointer)
int ldb,
const float* beta, // FP32 scalar
float* C, // FP32 matrix (device pointer)
int ldc
) -> StatusType {
#if defined(TILED_MM_CUDA)
// Use cublasGemmEx for mixed-precision BF16 × BF16 → FP32
// Requires CUDA 11.0+ and Ampere GPU (SM 80+)
return cublasGemmEx(
handle,
trans_a, trans_b,
m, n, k,
alpha,
A, CUDA_R_16BF, lda, // BF16 input A
B, CUDA_R_16BF, ldb, // BF16 input B
beta,
C, CUDA_R_32F, ldc, // FP32 output C
CUBLAS_COMPUTE_32F, // FP32 accumulation
CUBLAS_GEMM_DEFAULT_TENSOR_OP // Use Tensor Cores if available
);
#elif defined(TILED_MM_ROCM)
// Use rocblas_gemm_ex for mixed-precision BF16 × BF16 → FP32
// Requires ROCm 4.5+ and CDNA2 GPU (gfx90a)
return rocblas_gemm_ex(
handle,
trans_a, trans_b,
m, n, k,
alpha,
A, rocblas_datatype_bf16_r, lda, // BF16 input A
B, rocblas_datatype_bf16_r, ldb, // BF16 input B
beta,
C, rocblas_datatype_f32_r, ldc, // FP32 output C (in)
C, rocblas_datatype_f32_r, ldc, // FP32 output C (out)
rocblas_datatype_f32_r, // FP32 compute type
rocblas_gemm_algo_standard,
0, // solution_index
0 // flags
);
#endif
}

} // namespace blas_api
} // namespace gpu

Loading