Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions extensions/csrc/common/mp_type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

#include "micros.h"

#if defined(COLOSSAL_WITH_CUDA)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#endif

namespace colossalAI {
namespace common {

Expand All @@ -27,6 +32,18 @@ struct MPTypeTrait<at::BFloat16> {
using Type = float;
};

#if defined(COLOSSAL_WITH_CUDA)
template <>
struct MPTypeTrait<half> {
using Type = float;
};

template <>
struct MPTypeTrait<__nv_bfloat16> {
using Type = float;
};
#endif

template <bool high_precision, typename T>
struct ScalarTypeTrait {
using Type =
Expand Down
19 changes: 19 additions & 0 deletions extensions/csrc/funcs/binary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE,
typename T)

#if defined(COLOSSAL_WITH_CUDA)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMinus,
DEVICE, STMTS_WRAPPER({
return __hsub(lhs, rhs);
}))

COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd,
DEVICE, STMTS_WRAPPER({
return __hadd(lhs, rhs);
Expand All @@ -71,6 +76,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
DEVICE, STMTS_WRAPPER({
return __hadd(lhs, rhs);
}))

COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
__nv_bfloat16, BinaryOpType::kMinus,
DEVICE, STMTS_WRAPPER({
return __hsub(lhs, rhs);
}))

COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162,
__nv_bfloat162, BinaryOpType::kAdd,
DEVICE, STMTS_WRAPPER({
Expand All @@ -82,6 +94,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
STMTS_WRAPPER({
return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs));
}))

COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMinus, DEVICE,
STMTS_WRAPPER({
return __float2bfloat16(__bfloat162float(lhs) - __bfloat162float(rhs));
}))

COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE,
STMTS_WRAPPER({
Expand Down
4 changes: 4 additions & 0 deletions extensions/csrc/funcs/cast_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, DEVICE,
STMTS_WRAPPER({
return __float2bfloat16_rn(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float, DEVICE,
STMTS_WRAPPER({
return __bfloat162float(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
STMTS_WRAPPER({
dtype::bfloat164 dst;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,6 @@ void context_kv_cache_memcpy(
int max_seq_len_in_batch)
{

TORCH_CHECK(key.scalar_type() == at::ScalarType::Float || key.scalar_type() == at::ScalarType::Half || key.scalar_type() == at::ScalarType::BFloat16,
"Dtype of key should be float, half or bfloat16!");
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key.scalar_type(),
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");


#define _(T, CacheT) \
apply_context_kv_cache_memcpy<T, CacheT>( \
key, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,6 @@ void flash_decoding_attention(
const c10::optional<torch::Tensor>& alibi_slopes,
float scale) {


TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16,
"Dtype of query should be float, half or bfloat16!");
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == query.scalar_type(),
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");

if(key_cache.scalar_type() == at::ScalarType::Byte)
{
switch (query.scalar_type()) {
Expand Down
Loading