Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions extensions/csrc/kernel/cuda/convert_fp8_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#include <torch/extension.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDAContext.h>

#include <cmath>

#include "common/micros.h"
#include "utils/vec_copy.h"
#include "funcs/cast_functor.h"


using colossalAI::cuda::utils::copy;
using colossalAI::cuda::utils::get_vec_size;
using colossalAI::funcs::CastFunctor;

template <typename InT, typename OutT, int VecSize>
__global__ void convert_fp8_kernel(const InT* ins_data, OutT* outs_data, int numel, int tail)
{
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
const int64_t grid_size = blockDim.x * gridDim.x;
if(idx > numel + tail) {
return;
}

for(int64_t i = idx; i < numel; i += grid_size) {
copy<InT, OutT, VecSize>(ins_data + i * VecSize, outs_data + i * VecSize);
}
// Tail process
if(threadIdx.x == 0)
{
for(int i = 0; i < tail; ++i)
{
outs_data[i + numel * VecSize] = CastFunctor<InT, OutT>()(ins_data[i + numel * VecSize]);
}
}
}

template <typename InT, typename OutT>
void apply_convert_fp8(torch::Tensor& input, torch::Tensor& output)
{
const int kVecSize = get_vec_size<InT>(input);
const int kNumel = torch::numel(input);

const int kVecNumel = (kNumel >> static_cast<int>(std::log2(kVecSize)));
const int kTail = kNumel & (kVecSize - 1);
int grid_size = kVecNumel ? (kVecNumel + 255) / 256 : 1;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(grid_size);
dim3 block(256);

#define _(VEC_SIZE) \
convert_fp8_kernel<InT, OutT, VEC_SIZE> \
<<<grid, block, 0, stream>>> \
(reinterpret_cast<const InT*>(input.data_ptr()), \
reinterpret_cast<OutT*>(output.data_ptr()), \
kVecNumel, \
kTail)

switch (kVecSize)
{
case 1:
_(1);
break;
case 2:
_(2);
break;
case 4:
_(4);
break;
}
#undef _
AT_CUDA_CHECK(cudaGetLastError());
}

void convert_fp8(torch::Tensor& input, torch::Tensor& output)
{
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || output.scalar_type() == at::ScalarType::Byte, "Data type of Input or Output should be torch.uint8 for convert_fp8!");
TORCH_CHECK(input.scalar_type() != output.scalar_type(), "Data type of input and output are the same!");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte ||
input.scalar_type() == at::ScalarType::Float ||
input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of input!");
TORCH_CHECK(output.scalar_type() == at::ScalarType::Byte ||
output.scalar_type() == at::ScalarType::Float ||
output.scalar_type() == at::ScalarType::Half ||
output.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of output!");
TORCH_CHECK(input.sizes() == output.sizes(), "Shape of input and output should be the same!");

#define _(InT, OutT) \
apply_convert_fp8<InT, OutT>(input, output)


if(input.scalar_type() == at::ScalarType::Byte)
{
if(output.scalar_type() == at::ScalarType::Float)
{
_(uint8_t, float);
}
else if(output.scalar_type() == at::ScalarType::Half)
{
_(uint8_t, half);
}
else if(output.scalar_type() == at::ScalarType::BFloat16)
{
_(uint8_t, __nv_bfloat16);
}
}
else
{
if(input.scalar_type() == at::ScalarType::Float)
{
_(float, uint8_t);
}
else if(input.scalar_type() == at::ScalarType::Half)
{
_(half, uint8_t);
}
else if(input.scalar_type() == at::ScalarType::BFloat16)
{
_(__nv_bfloat16, uint8_t);
}
}

#undef _
}
17 changes: 7 additions & 10 deletions extensions/csrc/kernel/cuda/utils/vec_copy.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@

#pragma once

#include <cuda_fp16.h>
#include <stdint.h>

#include "common/vec_type_traits.h"
#include "funcs/cast_functor.h"

Expand All @@ -12,9 +9,9 @@ namespace cuda {
namespace utils {

// Note(LiuYang): Depreciated
template <typename T, int vec_size>
template <typename T, int VecSize>
__device__ __inline__ void copy_vector(T *dst, const T *src) {
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
}

Expand All @@ -34,17 +31,17 @@ __device__ __inline__ void copy_zero_vector(T *dst) {
*(reinterpret_cast<VT *>(dst)) = funcs::CastFunctor<float, VT>()(0.0f);
}

template <typename SrcT, typename DstT, int vec_size>
template <typename SrcT, typename DstT, int VecSize>
__device__ __inline__ void copy(const SrcT *src, DstT *dst) {
using SrcVT = typename common::VecTypeTrait<SrcT, vec_size>::Type;
using DstVT = typename common::VecTypeTrait<DstT, vec_size>::Type;
using SrcVT = typename common::VecTypeTrait<SrcT, VecSize>::Type;
using DstVT = typename common::VecTypeTrait<DstT, VecSize>::Type;
*(reinterpret_cast<DstVT *>(dst)) = funcs::CastFunctor<SrcVT, DstVT>()(
*(reinterpret_cast<const SrcVT *>(src)));
}

template <typename T, int vec_size>
template <typename T, int VecSize>
__device__ __inline__ void copy(const T *src, T *dst) {
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
}

Expand Down
5 changes: 5 additions & 0 deletions extensions/pybind/inference/inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ void flash_decoding_attention(
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
const c10::optional<torch::Tensor>& alibi_slopes, float scale);

void convert_fp8(torch::Tensor& input, torch::Tensor& output);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage.");
Expand Down Expand Up @@ -102,4 +104,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("flash_decoding_attention", &flash_decoding_attention,
"Compute the attention between an input query and the cached "
"keys/values using PagedAttention.");

m.def("convert_fp8", &convert_fp8,
"Convert input to fp8 output or convert fp8 input to output.");
}
1 change: 1 addition & 0 deletions extensions/pybind/inference/inference_ops_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def sources_files(self):
"kernel/cuda/rms_layernorm_kernel.cu",
"kernel/cuda/get_cos_and_sin_kernel.cu",
"kernel/cuda/flash_decoding_attention_kernel.cu",
"kernel/cuda/convert_fp8_kernel.cu",
]
] + [self.pybind_abs_path("inference/inference.cpp")]
return ret
Expand Down
57 changes: 57 additions & 0 deletions tests/test_infer/test_kernels/cuda/test_convert_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import random

import pytest
import torch

from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device

inference_ops = InferenceOpsLoader().load()

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]


@pytest.mark.skipif(True, reason="FP8 conversion still needs improvement, now we skip it's relative test!")
@pytest.mark.parametrize("num_heads", [8])
@pytest.mark.parametrize("head_size", [64, 80, 96, 112, 128, 256])
@pytest.mark.parametrize("block_size", [8, 16, 32])
@pytest.mark.parametrize("num_blocks", [1024, 10000])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float])
@pytest.mark.parametrize("seed", [0])
@torch.inference_mode()
def test_fp8_conversion(
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = get_current_device()

low = -224.0
high = 224.0
shape = (num_blocks, num_heads, head_size, block_size)
cache = torch.empty(shape, dtype=dtype, device=device)
cache.uniform_(low, high)

cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
inference_ops.convert_fp8(cache, cache_fp8)

converted_cache = torch.empty_like(cache)
inference_ops.convert_fp8(cache_fp8, converted_cache)

assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)


if __name__ == "__main__":
test_fp8_conversion(8, 64, 8, 1024, torch.half, 0)