Skip to content

Commit

Permalink
[DCU] New features for LLM (PaddlePaddle#65398)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuguo-Jack authored Jul 3, 2024
1 parent d2eb934 commit f561b06
Show file tree
Hide file tree
Showing 28 changed files with 1,394 additions and 306 deletions.
242 changes: 154 additions & 88 deletions cmake/external/flashattn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,108 +14,174 @@

include(ExternalProject)

add_definitions(-DPADDLE_WITH_FLASHATTN)

set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn)
set(FLASHATTN_TAG 5fc132ac11e78d26471ca09e5ba0cd817c3424d8)

set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
CACHE PATH "flash-attn Directory" FORCE)
set(FLASHATTN_LIB_DIR
"${FLASHATTN_INSTALL_DIR}/lib"
CACHE PATH "flash-attn Library Directory" FORCE)

if(WIN32)
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
else()
if(WITH_ROCM)
add_definitions(-DPADDLE_WITH_FLASHATTN)

set(FA_REPOSITORY https://github.com/PaddlePaddle/flash-attention.git)
set(FA_TAG "dcu")
set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn_hip)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn_hip)

set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
CACHE PATH "flash-attn Directory" FORCE)
set(FLASHATTN_LIB_DIR
"${FLASHATTN_INSTALL_DIR}/lib"
CACHE PATH "flash-attn Library Directory" FORCE)
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()

if(WIN32)
set(FLASHATTN_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_DEBUG
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(FLASHATTN_CXX_FLAGS
"${CMAKE_CXX_FLAGS} -w -Wno-deprecated-builtins -Wno-deprecated -DNDEBUG -U__HIP_NO_HALF_OPERATORS__ -U__HIP_NO_HALF_CONVERSIONS__ -fPIC -O3 -std=c++17 -D__HIP_PLATFORM_HCC__=1 --offload-arch=gfx928 -D__gfx940__ -mllvm -enable-num-vgprs-512=true"
)
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()

set(FA_NVCC_ARCH_BIN "")
foreach(arch ${NVCC_ARCH_BIN})
string(STRIP ${arch} arch)
if(arch STREQUAL "")
continue()
ExternalProject_Add(
extern_flashattn
GIT_REPOSITORY ${FA_REPOSITORY}
GIT_TAG ${FA_TAG}
SOURCE_DIR ${SOURCE_DIR}
PREFIX ${FLASHATTN_PREFIX_DIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${ROCM_PATH}/bin/hipcc
-DAMDGPU_TARGETS=gfx928
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${FLASHATTN_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=4
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES})
else()

add_definitions(-DPADDLE_WITH_FLASHATTN)

set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn)
set(FLASHATTN_TAG 5fc132ac11e78d26471ca09e5ba0cd817c3424d8)

set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
CACHE PATH "flash-attn Directory" FORCE)
set(FLASHATTN_LIB_DIR
"${FLASHATTN_INSTALL_DIR}/lib"
CACHE PATH "flash-attn Library Directory" FORCE)

if(WIN32)
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
else()
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()

if(FA_NVCC_ARCH_BIN STREQUAL "")
set(FA_NVCC_ARCH_BIN "${arch}")
if(WIN32)
set(FLASHATTN_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_DEBUG
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(FA_NVCC_ARCH_BIN "${FA_NVCC_ARCH_BIN}-${arch}")
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()
endforeach()

ExternalProject_Add(
extern_flashattn
${EXTERNAL_PROJECT_LOG_ARGS}
SOURCE_DIR ${SOURCE_DIR}
PREFIX ${FLASHATTN_PREFIX_DIR}
SOURCE_SUBDIR ${FLASHATTN_SOURCE_SUBDIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${FLASHATTN_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG}
-DCMAKE_CUDA_COMPILER_LAUNCHER=${CMAKE_CUDA_COMPILER_LAUNCHER}
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=4
-DNVCC_ARCH_BIN=${FA_NVCC_ARCH_BIN}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES})

set(FA_NVCC_ARCH_BIN "")
foreach(arch ${NVCC_ARCH_BIN})
string(STRIP ${arch} arch)
if(arch STREQUAL "")
continue()
endif()

if(FA_NVCC_ARCH_BIN STREQUAL "")
set(FA_NVCC_ARCH_BIN "${arch}")
else()
set(FA_NVCC_ARCH_BIN "${FA_NVCC_ARCH_BIN}-${arch}")
endif()
endforeach()

ExternalProject_Add(
extern_flashattn
${EXTERNAL_PROJECT_LOG_ARGS}
SOURCE_DIR ${SOURCE_DIR}
PREFIX ${FLASHATTN_PREFIX_DIR}
SOURCE_SUBDIR ${FLASHATTN_SOURCE_SUBDIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${FLASHATTN_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG}
-DCMAKE_CUDA_COMPILER_LAUNCHER=${CMAKE_CUDA_COMPILER_LAUNCHER}
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=4
-DNVCC_ARCH_BIN=${FA_NVCC_ARCH_BIN}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES})

endif()

message(STATUS "flash-attn library: ${FLASHATTN_LIBRARIES}")
get_filename_component(FLASHATTN_LIBRARY_PATH ${FLASHATTN_LIBRARIES} DIRECTORY)
Expand Down
6 changes: 6 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,12 @@ if(WITH_CUSPARSELT)
list(APPEND third_party_deps extern_cusparselt)
endif()

if(WITH_ROCM)
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON)
endif()

if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
Expand Down
15 changes: 15 additions & 0 deletions paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,21 @@ PHI_DEFINE_EXPORTED_int64(cudnn_exhaustive_search_times,
"Exhaustive search times for cuDNN convolution, "
"default is -1, not exhaustive search");

#ifdef PADDLE_WITH_HIP
/**
* MIOPEN related FLAG
* Name: FLAGS_batch_norm_use_miopen
* Since Version:
* Value Range:
* Example:
* Note: Use MIOpen batch norm instead of native
*/
PHI_DEFINE_EXPORTED_bool(batch_norm_use_miopen,
false,
"Whether use MIOpen batch norm or not, "
"default is false, not use miopen bn");
#endif

/**
* CUDNN related FLAG
* Name: FLAGS_cudnn_batchnorm_spatial_persistent
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/operators/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ static phi::DDim ColumnMatrixFromVector(const phi::DDim &y_dim) {
return common::make_ddim({y_dim[0], 1});
}

#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) || \
defined(PADDLE_WITH_HIP)
template <typename T, typename DeviceContext>
typename std::enable_if<std::is_integral<T>::value, void>::type
ComputeMatmulImpl(const framework::ExecutionContext &context) {
Expand Down Expand Up @@ -959,6 +960,7 @@ REGISTER_OP_CPU_KERNEL(matmul_grad_grad,
#if defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL(
matmul,
ops::MatMulKernel<phi::GPUContext, int8_t>,
ops::MatMulKernel<phi::GPUContext, float>,
ops::MatMulKernel<phi::GPUContext, double>,
ops::MatMulKernel<phi::GPUContext, phi::dtype::float16>);
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/backends/dynload/flashattn.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ extern void* flashattn_dso_handle;
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_WRAP(__name)

#ifdef PADDLE_WITH_HIP
#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_varlen_fwd); \
__macro(flash_attn_bwd); \
__macro(flash_attn_varlen_bwd); \
__macro(flash_attn_error);
#else
#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_varlen_fwd); \
Expand All @@ -51,6 +59,7 @@ extern void* flashattn_dso_handle;
__macro(flash_attn_fwd_with_bias_and_mask); \
__macro(flash_attn_bwd_with_bias_and_mask); \
__macro(flash_attn_error);
#endif

FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP);

Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/common/float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -1014,13 +1014,16 @@ struct is_pod<phi::dtype::float16> {
is_standard_layout<phi::dtype::float16>::value;
};

#if !(defined(PADDLE_WITH_CUSTOM_KERNEL) && defined(PADDLE_WITH_HIP))
template <>
struct is_floating_point<phi::dtype::float16>
: std::integral_constant<
bool,
std::is_same<
phi::dtype::float16,
typename std::remove_cv<phi::dtype::float16>::type>::value> {};
#endif

template <>
struct is_signed<phi::dtype::float16> {
static const bool value = true;
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5828,12 +5828,14 @@ void WeightQuantizeInferMeta(const MetaTensor& x,
const int32_t group_size,
MetaTensor* out,
MetaTensor* scale) {
#ifndef PADDLE_WITH_HIP
PADDLE_ENFORCE_EQ(
((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) ||
(arch == 89) || (arch == 90)),
true,
phi::errors::InvalidArgument(
"Currently, arch only support 70, 75, 80, 86, 89, 90."));
#endif

auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
Expand Down
Loading

0 comments on commit f561b06

Please sign in to comment.