Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DCU] new features #63721

Merged
merged 27 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a679c73
[DCU] fix bugs and surpport some fused ops
yuguo-Jack Apr 3, 2024
0631c13
[DCU] fix a small bug
yuguo-Jack Apr 3, 2024
e925661
Update fused_dropout_act_bias.h
yuguo-Jack Apr 7, 2024
1aff5a0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuguo-Jack Apr 7, 2024
55aea8f
update fused_dropout_act_bias.h
yuguo-Jack Apr 9, 2024
6acd01c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuguo-Jack Apr 9, 2024
42aa9bf
fix depthwise conv grad op bug
yuguo-Jack Apr 9, 2024
490a0d3
fix hip graph test bugs
yuguo-Jack Apr 10, 2024
81207d0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuguo-Jack Apr 10, 2024
03e9c0a
update
yuguo-Jack Apr 10, 2024
ea704f3
fix hip graph dropout bug
yuguo-Jack Apr 11, 2024
d2d143f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuguo-Jack Apr 11, 2024
3e66860
code style
yuguo-Jack Apr 11, 2024
597e660
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuguo-Jack Apr 21, 2024
6a42f01
[DCU] new features
yuguo-Jack Apr 21, 2024
b0079f0
[DCU] surpport miopen BN
yuguo-Jack Apr 21, 2024
43a401e
fix miopen bn bugs
yuguo-Jack Apr 22, 2024
3d36c6f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuguo-Jack Apr 24, 2024
c75c13c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuguo-Jack May 14, 2024
5326497
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuguo-Jack May 21, 2024
ed36550
[DCU] high performance LLM train and inference for DCU
yuguo-Jack Jun 11, 2024
c74703c
merge develop and solve conflicts
yuguo-Jack Jun 11, 2024
bf0a55b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuguo-Jack Jun 11, 2024
cf1e743
fix conflict files format
yuguo-Jack Jun 11, 2024
ae2ff95
fix redefinition of FastGeluFunctor
yuguo-Jack Jun 11, 2024
15e6938
fix small bugs
yuguo-Jack Jun 11, 2024
0e35e8f
fix a problem
yuguo-Jack Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 152 additions & 88 deletions cmake/external/flashattn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,108 +14,172 @@

include(ExternalProject)

add_definitions(-DPADDLE_WITH_FLASHATTN)

set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn)
set(FLASHATTN_TAG 5fc132ac11e78d26471ca09e5ba0cd817c3424d8)

set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
CACHE PATH "flash-attn Directory" FORCE)
set(FLASHATTN_LIB_DIR
"${FLASHATTN_INSTALL_DIR}/lib"
CACHE PATH "flash-attn Library Directory" FORCE)

if(WIN32)
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
else()
if(WITH_ROCM)
set(FA_REPOSITORY https://github.com/yuguo-Jack/flash-attention-hip.git)
set(FA_TAG "xdl")
set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn_hip)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn_hip)

set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
CACHE PATH "flash-attn Directory" FORCE)
set(FLASHATTN_LIB_DIR
"${FLASHATTN_INSTALL_DIR}/lib"
CACHE PATH "flash-attn Library Directory" FORCE)
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()

if(WIN32)
set(FLASHATTN_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_DEBUG
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(FLASHATTN_CXX_FLAGS
"${CMAKE_CXX_FLAGS} -w -Wno-deprecated-builtins -Wno-deprecated -DNDEBUG -U__HIP_NO_HALF_OPERATORS__ -U__HIP_NO_HALF_CONVERSIONS__ -fPIC -O3 -std=c++17 -D__HIP_PLATFORM_HCC__=1 --offload-arch=gfx928 -D__gfx940__"
)
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()

set(FA_NVCC_ARCH_BIN "")
foreach(arch ${NVCC_ARCH_BIN})
string(STRIP ${arch} arch)
if(arch STREQUAL "")
continue()
ExternalProject_Add(
extern_flashattn
GIT_REPOSITORY ${FA_REPOSITORY}
GIT_TAG ${FA_TAG}
SOURCE_DIR ${SOURCE_DIR}
PREFIX ${FLASHATTN_PREFIX_DIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${ROCM_PATH}/bin/hipcc
-DAMDGPU_TARGETS=gfx928
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${FLASHATTN_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=4
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES})
else()

add_definitions(-DPADDLE_WITH_FLASHATTN)

set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/flashattn)
set(FLASHATTN_TAG 5fc132ac11e78d26471ca09e5ba0cd817c3424d8)

set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
CACHE PATH "flash-attn Directory" FORCE)
set(FLASHATTN_LIB_DIR
"${FLASHATTN_INSTALL_DIR}/lib"
CACHE PATH "flash-attn Library Directory" FORCE)

if(WIN32)
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
else()
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()

if(FA_NVCC_ARCH_BIN STREQUAL "")
set(FA_NVCC_ARCH_BIN "${arch}")
if(WIN32)
set(FLASHATTN_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_DEBUG
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(FA_NVCC_ARCH_BIN "${FA_NVCC_ARCH_BIN}-${arch}")
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()
endforeach()

ExternalProject_Add(
extern_flashattn
${EXTERNAL_PROJECT_LOG_ARGS}
SOURCE_DIR ${SOURCE_DIR}
PREFIX ${FLASHATTN_PREFIX_DIR}
SOURCE_SUBDIR ${FLASHATTN_SOURCE_SUBDIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${FLASHATTN_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG}
-DCMAKE_CUDA_COMPILER_LAUNCHER=${CMAKE_CUDA_COMPILER_LAUNCHER}
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=4
-DNVCC_ARCH_BIN=${FA_NVCC_ARCH_BIN}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES})

set(FA_NVCC_ARCH_BIN "")
foreach(arch ${NVCC_ARCH_BIN})
string(STRIP ${arch} arch)
if(arch STREQUAL "")
continue()
endif()

if(FA_NVCC_ARCH_BIN STREQUAL "")
set(FA_NVCC_ARCH_BIN "${arch}")
else()
set(FA_NVCC_ARCH_BIN "${FA_NVCC_ARCH_BIN}-${arch}")
endif()
endforeach()

ExternalProject_Add(
extern_flashattn
${EXTERNAL_PROJECT_LOG_ARGS}
SOURCE_DIR ${SOURCE_DIR}
PREFIX ${FLASHATTN_PREFIX_DIR}
SOURCE_SUBDIR ${FLASHATTN_SOURCE_SUBDIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${FLASHATTN_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG}
-DCMAKE_CUDA_COMPILER_LAUNCHER=${CMAKE_CUDA_COMPILER_LAUNCHER}
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=4
-DNVCC_ARCH_BIN=${FA_NVCC_ARCH_BIN}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES})

endif()

message(STATUS "flash-attn library: ${FLASHATTN_LIBRARIES}")
get_filename_component(FLASHATTN_LIBRARY_PATH ${FLASHATTN_LIBRARIES} DIRECTORY)
Expand Down
2 changes: 2 additions & 0 deletions cmake/hip.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,11 @@ set(HIP_CLANG_FLAGS ${HIP_CXX_FLAGS})
list(APPEND HIP_HCC_FLAGS -fno-gpu-rdc)
list(APPEND HIP_HCC_FLAGS --offload-arch=gfx906) # Z100 (ZIFANG)
list(APPEND HIP_HCC_FLAGS --offload-arch=gfx926) # K100 (KONGING)
list(APPEND HIP_HCC_FLAGS --offload-arch=gfx928) # K100_AI (KONGING_AI)
list(APPEND HIP_CLANG_FLAGS -fno-gpu-rdc)
list(APPEND HIP_CLANG_FLAGS --offload-arch=gfx906) # Z100 (ZIFANG)
list(APPEND HIP_CLANG_FLAGS --offload-arch=gfx926) # K100 (KONGING)
list(APPEND HIP_CLANG_FLAGS --offload-arch=gfx928) # K100_AI (KONGING_AI)

if(HIP_COMPILER STREQUAL clang)
set(hip_library_name amdhip64)
Expand Down
6 changes: 6 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,12 @@ if(WITH_CUSPARSELT)
list(APPEND third_party_deps extern_cusparselt)
endif()

if(WITH_ROCM)
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON)
endif()

if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/operators/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ static phi::DDim ColumnMatrixFromVector(const phi::DDim &y_dim) {
return common::make_ddim({y_dim[0], 1});
}

#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) || \
defined(PADDLE_WITH_HIP)
template <typename T, typename DeviceContext>
typename std::enable_if<std::is_integral<T>::value, void>::type
ComputeMatmulImpl(const framework::ExecutionContext &context) {
Expand Down Expand Up @@ -959,6 +960,7 @@ REGISTER_OP_CPU_KERNEL(matmul_grad_grad,
#if defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL(
matmul,
ops::MatMulKernel<phi::GPUContext, int8_t>,
ops::MatMulKernel<phi::GPUContext, float>,
ops::MatMulKernel<phi::GPUContext, double>,
ops::MatMulKernel<phi::GPUContext, phi::dtype::float16>);
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/backends/dynload/flashattn.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ extern void* flashattn_dso_handle;
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_WRAP(__name)

#ifdef PADDLE_WITH_HIP
#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_varlen_fwd); \
__macro(flash_attn_bwd); \
__macro(flash_attn_varlen_bwd); \
__macro(flash_attn_error);
#else
#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_varlen_fwd); \
Expand All @@ -51,6 +59,7 @@ extern void* flashattn_dso_handle;
__macro(flash_attn_fwd_with_bias_and_mask); \
__macro(flash_attn_bwd_with_bias_and_mask); \
__macro(flash_attn_error);
#endif

FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP);

Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/common/float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -1014,13 +1014,16 @@ struct is_pod<phi::dtype::float16> {
is_standard_layout<phi::dtype::float16>::value;
};

#if !(defined(PADDLE_WITH_CUSTOM_KERNEL) && defined(PADDLE_WITH_HIP))
template <>
struct is_floating_point<phi::dtype::float16>
: std::integral_constant<
bool,
std::is_same<
phi::dtype::float16,
typename std::remove_cv<phi::dtype::float16>::type>::value> {};
#endif

template <>
struct is_signed<phi::dtype::float16> {
static const bool value = true;
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,13 @@ void FlashAttnInferMeta(const MetaTensor& q,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset) {
#ifdef PADDLE_WITH_HIP
auto out_dims = q.dims();
out_dims[3] = v.dims()[3];
out->set_dims(out_dims);
out->set_dtype(q.dtype());
out->set_layout(q.layout());
#else
auto out_dims = q.dims();
PADDLE_ENFORCE_EQ(out_dims.size(),
4,
Expand Down Expand Up @@ -435,6 +442,7 @@ void FlashAttnInferMeta(const MetaTensor& q,
seed_offset->set_dtype(phi::DataType::INT64);
seed_offset->set_dims({2});
}
#endif
}
void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv,
MetaTensor* out,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5760,12 +5760,13 @@ void WeightQuantizeInferMeta(const MetaTensor& x,
const int32_t group_size,
MetaTensor* out,
MetaTensor* scale) {
#ifndef PADDLE_WITH_HIP
PADDLE_ENFORCE_EQ(
((arch == 80) || (arch == 86) || (arch == 70) || (arch == 75)),
true,
phi::errors::InvalidArgument(
"Currently, arch only support 70, 75, 80, 86."));

#endif
auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,6 @@ if(WITH_ROCM)
REMOVE_ITEM
kernel_cu
"gpudnn/mha_cudnn_frontend.cu"
"fusion/gpu/blha_get_max_len.cu"
"fusion/gpu/block_multi_head_attention_kernel.cu"
"fusion/gpu/fused_bn_add_activation_grad_kernel.cu"
"fusion/gpu/fused_bn_add_activation_kernel.cu"
"fusion/gpu/fusion_transpose_flatten_concat_kernel.cu")
Expand Down
Loading