Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
fd28881
[Metax_change_ut]
duqimeng Jul 23, 2025
a9d2aa7
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Jul 24, 2025
1695f36
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Jul 31, 2025
b931d38
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 1, 2025
bef21bf
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 8, 2025
f4e5004
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 13, 2025
55422eb
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 18, 2025
815a63a
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 19, 2025
1739a15
fix sum&collect_fpn_proposals op register
StareAtYou Aug 19, 2025
af0bae5
fix sum&collect_fpn_proposals op register
metax666 Aug 19, 2025
be61f06
modify profile
jxwangmetax Aug 20, 2025
0fc2dd1
modify profile
metax666 Aug 20, 2025
1ad95c5
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 20, 2025
f12b3e4
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 21, 2025
789c9fc
[Metax] fix paddle bug replace 'MoeGradDispatchKernel' to 'MoeGateDis…
StareAtYou Aug 21, 2025
a0116fb
[Metax] fix paddle bug
metax666 Aug 21, 2025
a2da5e0
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 22, 2025
f9e6d2c
[Metax] register bce_loss_grad & bce_loss & index_add_grad kernels
StareAtYou Aug 22, 2025
4b4f562
Merge branch 'develop' of https://github.com/duqimeng/PaddleCustomDev…
duqimeng Aug 22, 2025
662e22e
[Metax] con2d_grad use gpudnn
duqimeng Aug 22, 2025
3e8d6ce
Merge branch 'metax666:develop' into develop
StareAtYou Aug 25, 2025
9dae9b7
[Metax] register bce_loss_grad & bce_loss & index_add_grad kernels
metax666 Aug 25, 2025
47fef62
blas handle support
jxwangmetax Aug 25, 2025
266c0df
blas handle support
metax666 Aug 25, 2025
a0b340b
[Metax] register some kernels & update CMakeLists
StareAtYou Aug 25, 2025
aa9bd35
Merge branch 'metax666:develop' into develop
StareAtYou Aug 26, 2025
8c6ac05
[Metax] register some kernels & update CMakeLists
metax666 Aug 26, 2025
9510f7d
Merge branch 'metax666:develop' into develop
duqimeng Aug 26, 2025
fa7cc1a
[Metax] fix metax unittest fail
StareAtYou Aug 26, 2025
a907545
[Metax] fix metax unittest fail
metax666 Aug 26, 2025
7a6312e
[Metax] add group_norm & label_smooth kernel and update matmul kernel
StareAtYou Aug 26, 2025
90bb94e
[Metax] add group_norm & label_smooth kernel and update matmul kernel
metax666 Aug 27, 2025
9f130fe
[Metax] fix rmsprop kernel register and add meshgrid & meshgrid_grad …
StareAtYou Aug 27, 2025
ca38fb5
Merge branch 'metax666:develop' into develop
StareAtYou Aug 27, 2025
f0cc1e0
add test
zhang-chenyi Aug 27, 2025
8e8b732
add test
zhang-chenyi Aug 27, 2025
8d7efbd
Merge branch 'metax666:develop' into develop
zhang-chenyi Aug 27, 2025
28c992b
Merge branch 'develop' of https://github.com/zhang-chenyi/PaddleCusto…
zhang-chenyi Aug 27, 2025
d3470bb
[test] chang the logic of workspace_host in cholesky_kernel_register
zhang-chenyi Aug 27, 2025
db17ebf
Merge branch 'develop' of https://github.com/zhang-chenyi/PaddleCusto…
zhang-chenyi Aug 27, 2025
83bc87f
[Metax] fix compile fail
StareAtYou Aug 27, 2025
f1e8d0c
Revert "[Metax] fix compile fail"
StareAtYou Aug 27, 2025
a13daa8
[Metax] fix compile fail by 'conv_transpose_grad_kernel_impl.h'
StareAtYou Aug 27, 2025
95a179b
[Metax] fix bug & add some kernel register
metax666 Aug 28, 2025
4576ef4
[Metax]fix bug and add qr lstsq logsoftmax
duqimeng Aug 28, 2025
ca51a1e
Merge branch 'metax666:develop' into develop
duqimeng Aug 28, 2025
7789e9b
[Metax] con2d_grad use gpudnn
duqimeng Aug 22, 2025
afd0863
[Metax]fix bug and add qr lstsq logsoftmax
duqimeng Aug 28, 2025
6da0f0d
Merge branch 'metax666:develop' into develop
duqimeng Aug 28, 2025
e1e07ba
[Metax] change_patch
duqimeng Aug 28, 2025
046637c
[Metax] change_patch
metax666 Aug 28, 2025
c27b492
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 28, 2025
05ecd9d
[Metax] update unit test CMakeLists.txt
StareAtYou Aug 28, 2025
b1bf7e8
[Metax] update unit test CMakeLists.txt
StareAtYou Aug 28, 2025
f90d585
Merge branch 'metax666:develop' into develop
StareAtYou Aug 28, 2025
874d9b6
Merge branch 'metax666:develop' into develop
zhang-chenyi Aug 28, 2025
0ca02b9
[feature] add unique_consecutive kernel
zhang-chenyi Aug 28, 2025
40d8f21
[metax-feature] add kernel for test_math_op_patch_var_base
metax666 Aug 28, 2025
3e9b526
[metax] add some kernel
duqimeng Aug 28, 2025
8911576
[metax] add some kernel
duqimeng Aug 28, 2025
8471597
Merge branch 'metax666:develop' into develop
duqimeng Aug 28, 2025
0758887
Merge branch 'metax666:develop' into develop
StareAtYou Aug 29, 2025
61be33d
[Metax] register baddbmm kernel & update blas api
StareAtYou Aug 29, 2025
2fe962e
[Metax] register baddbmm kernel & update blas api
StareAtYou Aug 29, 2025
531fedb
Merge branch 'metax666:develop' into develop
StareAtYou Aug 29, 2025
c0dcfff
[Metax] register deformable_conv kernel & fix 'ModulatedDeformableCol…
StareAtYou Aug 29, 2025
bd65451
[feature] add add unique_consecutive kernel.cu
zhang-chenyi Aug 29, 2025
0def63d
[fix] fix some test case due to missing op register
zhang-chenyi Aug 29, 2025
e503c9e
[fix] fix some fail text
zhang-chenyi Aug 29, 2025
9844878
[metax]fix lu eigvalshsqueeze rnn kernel
duqimeng Aug 29, 2025
70b86e7
[metax]fix lu eigvalshsqueeze rnn kernel
duqimeng Aug 29, 2025
1e90757
add and fix some kernels
1184319564 Aug 30, 2025
f93307d
[Metax] register deformable_conv kernel & fix 'ModulatedDeformableCol…
StareAtYou Aug 29, 2025
c4b0eb9
[Metax] fix conflict
StareAtYou Sep 1, 2025
06dda18
[Metax] fix conflict
StareAtYou Sep 1, 2025
dae6ce8
[Metax] adapt to paddle-cpu-20250901 & resolve the issue of 'test_ele…
StareAtYou Sep 1, 2025
b4a5c62
[Metax] update repeat_interleave kernel & ignore max op test
StareAtYou Sep 2, 2025
7cf4405
Merge branch 'metax666:develop' into develop
StareAtYou Sep 2, 2025
0015f2e
[Metax] register deformable_conv kernel & fix 'ModulatedDeformableCol…
metax666 Sep 2, 2025
fc2c0f5
Merge branch 'metax666:develop' into develop
duqimeng Sep 2, 2025
829c3b6
Merge dev
duqimeng Sep 2, 2025
3104a9c
【metax】add and fix some kernels
metax666 Sep 2, 2025
175cca6
[metax]fix lu eigvalshsqueeze rnn kernel
metax666 Sep 2, 2025
c7db810
[metax]fix lu eigvalshsqueeze rnn kernel
duqimeng Aug 29, 2025
f5813ed
[metax] chang patch fix copy
duqimeng Sep 2, 2025
6f0b705
[metax] chang patch fix copy
duqimeng Sep 2, 2025
8f47f0e
[metax] chang patch fix copy
metax666 Sep 2, 2025
b420f97
[Metax] update metax_gpu unit test
StareAtYou Sep 2, 2025
c08533e
[Metax] update metax_gpu unit test
metax666 Sep 2, 2025
414715f
[Metax] fix test CMakeList.txt
StareAtYou Sep 2, 2025
aa6b5bf
[Metax] fix test CMakeList.txt
metax666 Sep 2, 2025
0bfc6e7
[metax]change_cupti_and_fix_softmax
duqimeng Sep 9, 2025
cb93f6a
[metax]change_cupti_and_fix_softmax
duqimeng Sep 9, 2025
2e99f62
[metax]change_patch
duqimeng Sep 9, 2025
026551a
[metax]change_patch
duqimeng Sep 9, 2025
b09babb
Merge branch 'metax666:develop' into develop
duqimeng Sep 9, 2025
31594f8
[metax] updata_qr_kernel
duqimeng Sep 11, 2025
4fb467c
[metax] updata_qr_kernel
duqimeng Sep 11, 2025
5dc60a3
Merge branch 'metax666:develop' into develop
duqimeng Sep 11, 2025
e4fd192
Merge branch 'metax666:develop' into develop
duqimeng Sep 15, 2025
471b184
[Metax] fix cufft and fix some blas kernel apply
duqimeng Sep 15, 2025
a0d237c
Merge branch 'metax666:develop' into develop
duqimeng Sep 15, 2025
4c86266
[metax] fix bug
duqimeng Sep 15, 2025
a8b4696
[Metax] add github action
duqimeng Sep 16, 2025
8dff471
[metax]chaneg build
duqimeng Sep 16, 2025
ee4eefd
[metax]chaneg build
duqimeng Sep 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions .github/workflows/metax_work.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: padlle metax gpu test

on:
workflow_dispatch:
pull_request:
types: [opened, synchronize]
branches: [develop, release/**]
paths:
- "**"
- "!backends/**"
- "backends/metax_gpu/**"

permissions: read-all

defaults:
run:
shell: bash

jobs:
metax-gpu-test:
runs-on: paddle-metax-runner-set
steps:
- name: Checkout repository
run: |
git config --global user.name "GitHub Actions"
git config --global user.email "actions@github.com"

if [ "${{ github.event_name }}" == "pull_request" ]; then
BRANCH_NAME=${{ github.head_ref }}
else
BRANCH_NAME=${{ github.ref_name }}
fi

git clone \
--reference-if-able /home/runner/PaddleCustomDevice \
--depth=1 \
--shallow-submodules \
--jobs=8 \
--branch $BRANCH_NAME \
--recurse-submodules \
https://${{ github.actor }}:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}.git .


- name: compile
run: |
cd backends/metax_gpu
bash build.sh

- name: run test
run: |
cd backends/metax_gpu/tests
bash run_test.sh
2 changes: 2 additions & 0 deletions backends/metax_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ include(cblas)
include(flashattn)
include(cutlass)
include(dgc)
include(warpctc)
include(warprnnt)

set(PLUGIN_VERSION ${PADDLE_VERSION})

Expand Down
4 changes: 2 additions & 2 deletions backends/metax_gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ set -e
pip uninstall paddlepaddle -y


export http_proxy=http://10.2.192.21:1080 https_proxy=http://10.2.192.21:1080
# export http_proxy=http://10.2.192.21:1080 https_proxy=http://10.2.192.21:1080
pip install safetensors==0.6.2 -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple some-package
# install paddle
python -m pip install --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
Expand Down Expand Up @@ -50,7 +50,7 @@ fi
echo "make_maca"
cd build
cmake_maca .. -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON
make_maca -j8
make_maca -j60

echo "install whl"
pip install dist/paddle_metax_gpu*.whl --force-reinstall
Expand Down
1 change: 1 addition & 0 deletions backends/metax_gpu/change_patch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ cp patch/tmp/mixed_vector* ../../Paddle/paddle/phi/core
cd ../../Paddle/
git apply --verbose ../backends/metax_gpu/patch/paddle.patch
cd -
cp -r patch/intrinsics.cuh ../../Paddle/third_party/warpctc/include/contrib/moderngpu/include/device/
149 changes: 149 additions & 0 deletions backends/metax_gpu/cmake/warpctc.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.

include(ExternalProject)

if(WITH_ROCM)
add_definitions(-DWARPCTC_WITH_HIP)
endif()

set(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc)
set(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc)
# in case of low internet speed set(WARPCTC_REPOSITORY
# https://gitee.com/tianjianhe/warp-ctc.git)
set(WARPCTC_TAG bdc2b4550453e0ef2d3b5190f9c6103a84eff184)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/warpctc)
set(WARPCTC_PATCH_COMMAND "")
set(WARPCTC_CCBIN_OPTION "")
if(WIN32)
set(WARPCTC_PATCH_CUDA_COMMAND
git checkout -- . && git checkout ${WARPCTC_TAG} && git apply
${PADDLE_SOURCE_DIR}/patches/warpctc/CMakeLists.txt.cuda.patch)
else()
set(WARPCTC_PATCH_CUDA_COMMAND
git checkout -- . && git checkout ${WARPCTC_TAG} && patch -Nd
${SOURCE_DIR} <
${PADDLE_SOURCE_DIR}/patches/warpctc/CMakeLists.txt.cuda.patch)
endif()

if(NOT WIN32 AND WITH_GPU)
if(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0 AND ${CMAKE_CXX_COMPILER_VERSION}
VERSION_GREATER 12.0)
file(TO_NATIVE_PATH
${PADDLE_SOURCE_DIR}/patches/warpctc/CMakeLists.txt.patch native_src)
set(WARPCTC_PATCH_COMMAND git checkout -- . && git checkout ${WARPCTC_TAG}
&& patch -Nd ${SOURCE_DIR} < ${native_src} &&)
set(WARPCTC_CCBIN_OPTION -DCCBIN_COMPILER=${CCBIN_COMPILER})
endif()
endif()

if(WITH_ROCM)
set(WARPCTC_PATHCH_ROCM_COMMAND
patch -p1 <
${PADDLE_SOURCE_DIR}/patches/warpctc/CMakeLists.txt.rocm.patch && patch
-p1 < ${PADDLE_SOURCE_DIR}/patches/warpctc/devicetypes.cuh.patch && patch
-p1 < ${PADDLE_SOURCE_DIR}/patches/warpctc/hip.cmake.patch)
endif()

set(WARPCTC_INCLUDE_DIR
"${WARPCTC_INSTALL_DIR}/include"
CACHE PATH "Warp-ctc Directory" FORCE)
# Used in unit test test_WarpCTCLayer
set(WARPCTC_LIB_DIR
"${WARPCTC_INSTALL_DIR}/lib"
CACHE PATH "Warp-ctc Library Directory" FORCE)

if(WIN32)
set(WARPCTC_LIBRARIES
"${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-ctc Library" FORCE)
else()
set(WARPCTC_LIBRARIES
"${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-ctc Library" FORCE)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()

if(WIN32)
set(WARPCTC_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(WARPCTC_C_FLAGS_DEBUG $<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(WARPCTC_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(WARPCTC_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(WARPCTC_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(WARPCTC_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(WARPCTC_C_FLAGS ${CMAKE_C_FLAGS})
set(WARPCTC_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(WARPCTC_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(WARPCTC_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(WARPCTC_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(WARPCTC_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()

ExternalProject_Add(
extern_warpctc
${EXTERNAL_PROJECT_LOG_ARGS}
SOURCE_DIR ${SOURCE_DIR}
PREFIX ${WARPCTC_PREFIX_DIR}
UPDATE_COMMAND ""
PATCH_COMMAND
COMMAND ${WARPCTC_PATCH_COMMAND}
COMMAND ${WARPCTC_PATCH_CUDA_COMMAND}
COMMAND ${WARPCTC_PATHCH_ROCM_COMMAND}
# BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${WARPCTC_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${WARPCTC_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${WARPCTC_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${WARPCTC_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${WARPCTC_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${WARPCTC_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DNVCC_FLAGS_EXTRA=${NVCC_FLAGS_EXTRA}
-DWITH_TORCH=OFF
-DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON
-DBUILD_SHARED=ON
-DBUILD_TESTS=OFF
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}
${EXTERNAL_OPTIONAL_ARGS}
${WARPCTC_CCBIN_OPTION}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR}
BUILD_BYPRODUCTS ${WARPCTC_LIBRARIES})

message(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}")
get_filename_component(WARPCTC_LIBRARY_PATH ${WARPCTC_LIBRARIES} DIRECTORY)
include_directories(${WARPCTC_INCLUDE_DIR}) # For warpctc code to include its
# headers.

add_library(warpctc INTERFACE)
add_dependencies(warpctc extern_warpctc)
142 changes: 142 additions & 0 deletions backends/metax_gpu/cmake/warprnnt.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.

include(ExternalProject)

if(WITH_ROCM)
add_definitions(-DWARPRNNT_WITH_HIP)
endif()

set(WARPRNNT_PREFIX_DIR ${THIRD_PARTY_PATH}/warprnnt)
set(WARPRNNT_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warprnnt)
set(WARPRNNT_TAG 7ea6bfe748779c245a0fcaa5dd9383826273eff2)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/warprnnt)
set(WARPRNNT_PATCH_COMMAND "")
set(WARPRNNT_CCBIN_OPTION "")
if(WIN32)
set(WARPCTC_PATCH_CUDA_COMMAND
${CMAKE_COMMAND} -E copy_if_different
${PADDLE_SOURCE_DIR}/patches/warprnnt/CMakeLists.txt.cuda.patch
"<SOURCE_DIR>/")
else()
set(WARPCTC_PATCH_CUDA_COMMAND
git checkout -- . && git checkout ${WARPRNNT_TAG} && patch -Nd
${SOURCE_DIR} <
${PADDLE_SOURCE_DIR}/patches/warprnnt/CMakeLists.txt.cuda.patch)
endif()
if(WITH_ROCM)
set(WARPRNNT_PATCH_ROCM_COMMAND
patch -p1 <
${PADDLE_SOURCE_DIR}/patches/warprnnt/CMakeLists.txt.rocm.patch)
endif()
if(NOT WIN32 AND WITH_GPU)
if(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0 AND ${CMAKE_CXX_COMPILER_VERSION}
VERSION_GREATER 12.0)
file(TO_NATIVE_PATH
${PADDLE_SOURCE_DIR}/patches/warprnnt/CMakeLists.txt.patch native_src)
set(WARPRNNT_PATCH_COMMAND
git checkout -- . && git checkout ${WARPRNNT_TAG} && patch -Nd
${SOURCE_DIR} < ${native_src})
set(WARPRNNT_CCBIN_OPTION -DCCBIN_COMPILER=${CCBIN_COMPILER})
endif()
endif()

set(WARPRNNT_INCLUDE_DIR
"${WARPRNNT_INSTALL_DIR}/include"
CACHE PATH "Warp-rnnt Directory" FORCE)
# Used in unit test test_WarpCTCLayer
set(WARPRNNT_LIB_DIR
"${WARPRNNT_INSTALL_DIR}/lib"
CACHE PATH "Warp-rnnt Library Directory" FORCE)

if(WIN32)
set(WARPRNNT_LIBRARIES
"${WARPRNNT_INSTALL_DIR}/bin/warprnnt${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-rnnt Library" FORCE)
else()
set(WARPRNNT_LIBRARIES
"${WARPRNNT_INSTALL_DIR}/lib/libwarprnnt${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-rnnt Library" FORCE)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()

if(WIN32)
set(WARPRNNT_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(WARPRNNT_C_FLAGS_DEBUG
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(WARPRNNT_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(WARPRNNT_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(WARPRNNT_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(WARPRNNT_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(WARPRNNT_C_FLAGS ${CMAKE_C_FLAGS})
set(WARPRNNT_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(WARPRNNT_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(WARPRNNT_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(WARPRNNT_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(WARPRNNT_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()
ExternalProject_Add(
extern_warprnnt
${EXTERNAL_PROJECT_LOG_ARGS}
SOURCE_DIR ${SOURCE_DIR}
PREFIX ${WARPRNNT_PREFIX_DIR}
UPDATE_COMMAND ""
PATCH_COMMAND
COMMAND ${WARPCTC_PATCH_CUDA_COMMAND}
COMMAND ${WARPRNNT_PATCH_ROCM_COMMAND}
# BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${WARPRNNT_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${WARPRNNT_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${WARPRNNT_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${WARPRNNT_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${WARPRNNT_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${WARPRNNT_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${WARPRNNT_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DNVCC_FLAGS_EXTRA=${NVCC_FLAGS_EXTRA}
-DBUILD_SHARED=ON
-DBUILD_TESTS=OFF
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
${WARPCTC_CCBIN_OPTION}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${WARPRNNT_INSTALL_DIR}
BUILD_BYPRODUCTS ${WARPRNNT_LIBRARIES})

message(STATUS "warp-rnnt library: ${WARPRNNT_LIBRARIES}")
get_filename_component(WARPRNNT_LIBRARY_PATH ${WARPRNNT_LIBRARIES} DIRECTORY)
include_directories(${WARPRNNT_INCLUDE_DIR}) # For warprnnt code to include its
# headers.

add_library(warprnnt INTERFACE)
# set_property(TARGET warprnnt PROPERTY IMPORTED_LOCATION ${WARPRNNT_LIBRARIES})
add_dependencies(warprnnt extern_warprnnt)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/warpctc_grad_kernel.h"

PD_REGISTER_PLUGIN_KERNEL(warpctc_grad,
PD_CUSTOM_KERNEL_REGISTER(warpctc_grad,
metax_gpu,
ALL_LAYOUT,
phi::WarpctcGradKernel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/warpctc_kernel.h"

PD_REGISTER_PLUGIN_KERNEL(
PD_CUSTOM_KERNEL_REGISTER(
warpctc, metax_gpu, ALL_LAYOUT, phi::WarpctcKernel, float, double) {}
Loading
Loading