From e7c01a12535eb277383c1cb4a41b05c3d04bf056 Mon Sep 17 00:00:00 2001 From: 3manifold <22544721+3manifold@users.noreply.github.com> Date: Mon, 29 Jan 2024 10:28:47 +0100 Subject: [PATCH] CANN Backend support --- CMakeLists.txt | 107 ++++ README.md | 6 +- cli/translator.cc | 8 +- docker/build_all.sh | 0 docker/cann/Dockerfile_cann | 78 +++ docker/cann/run_container_cann.sh | 15 + docs/hardware_support.md | 7 + examples/cann/CMakeLists.txt | 10 + examples/cann/README.md | 45 ++ examples/cann/build_run.sh | 19 + examples/cann/main.cc | 48 ++ include/ctranslate2/devices.h | 6 +- include/ctranslate2/ops/gemm.h | 7 + include/ctranslate2/ops/matmul.h | 4 + include/ctranslate2/ops/mul.h | 21 +- include/ctranslate2/ops/transpose.h | 24 +- include/ctranslate2/primitives.h | 21 +- include/ctranslate2/replica_pool.h | 2 +- include/ctranslate2/storage_view.h | 6 +- include/ctranslate2/utils.h | 1 + python/cpp/encoder.cc | 2 +- python/cpp/generator.cc | 2 +- python/cpp/module.cc | 7 +- python/cpp/storage_view.cc | 2 +- python/ctranslate2/__init__.py | 1 + src/cann/allocator.cc | 44 ++ src/cann/cann_inc.h | 8 + src/cann/primitives.cc | 960 ++++++++++++++++++++++++++++ src/cann/utils.cc | 156 +++++ src/cann/utils.h | 127 ++++ src/cpu/primitives.cc | 17 +- src/cuda/primitives.cu | 17 +- src/decoding_utils.cc | 3 +- src/device_dispatch.h | 14 +- src/devices.cc | 61 +- src/dispatch.h | 38 +- src/layers/decoder.cc | 4 +- src/models/model.cc | 4 + src/ops/alibi_add_npu.cc | 28 + src/ops/bias_add_npu.cc | 33 + src/ops/concat_split_slide_npu.cc | 181 ++++++ src/ops/conv1d_npu.cc | 26 + src/ops/dequantize_npu.cc | 63 ++ src/ops/gather_npu.cc | 65 ++ src/ops/gemm.cc | 38 ++ src/ops/gumbel_max_npu.cc | 23 + src/ops/layer_norm_npu.cc | 80 +++ src/ops/matmul.cc | 211 +++--- src/ops/mean_npu.cc | 30 + src/ops/mul.cc | 73 ++- src/ops/multinomial_npu.cc | 22 + src/ops/quantize_npu.cc | 27 + src/ops/rms_norm_npu.cc | 24 + src/ops/rotary_npu.cc | 26 + src/ops/softmax_npu.cc | 98 +++ src/ops/tile_npu.cc | 25 + src/ops/topk_npu.cc | 79 +++ src/ops/topp_mask_npu.cc | 28 + src/ops/transpose.cc | 39 +- src/storage_view.cc | 45 +- src/thread_pool.cc | 3 +- src/types.cc | 29 + src/utils.cc | 12 + tests/CMakeLists.txt | 5 + tests/benchmark_ops.cc | 2 +- tests/benchmark_utils.h | 4 + tests/layers_test.cc | 34 + tests/ops_test.cc | 826 +++++++++++++++++++++++- tests/primitives_test.cc | 219 ++++++- tests/storage_view_test.cc | 62 ++ tests/test.cc | 18 +- tests/test_utils.h | 33 + tests/translator_test.cc | 22 + 73 files changed, 4282 insertions(+), 153 deletions(-) mode change 100755 => 100644 docker/build_all.sh create mode 100644 docker/cann/Dockerfile_cann create mode 100644 docker/cann/run_container_cann.sh create mode 100644 examples/cann/CMakeLists.txt create mode 100644 examples/cann/README.md create mode 100644 examples/cann/build_run.sh create mode 100644 examples/cann/main.cc create mode 100644 src/cann/allocator.cc create mode 100644 src/cann/cann_inc.h create mode 100644 src/cann/primitives.cc create mode 100644 src/cann/utils.cc create mode 100644 src/cann/utils.h create mode 100644 src/ops/alibi_add_npu.cc create mode 100644 src/ops/bias_add_npu.cc create mode 100644 src/ops/concat_split_slide_npu.cc create mode 100644 src/ops/conv1d_npu.cc create mode 100644 src/ops/dequantize_npu.cc create mode 100644 src/ops/gather_npu.cc create mode 100644 src/ops/gumbel_max_npu.cc create mode 100644 src/ops/layer_norm_npu.cc create mode 100644 src/ops/mean_npu.cc create mode 100644 src/ops/multinomial_npu.cc create mode 100644 src/ops/quantize_npu.cc create mode 100644 src/ops/rms_norm_npu.cc create mode 100644 src/ops/rotary_npu.cc create mode 100644 src/ops/softmax_npu.cc create mode 100644 src/ops/tile_npu.cc create mode 100644 src/ops/topk_npu.cc create mode 100644 src/ops/topp_mask_npu.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 1089106cc..4c72a56bf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,6 +12,7 @@ option(WITH_DNNL "Compile with DNNL backend" OFF) option(WITH_ACCELERATE "Compile with Accelerate backend" OFF) option(WITH_OPENBLAS "Compile with OpenBLAS backend" OFF) option(WITH_RUY "Compile with Ruy backend" OFF) +option(WITH_CANN "Compile with CANN backend" OFF) option(WITH_CUDA "Compile with CUDA backend" OFF) option(WITH_CUDNN "Compile with cuDNN backend" OFF) option(CUDA_DYNAMIC_LOADING "Dynamically load CUDA libraries at runtime" OFF) @@ -21,6 +22,12 @@ option(BUILD_CLI "Compile the clients" ON) option(BUILD_TESTS "Compile the tests" OFF) option(BUILD_SHARED_LIBS "Build shared libraries" ON) +if(WITH_CUDA OR WITH_CUDNN) + if(WITH_CANN) + message( FATAL_ERROR "CANN backend cannot be combined with CUDA or CUDNN!" ) + endif () +endif () + if(ENABLE_PROFILING) message(STATUS "Enable profiling support") add_definitions(-DCT2_ENABLE_PROFILING) @@ -525,6 +532,105 @@ if (WITH_CUDA) ) elseif(WITH_CUDNN) message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON") +elseif(WITH_CANN) + add_definitions(-DCT2_WITH_CANN) + + message(STATUS "ASCEND_TOOLKIT_HOME: $ENV{ASCEND_TOOLKIT_HOME}") + message(STATUS "LD_LIBRARY_PATH: $ENV{LD_LIBRARY_PATH}") + message(STATUS "PYTHONPATH: $ENV{PYTHONPATH}") + message(STATUS "ASCEND_AICPU_PATH: $ENV{ASCEND_AICPU_PATH}") + message(STATUS "ASCEND_OPP_PATH: $ENV{ASCEND_OPP_PATH}") + message(STATUS "TOOLCHAIN_HOME: $ENV{TOOLCHAIN_HOME}") + message(STATUS "ASCEND_HOME_PATH: $ENV{ASCEND_HOME_PATH}") + message(STATUS "PATH: $ENV{PATH}") + + if(DEFINED ENV{ASCEND_CUSTOM_PATH}) + set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) + else() + set(ASCEND_DIR /usr/local/Ascend) + endif() + + message(STATUS "ASCEND_DIR: ${ASCEND_DIR}") + + set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) + set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) + set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) + set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) + set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) + set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) + set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) + + set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR}) + set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR}) + set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64) + set(ATLAS_RUNTIME_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include) + set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64) + set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64) + set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR}) + + set(atlas_graph_lib ${ATLAS_RUNTIME_DIR}/libgraph.so) + set(atlas_ge_runner_lib ${ATLAS_RUNTIME_DIR}/libge_runner.so) + set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so) + INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR}) + + ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib}) + + ADD_LIBRARY(ascend_graph SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${atlas_graph_lib}) + + ADD_LIBRARY(atlas_acl SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET atlas_acl PROPERTY IMPORTED_LOCATION ${atlas_acl_lib}) + + set(extern_ascend ascend_ge ascend_graph atlas_acl CACHE INTERNAL "acllib runtime libs") + + set(ASCEND_CL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64) + + set(ascend_hccl_lib ${ASCEND_CL_DIR}/libhccl.so) + set(ascendcl_lib ${ASCEND_CL_DIR}/libascendcl.so) + set(acl_op_compiler_lib ${ASCEND_CL_DIR}/libacl_op_compiler.so) + set(FWKACLLIB_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include) + set(ACLLIB_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/include) + + message(STATUS "FWKACLLIB_INC_DIR ${FWKACLLIB_INC_DIR}") + message(STATUS "ASCEND_CL_DIR ${ASCEND_CL_DIR}") + INCLUDE_DIRECTORIES(${FWKACLLIB_INC_DIR}) + INCLUDE_DIRECTORIES(${ACLLIB_INC_DIR}) + + ADD_LIBRARY(ascendcl SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET ascendcl PROPERTY IMPORTED_LOCATION ${ascendcl_lib}) + + ADD_LIBRARY(ascend_hccl SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET ascend_hccl PROPERTY IMPORTED_LOCATION ${ascend_hccl_lib}) + + ADD_LIBRARY(acl_op_compiler SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET acl_op_compiler PROPERTY IMPORTED_LOCATION ${acl_op_compiler_lib}) + + set(extern_ascend_cl ascendcl acl_op_compiler CACHE INTERNAL "acltoolkit libs") + + list(APPEND SOURCES + src/cann/allocator.cc + src/cann/primitives.cc + src/cann/utils.cc + src/ops/topk_npu.cc + src/ops/dequantize_npu.cc + src/ops/gumbel_max_npu.cc + src/ops/topp_mask_npu.cc + src/ops/multinomial_npu.cc + src/ops/gather_npu.cc + src/ops/conv1d_npu.cc + src/ops/concat_split_slide_npu.cc + src/ops/alibi_add_npu.cc + src/ops/softmax_npu.cc + src/ops/tile_npu.cc + src/ops/rms_norm_npu.cc + src/ops/layer_norm_npu.cc + src/ops/rotary_npu.cc + src/ops/bias_add_npu.cc + src/ops/mean_npu.cc + src/ops/quantize_npu.cc) + add_library(${PROJECT_NAME} ${SOURCES}) + list(APPEND LIBRARIES ${extern_ascend} ${extern_ascend_cl}) else() add_library(${PROJECT_NAME} ${SOURCES}) endif() @@ -540,6 +646,7 @@ set_property(TARGET ${PROJECT_NAME} APPEND PROPERTY ) list(APPEND LIBRARIES ${CMAKE_DL_LIBS}) + target_link_libraries(${PROJECT_NAME} PRIVATE ${LIBRARIES}) target_include_directories(${PROJECT_NAME} BEFORE PUBLIC $ $ diff --git a/README.md b/README.md index af763e0bf..7b88437c2 100644 --- a/README.md +++ b/README.md @@ -25,12 +25,12 @@ The project is production-oriented and comes with [backward compatibility guaran ## Key features -* **Fast and efficient execution on CPU and GPU**
The execution [is significantly faster and requires less resources](#benchmarks) than general-purpose deep learning frameworks on supported models and tasks thanks to many advanced optimizations: layer fusion, padding removal, batch reordering, in-place operations, caching mechanism, etc. +* **Fast and efficient execution on CPU, GPU and NPU**
The execution [is significantly faster and requires less resources](#benchmarks) than general-purpose deep learning frameworks on supported models and tasks thanks to many advanced optimizations: layer fusion, padding removal, batch reordering, in-place operations, caching mechanism, etc. * **Quantization and reduced precision**
The model serialization and computation support weights with [reduced precision](https://opennmt.net/CTranslate2/quantization.html): 16-bit floating points (FP16), 16-bit brain floating points (BF16), 16-bit integers (INT16), and 8-bit integers (INT8). * **Multiple CPU architectures support**
The project supports x86-64 and AArch64/ARM64 processors and integrates multiple backends that are optimized for these platforms: [Intel MKL](https://software.intel.com/content/www/us/en/develop/tools/oneapi/components/onemkl.html), [oneDNN](https://github.com/oneapi-src/oneDNN), [OpenBLAS](https://www.openblas.net/), [Ruy](https://github.com/google/ruy), and [Apple Accelerate](https://developer.apple.com/documentation/accelerate). * **Automatic CPU detection and code dispatch**
One binary can include multiple backends (e.g. Intel MKL and oneDNN) and instruction set architectures (e.g. AVX, AVX2) that are automatically selected at runtime based on the CPU information. -* **Parallel and asynchronous execution**
Multiple batches can be processed in parallel and asynchronously using multiple GPUs or CPU cores. -* **Dynamic memory usage**
The memory usage changes dynamically depending on the request size while still meeting performance requirements thanks to caching allocators on both CPU and GPU. +* **Parallel and asynchronous execution**
Multiple batches can be processed in parallel and asynchronously using multiple GPUs, NPUs or CPU cores. +* **Dynamic memory usage**
The memory usage changes dynamically depending on the request size while still meeting performance requirements thanks to caching allocators on all CPU, GPU and NPU. * **Lightweight on disk**
Quantization can make the models 4 times smaller on disk with minimal accuracy loss. * **Simple integration**
The project has few dependencies and exposes simple APIs in [Python](https://opennmt.net/CTranslate2/python/overview.html) and C++ to cover most integration needs. * **Configurable and interactive decoding**
[Advanced decoding features](https://opennmt.net/CTranslate2/decoding.html) allow autocompleting a partial sequence and returning alternatives at a specific location in the sequence. diff --git a/cli/translator.cc b/cli/translator.cc index b17458862..434132dbc 100644 --- a/cli/translator.cc +++ b/cli/translator.cc @@ -30,7 +30,7 @@ int main(int argc, char* argv[]) { cxxopts::value()->default_value("1")) ("intra_threads", "Number of computation threads (set to 0 to use the default value).", cxxopts::value()->default_value("0")) - ("device", "Device to use (can be cpu, cuda, auto).", + ("device", "Device to use (can be cpu, cuda, cann, auto).", cxxopts::value()->default_value("cpu")) ("device_index", "Comma-separated list of device IDs to use.", cxxopts::value>()->default_value("0")) @@ -44,6 +44,8 @@ int main(int argc, char* argv[]) { cxxopts::value()->default_value("default")) ("cuda_compute_type", "Computation type on CUDA devices (overrides compute_type)", cxxopts::value()) + ("cann_compute_type", "Computation type on CANN devices (overrides compute_type)", + cxxopts::value()) ("cpu_compute_type", "Computation type on CPU devices (overrides compute_type)", cxxopts::value()) ; @@ -139,6 +141,10 @@ int main(int argc, char* argv[]) { if (args.count("cuda_compute_type")) compute_type = ctranslate2::str_to_compute_type(args["cuda_compute_type"].as()); break; + case ctranslate2::Device::CANN: + if (args.count("cann_compute_type")) + compute_type = ctranslate2::str_to_compute_type(args["cann_compute_type"].as()); + break; }; ctranslate2::ReplicaPoolConfig pool_config; diff --git a/docker/build_all.sh b/docker/build_all.sh old mode 100755 new mode 100644 diff --git a/docker/cann/Dockerfile_cann b/docker/cann/Dockerfile_cann new file mode 100644 index 000000000..581be7882 --- /dev/null +++ b/docker/cann/Dockerfile_cann @@ -0,0 +1,78 @@ +# Extened/build an image for CANN support +# Ascend-cann-toolkit_.run is expected to exist in /ascend_install_files + +# preferably arm64 +FROM ubuntu:20.04 + +RUN DEBIAN_FRONTEND="noninteractive" apt update && \ + apt install --no-install-recommends net-tools -y && \ + apt install --no-install-recommends libsqlite3-dev -y && \ + apt install --no-install-recommends zlib1g -y && \ + apt install --no-install-recommends openssl -y + +RUN DEBIAN_FRONTEND="noninteractive" apt update && \ + apt install --no-install-recommends ca-certificates -y && \ + apt install --no-install-recommends bc wget -y && \ + apt install --no-install-recommends curl gdb cmake gcc make g++ pkg-config unzip -y && \ + apt install --no-install-recommends libblas3 liblapack3 gfortran vim -y && \ + apt install --no-install-recommends liblapack-dev libblas-dev libhdf5-dev libffi-dev -y && \ + apt install --no-install-recommends libssl-dev zlib1g-dev xz-utils cython3 python3-h5py -y && \ + apt install --no-install-recommends libopenblas-dev libgmpxx4ldbl liblzma-dev -y && \ + apt install --no-install-recommends pciutils -y + + +RUN DEBIAN_FRONTEND="noninteractive" apt update && \ + apt-get install -y --no-install-recommends \ + python3-dev \ + python3-pip \ + wget + +RUN python3 -m pip --no-cache-dir install numpy && \ + python3 -m pip --no-cache-dir install decorator && \ + python3 -m pip --no-cache-dir install sympy && \ + python3 -m pip --no-cache-dir install cffi && \ + python3 -m pip --no-cache-dir install pyyaml && \ + python3 -m pip --no-cache-dir install pathlib2 && \ + python3 -m pip --no-cache-dir install protobuf && \ + python3 -m pip --no-cache-dir install scipy + +RUN python3 -m pip --no-cache-dir install psutil && \ + python3 -m pip --no-cache-dir install requests absl-py + +RUN python3 -m pip --no-cache-dir install attrs + +# cleanup actions +RUN rm -rf /root/.cache/pip +RUN DEBIAN_FRONTEND="noninteractive" apt clean && rm -rf /var/lib/apt/lists/* +RUN DEBIAN_FRONTEND="noninteractive" apt autoremove && apt autoclean + +# Install Ascend toolkit +COPY ascend_install_files ascend_install_files +RUN chmod +x ascend_install_files/Ascend-cann-toolkit_7.0.RC1.alpha001_linux-aarch64.run && \ + ascend_install_files/Ascend-cann-toolkit_7.0.RC1.alpha001_linux-aarch64.run --install && \ + rm -f ascend_install_files/Ascend-cann-toolkit_7.0.RC1.alpha001_linux-aarch64.run + +# Add usergroup & user +RUN groupadd HwHiAiUser && useradd -g HwHiAiUser -m -d /home/HwHiAiUser HwHiAiUser + +# This is copied from /usr/local/Ascend/ascend-toolkit/set_env.sh of the respective ascend-toolkit version +ENV LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64:/usr/local/Ascend/driver/lib64/common:/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:$LD_LIBRARY_PATH +ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:$PYTHONPATH +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:$PATH +ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit +ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME} + +# ENV LD_LIBRARY_PATH=/usr/lib/aarch64-linux-gnu/hdf5/serial:$LD_LIBRARY_PATH +# ENV HCCL_CONNECT_TIMEOUT=7200 +# ENV HCCL_WHITELIST_DISABLE=1 +# ENV HCCL_SECURITY_MODE=1 + +ENV ASCEND_GLOBAL_LOG_LEVEL=3 + +# Set env vars again in case of interactive ssh connection (ascend-toolkit assumed to be already installed) +RUN cp /usr/local/Ascend/ascend-toolkit/set_env.sh /etc/profile.d/ +RUN chmod 644 /etc/profile.d/set_env.sh diff --git a/docker/cann/run_container_cann.sh b/docker/cann/run_container_cann.sh new file mode 100644 index 000000000..9a8df3470 --- /dev/null +++ b/docker/cann/run_container_cann.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# build image that will host CANN environment +cd ../../ +docker build -t ctranslate2-aarch64 -f docker/cann/Dockerfile_cann --platform linux/arm64 . + +# run the respective container +docker run \ +-d --cap-add sys_ptrace \ +--pids-limit 409600 \ +--privileged --shm-size=128G \ +-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ +-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ +-v /usr/local/dcmi:/usr/local/dcmi \ +--name ctranslate2-aarch64 diff --git a/docs/hardware_support.md b/docs/hardware_support.md index 88506b547..782acf754 100644 --- a/docs/hardware_support.md +++ b/docs/hardware_support.md @@ -20,3 +20,10 @@ See the [environment variables](environment_variables.md) `CT2_USE_MKL` and `CT2 * NVIDIA GPUs with a Compute Capability greater or equal to 3.5 The driver requirement depends on the CUDA version. See the [CUDA Compatibility guide](https://docs.nvidia.com/deploy/cuda-compatibility/index.html) for more information. + +## NPU + +* AArch64/ARM64 processors +* Ascend NPU AI Processor greater or equal to 910A + +`CANN` version greater or equal to `7.0.RC1.alpha001` (depends on NPU model). See [CANN documentation](https://support.huawei.com/enterprise/en/ascend-computing/cann-pid-251168373) for more information. diff --git a/examples/cann/CMakeLists.txt b/examples/cann/CMakeLists.txt new file mode 100644 index 000000000..2c5207da7 --- /dev/null +++ b/examples/cann/CMakeLists.txt @@ -0,0 +1,10 @@ +cmake_minimum_required(VERSION 3.7) +project(cann) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Release) +find_package(Threads) +add_executable(cann_run main.cc) +target_link_libraries(cann_run PRIVATE + ${CMAKE_THREAD_LIBS_INIT} + ctranslate2 + ) diff --git a/examples/cann/README.md b/examples/cann/README.md new file mode 100644 index 000000000..164a5ed97 --- /dev/null +++ b/examples/cann/README.md @@ -0,0 +1,45 @@ +# CANN example query +This example demonstrates a translation query employing `CANN` using the English-German Transformer model trained with OpenNMT-py as in [CTranslate2 documentation](https://opennmt.net/CTranslate2/quickstart.html). + +## Environment setup +- Create environment:`docker/cann/Dockerfile_cann` +- Run the container: `docker/cann/run_container_cann.sh` + +## Download model +```bash +wget https://s3.amazonaws.com/opennmt-models/transformer-ende-wmt-pyOnmt.tar.gz +tar xf transformer-ende-wmt-pyOnmt.tar.gz +``` + +## Build executable +Run `examples/cann/build_run.sh` + +### Expected output + +``` +current path: "" +input data path: "" +[] [ctranslate2] [thread 49835] [info] CPU: ARM (NEON=true) +[] [ctranslate2] [thread 49835] [info] - Selected ISA: NEON +[] [ctranslate2] [thread 49835] [info] - Use Intel MKL: false +[] [ctranslate2] [thread 49835] [info] - SGEMM backend: OpenBLAS (packed: false) +[] [ctranslate2] [thread 49835] [info] - GEMM_S16 backend: none (packed: false) +[] [ctranslate2] [thread 49835] [info] - GEMM_S8 backend: Ruy (packed: false, u8s8 preferred: false) +[] [ctranslate2] [thread 49835] [info] NPU: +[] [ctranslate2] [thread 49835] [info] - Number of NPU cores: 8 +[] [ctranslate2] [thread 49835] [info] - aclrtRunMode: ACL_HOST +[] [ctranslate2] [thread 49835] [info] Loaded model on device cann:0 +[] [ctranslate2] [thread 49835] [info] - Binary version: 6 +[] [ctranslate2] [thread 49835] [info] - Model specification revision: 7 +[] [ctranslate2] [thread 49835] [info] - Selected compute type: float32 +input data: +▁H ello ▁world ! +Start: Warmup examples +output: +▁Hallo ▁Welt ! +input data: +▁H ello ▁world ! +Start: Query examples +output: +▁Hallo ▁Welt ! +``` diff --git a/examples/cann/build_run.sh b/examples/cann/build_run.sh new file mode 100644 index 000000000..cde9e0b41 --- /dev/null +++ b/examples/cann/build_run.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# execute from project root + +# first build ct2lib +rm -rf build-release/ +mkdir build-release && cd build-release || exit + +cmake -DWITH_CANN=ON -DCMAKE_BUILD_TYPE=Release -DBUILD_CLI=OFF -DWITH_MKL=OFF -DOPENMP_RUNTIME=COMP -DCMAKE_PREFIX_PATH="/opt/OpenBLAS" -DWITH_OPENBLAS=ON -DWITH_RUY=ON .. + +make -j"$(nproc)" + +rm CMakeCache.txt + +# then build cann_run +cmake -DCMAKE_BUILD_TYPE=Release ../examples/cann/ + +make -j"$(nproc)" +# ./cann_run diff --git a/examples/cann/main.cc b/examples/cann/main.cc new file mode 100644 index 000000000..db9f4e840 --- /dev/null +++ b/examples/cann/main.cc @@ -0,0 +1,48 @@ +#include +#include +#include +#include +#include + +void execute_translation(ctranslate2::Translator &translator, const std::vector>& batch, const std::string& msg) { + std::cout << "input data: " << std::endl; + for (const auto &input: batch) { + for (const auto &word: input) { + std::cout << word << ' '; + } + std::cout << "\n"; + } + + std::cout << "Start: " << msg << " examples\n"; + // const auto start{std::chrono::steady_clock::now()}; + const std::vector results = translator.translate_batch(batch); + // const auto end{std::chrono::steady_clock::now()}; + // std::cout << "End: " << msg << " examples. time: "<< std::chrono::duration_cast(end - start).count() << "ms\n"; + + std::cout << "output: " << std::endl; + for (const auto &token: results[0].output()) + std::cout << token << ' '; + std::cout << std::endl; +} + +int main(int, char* argv[]) { + std::filesystem::path cwd = std::filesystem::current_path(); + std::cout << "current path: " << cwd << std::endl; + const std::string input_data_path = argv[1]; + std::cout << "input data path: " << std::filesystem::absolute(input_data_path) << std::endl; + + if(!std::filesystem::exists(input_data_path)) { + std::cout << input_data_path << " does not exist" << std::endl; + } + + ctranslate2::set_log_level(ctranslate2::LogLevel::Info); + + const auto device = ctranslate2::str_to_device("auto"); + ctranslate2::Translator translator(input_data_path, device); + const std::vector > batch = {{"▁H", "ello", "▁world", "!"}}; + execute_translation(translator, batch, "Warmup"); + execute_translation(translator, batch, "Query"); + + return 0; +} + diff --git a/include/ctranslate2/devices.h b/include/ctranslate2/devices.h index 2691efc3a..a62b8c66d 100644 --- a/include/ctranslate2/devices.h +++ b/include/ctranslate2/devices.h @@ -7,7 +7,8 @@ namespace ctranslate2 { enum class Device { CPU, - CUDA + CUDA, + CANN }; Device str_to_device(const std::string& device); @@ -19,6 +20,9 @@ namespace ctranslate2 { int get_device_index(Device device); void set_device_index(Device device, int index); + void initialize_device(); + void finalize_device(); + void synchronize_device(Device device, int index); void synchronize_stream(Device device); diff --git a/include/ctranslate2/ops/gemm.h b/include/ctranslate2/ops/gemm.h index c309063d6..1d7ebfbf5 100644 --- a/include/ctranslate2/ops/gemm.h +++ b/include/ctranslate2/ops/gemm.h @@ -2,6 +2,9 @@ #include "activation.h" #include "op.h" +#ifdef CT2_WITH_CANN +#include +#endif namespace ctranslate2 { namespace ops { @@ -48,6 +51,10 @@ namespace ctranslate2 { bool _a_is_packed; bool _b_is_packed; const ActivationType* _activation_type; +#ifdef CT2_WITH_CANN + mutable std::shared_ptr _alpha_sv; + mutable std::shared_ptr _beta_sv; +#endif template void compute(const StorageView& a, diff --git a/include/ctranslate2/ops/matmul.h b/include/ctranslate2/ops/matmul.h index 6b84d48e1..6765be893 100644 --- a/include/ctranslate2/ops/matmul.h +++ b/include/ctranslate2/ops/matmul.h @@ -17,6 +17,10 @@ namespace ctranslate2 { template void compute(const StorageView& a, const StorageView& b, StorageView& c) const; + template + void handleCann(const StorageView &a, const StorageView &b, StorageView &c) const; + template + void handleNonCann(const StorageView &a, const StorageView &b, StorageView &c, dim_t m, dim_t n, const dim_t k, const dim_t a_batch_size) const; }; } diff --git a/include/ctranslate2/ops/mul.h b/include/ctranslate2/ops/mul.h index a09a9e00f..d9bbca38a 100644 --- a/include/ctranslate2/ops/mul.h +++ b/include/ctranslate2/ops/mul.h @@ -14,11 +14,28 @@ namespace ctranslate2 { void compute(const StorageView& a, const StorageView& b, StorageView& c) const { c.resize_as(a); if (b.is_scalar()) { - primitives::mul(b.data()[0], a.data(), c.data(), c.size()); + const auto scalar = b.data()[0]; + if constexpr (D == Device::CANN) { + handleCannScalar(scalar, a, c); + } + else { + primitives::mul(scalar, a.data(), c.data(), c.size()); + } } else { - primitives::mul(a.data(), b.data(), c.data(), c.size()); + if constexpr (D == Device::CANN) { + handleCann(a, b, c); + } + else { + primitives::mul(a.data(), b.data(), c.data(), c.size()); + } } } + + template + void handleCann(const StorageView& a, const StorageView& b, StorageView& c) const; + + template + void handleCannScalar(T scalar, const StorageView& a, StorageView& c) const; }; } diff --git a/include/ctranslate2/ops/transpose.h b/include/ctranslate2/ops/transpose.h index 519eed392..747c1f422 100644 --- a/include/ctranslate2/ops/transpose.h +++ b/include/ctranslate2/ops/transpose.h @@ -18,18 +18,36 @@ namespace ctranslate2 { void compute(const StorageView& x, const std::vector& perm, StorageView& y) const { if (x.rank() == 2) { y.resize({x.dim(1), x.dim(0)}); - primitives::transpose_2d(x.data(), x.shape().data(), y.data()); + if constexpr (D == Device::CANN) { + handleCann(x, perm, y); + } + else { + primitives::transpose_2d(x.data(), x.shape().data(), y.data()); + } } else if (x.rank() == 3) { y.resize({x.dim(perm[0]), x.dim(perm[1]), x.dim(perm[2])}); - primitives::transpose_3d(x.data(), x.shape().data(), perm.data(), y.data()); + if constexpr (D == Device::CANN) { + handleCann(x, perm, y); + } + else { + primitives::transpose_3d(x.data(), x.shape().data(), perm.data(), y.data()); + } } else if (x.rank() == 4) { y.resize({x.dim(perm[0]), x.dim(perm[1]), x.dim(perm[2]), x.dim(perm[3])}); - primitives::transpose_4d(x.data(), x.shape().data(), perm.data(), y.data()); + if constexpr (D == Device::CANN) { + handleCann(x, perm, y); + } + else { + primitives::transpose_4d(x.data(), x.shape().data(), perm.data(), y.data()); + } } else { throw std::invalid_argument("Transpose: rank " + std::to_string(x.rank()) + " is not supported, supported ranks are: 2, 3, 4"); } } + + template + void handleCann(const StorageView& x, const std::vector& perm, StorageView& y) const; }; } diff --git a/include/ctranslate2/primitives.h b/include/ctranslate2/primitives.h index bed80c8ff..8ee5a8ea2 100644 --- a/include/ctranslate2/primitives.h +++ b/include/ctranslate2/primitives.h @@ -16,9 +16,11 @@ namespace ctranslate2 { template static void fill(T* x, T a, dim_t size); template + static void zero(T* x, dim_t size, bool synchronous = true); + template static void strided_fill(T* x, T a, dim_t inc_x, dim_t size); template - static void indexed_fill(T* x, T a, const int32_t* indices, dim_t num_indices); + static void indexed_fill(T* x, T a, const int32_t* indices, dim_t num_indices, dim_t size = -1); template static void copy(const T* x, T* y, dim_t size); @@ -58,11 +60,11 @@ namespace ctranslate2 { } template - static void add_batch_broadcast(const T* a, const T* b, T* c, dim_t a_size, dim_t b_size); + static void add_batch_broadcast(const T* a, const T* b, T* c, dim_t a_size, dim_t b_size, bool synchronous = true); template - static void add_batch_broadcast(const T* x, T* y, dim_t x_size, dim_t y_size) { - add_batch_broadcast(x, y, y, x_size, y_size); + static void add_batch_broadcast(const T* x, T* y, dim_t x_size, dim_t y_size, bool synchronous = true) { + add_batch_broadcast(x, y, y, x_size, y_size, synchronous); } template @@ -220,6 +222,17 @@ namespace ctranslate2 { float beta, Out* c, dim_t ldc, dim_t stridec, dim_t batch_size); + + template + static void gemm_alpha_beta_in_device(bool a_is_packed, bool b_is_packed, + bool transpose_a, bool transpose_b, + dim_t m, dim_t n, dim_t k, + const float* alpha, + const In* a, dim_t lda, + const In* b, dim_t ldb, + const float* beta, + Out* c, dim_t ldc, + const Out* a_shift_compensation = nullptr); }; template diff --git a/include/ctranslate2/replica_pool.h b/include/ctranslate2/replica_pool.h index efc9824d1..36af504af 100644 --- a/include/ctranslate2/replica_pool.h +++ b/include/ctranslate2/replica_pool.h @@ -345,7 +345,7 @@ namespace ctranslate2 { } void idle() override { - // When no new jobs are immediately available, we synchronize the CUDA stream + // When no new jobs are immediately available, we synchronize the CUDA/CANN stream // so that the CudaAsyncAllocator can release some memory. synchronize_stream(_device); } diff --git a/include/ctranslate2/storage_view.h b/include/ctranslate2/storage_view.h index 8834ef651..2fd715769 100644 --- a/include/ctranslate2/storage_view.h +++ b/include/ctranslate2/storage_view.h @@ -139,6 +139,8 @@ namespace ctranslate2 { return _size; } + dim_t size_in_bytes() const; + dim_t item_size() const; bool is_scalar() const { @@ -186,6 +188,8 @@ namespace ctranslate2 { T* index(std::initializer_list indices); template const T* index(std::initializer_list indices) const; + template + const T* index(const std::vector& indices) const; template T& at(dim_t index) { @@ -229,7 +233,7 @@ namespace ctranslate2 { template StorageView& fill(T value); - StorageView& zero(); + StorageView& zero(bool synchronous = true); StorageView& copy_from(const StorageView& other, bool synchronous = false); diff --git a/include/ctranslate2/utils.h b/include/ctranslate2/utils.h index c8e7ef78b..a49197421 100644 --- a/include/ctranslate2/utils.h +++ b/include/ctranslate2/utils.h @@ -11,6 +11,7 @@ namespace ctranslate2 { void log_system_config(); int get_gpu_count(); + int get_npu_count(); void set_num_threads(size_t num_threads); bool ends_with(const std::string& str, const std::string& suffix); diff --git a/python/cpp/encoder.cc b/python/cpp/encoder.cc index ea8b1a430..e3dae906b 100644 --- a/python/cpp/encoder.cc +++ b/python/cpp/encoder.cc @@ -86,7 +86,7 @@ namespace ctranslate2 { Arguments: model_path: Path to the CTranslate2 model directory. - device: Device to use (possible values are: cpu, cuda, auto). + device: Device to use (possible values are: cpu, cuda, cann, auto). device_index: Device IDs where to place this encoder on. compute_type: Model computation type or a dictionary mapping a device name to the computation type (possible values are: default, auto, int8, int8_float32, diff --git a/python/cpp/generator.cc b/python/cpp/generator.cc index 981c6da68..946426986 100644 --- a/python/cpp/generator.cc +++ b/python/cpp/generator.cc @@ -143,7 +143,7 @@ namespace ctranslate2 { Arguments: model_path: Path to the CTranslate2 model directory. - device: Device to use (possible values are: cpu, cuda, auto). + device: Device to use (possible values are: cpu, cuda, cann, auto). device_index: Device IDs where to place this generator on. compute_type: Model computation type or a dictionary mapping a device name to the computation type (possible values are: default, auto, int8, int8_float32, diff --git a/python/cpp/module.cc b/python/cpp/module.cc index 4a9e47561..2bc85e537 100644 --- a/python/cpp/module.cc +++ b/python/cpp/module.cc @@ -56,6 +56,9 @@ PYBIND11_MODULE(_ext, m) m.def("get_cuda_device_count", &ctranslate2::get_gpu_count, "Returns the number of visible GPU devices."); + m.def("get_cann_device_count", &ctranslate2::get_npu_count, + "Returns the number of visible NPU devices."); + m.def("get_supported_compute_types", &get_supported_compute_types, py::arg("device"), py::arg("device_index")=0, @@ -63,7 +66,7 @@ PYBIND11_MODULE(_ext, m) Returns the set of supported compute types on a device. Arguments: - device: Device name (cpu or cuda). + device: Device name (cpu or cuda or cann). device_index: Device index. Example: @@ -71,6 +74,8 @@ PYBIND11_MODULE(_ext, m) {'int16', 'float32', 'int8', 'int8_float32'} >>> ctranslate2.get_supported_compute_types("cuda") {'float32', 'int8_float16', 'float16', 'int8', 'int8_float32'} + >>> ctranslate2.get_supported_compute_types("cann") + {'int8', 'float32', 'float16', 'int8_float16', 'int8_float32'} )pbdoc"); m.def("set_random_seed", &ctranslate2::set_random_seed, py::arg("seed"), diff --git a/python/cpp/storage_view.cc b/python/cpp/storage_view.cc index ce5e95b3b..e34b585fe 100644 --- a/python/cpp/storage_view.cc +++ b/python/cpp/storage_view.cc @@ -165,7 +165,7 @@ namespace ctranslate2 { [](const StorageView& view) { return device_to_str(view.device()); }, - "Device where the storage is allocated (\"cpu\" or \"cuda\").") + "Device where the storage is allocated (\"cpu\" or \"cuda\" or \"cann\").") .def_property_readonly("__array_interface__", [](const StorageView& view) { if (view.device() == Device::CUDA) diff --git a/python/ctranslate2/__init__.py b/python/ctranslate2/__init__.py index 9c0efac2a..3fdeef8e8 100644 --- a/python/ctranslate2/__init__.py +++ b/python/ctranslate2/__init__.py @@ -35,6 +35,7 @@ Translator, contains_model, get_cuda_device_count, + get_cann_device_count, get_supported_compute_types, set_random_seed, ) diff --git a/src/cann/allocator.cc b/src/cann/allocator.cc new file mode 100644 index 000000000..f644708c0 --- /dev/null +++ b/src/cann/allocator.cc @@ -0,0 +1,44 @@ +#include "ctranslate2/allocator.h" +#include "./utils.h" + +namespace ctranslate2 { + namespace cann { + + class CannAllocator : public Allocator { + public: + void* allocate(size_t size, int device_index) override { + int prev_device_index = -1; + if (device_index >= 0) { + ACL_CALL(aclrtGetDevice(&prev_device_index)); + ACL_CALL(aclrtSetDevice(device_index)); + } + + void* ptr = nullptr; + ACL_CALL(aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + if (prev_device_index >= 0) { + ACL_CALL(aclrtSetDevice(prev_device_index)); + } + return ptr; + } + + void free(void* ptr, int device_index) override { + int prev_device_index = -1; + if (device_index >= 0) { + ACL_CALL(aclrtGetDevice(&prev_device_index)); + ACL_CALL(aclrtSetDevice(device_index)); + } + ACL_CALL(aclrtFree(ptr)); + + if (prev_device_index >= 0) { + ACL_CALL(aclrtSetDevice(prev_device_index)); + } + } + }; + } + + template<> + Allocator& get_allocator() { + static cann::CannAllocator allocator; + return allocator; + } +} diff --git a/src/cann/cann_inc.h b/src/cann/cann_inc.h new file mode 100644 index 000000000..799794dc1 --- /dev/null +++ b/src/cann/cann_inc.h @@ -0,0 +1,8 @@ +#pragma once + +#include +#include +#include +#include +#include + diff --git a/src/cann/primitives.cc b/src/cann/primitives.cc new file mode 100644 index 000000000..a7282e07c --- /dev/null +++ b/src/cann/primitives.cc @@ -0,0 +1,960 @@ +#include +#include "ctranslate2/primitives.h" +#include "utils.h" +#include "type_dispatch.h" +#include "ctranslate2/storage_view.h" + +namespace ctranslate2 { + + template<> + template + T primitives::at(const T* x, dim_t index) { + T val = T(); + cross_device_primitives::copy(x + index, &val, 1); + return val; + } + + template<> + template + void primitives::fill(T* x, T a, dim_t size) { + ctranslate2::cann::CannPreparation prepare; + + const aclDataType aclType = cann::getACLType(); + const ctranslate2::Shape x_shape = {size}; + + cann_prepare_inputdesc(prepare, aclType, x_shape.size(), x_shape.data(), ACL_FORMAT_ND); + cann_prepare_outputdesc(prepare, aclType, x_shape.size(), x_shape.data(), ACL_FORMAT_ND); + + ACL_CALL(aclopSetAttrFloat(prepare._opAttr, "value", static_cast(a))); + + const dim_t size_of_x = size*sizeof(T); + cann_prepare_inputbuffer(prepare, x, size_of_x); + cann_prepare_outputbuffer(prepare, x, size_of_x); + + ACL_CALL(aclopCompileAndExecute("Fills", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + + template<> + template + void primitives::zero(T* x, dim_t size, bool synchronous) { + ctranslate2::cann::CannPreparation prepare; + + const aclDataType aclType = cann::getACLType(); + const ctranslate2::Shape x_shape = {size}; + + cann_prepare_inputdesc(prepare, aclType, x_shape.size(), x_shape.data(), ACL_FORMAT_ND); + cann_prepare_outputdesc(prepare, aclType, x_shape.size(), x_shape.data(), ACL_FORMAT_ND); + + const dim_t size_of_x = size*sizeof(T); + cann_prepare_inputbuffer(prepare, x, size_of_x); + cann_prepare_outputbuffer(prepare, x, size_of_x); + + ACL_CALL(aclopCompileAndExecute("ZerosLike", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + if (synchronous) { + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + } + + template<> + template + void primitives::strided_fill(T* x, T a, dim_t inc_x, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template<> + template + void primitives::indexed_fill(T* x, T a, const int32_t* indices, dim_t num_indices, dim_t size) { + if (size <= 0) { + THROW_RUNTIME_ERROR("Input 'size' of 'indexed_fill' primitive should be positive"); + } + + // Using a unique pointer to a bool array instead of a 'std::vector' because in the case of the latter, it's + // not possible to get a raw pointer to the underlying data. + auto mask = std::make_unique(size); + std::vector indices_cpu(num_indices); + + // Moving the 'indices' from the NPU to the CPU in order to iterate over them. + cross_device_primitives::copy(indices, indices_cpu.data(), num_indices); + for (dim_t i = 0; i < num_indices; ++i) { + mask[indices_cpu[i]] = true; + } + + ctranslate2::cann::CannPreparation prepare; + const aclDataType aclType = cann::getACLType(); + const Shape input_and_mask_shape = {size}, value_shape = {1}; + + cann_prepare_inputdesc(prepare, aclType, input_and_mask_shape.size(), input_and_mask_shape.data(), ACL_FORMAT_ND); + cann_prepare_inputdesc(prepare, ACL_BOOL, input_and_mask_shape.size(), input_and_mask_shape.data(), ACL_FORMAT_ND); + cann_prepare_inputdesc(prepare, aclType, value_shape.size(), value_shape.data(), ACL_FORMAT_ND); + + cann_prepare_outputdesc(prepare, aclType, input_and_mask_shape.size(), input_and_mask_shape.data(), ACL_FORMAT_ND); + + const dim_t size_of_input_in_bytes = sizeof(T)*size; + cann_prepare_inputbuffer(prepare, const_cast(x), size_of_input_in_bytes); + + // Temporary way to allocate memory for the 'mask' input + void *mask_dev_ptr = nullptr; + const dim_t size_of_mask_in_bytes = sizeof(bool)*size; + ACL_CALL(aclrtMalloc(&mask_dev_ptr, size_of_mask_in_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CALL(aclrtMemcpyAsync(mask_dev_ptr, size_of_mask_in_bytes, mask.get(), size_of_mask_in_bytes, ACL_MEMCPY_HOST_TO_DEVICE, cann::get_aclrt_stream())); + cann_prepare_inputbuffer(prepare, mask_dev_ptr, size_of_mask_in_bytes); + + // Temporary way to allocate memory for the 'value' input + void *value_dev_ptr = nullptr; + ACL_CALL(aclrtMalloc(&value_dev_ptr, sizeof(T), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CALL(aclrtMemcpyAsync(value_dev_ptr, sizeof(T), const_cast(&a), sizeof(T), ACL_MEMCPY_HOST_TO_DEVICE, cann::get_aclrt_stream())); + cann_prepare_inputbuffer(prepare, value_dev_ptr, sizeof(T)); + cann_prepare_outputbuffer(prepare, x, size_of_input_in_bytes); + + ACL_CALL(aclopCompileAndExecute("MaskedFill", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + + // Temporary way to free the allocated memory for inputs + ACL_CALL(aclrtFree(mask_dev_ptr)); + ACL_CALL(aclrtFree(value_dev_ptr)); + } + + template<> + template + void primitives::copy(const T* x, T* y, dim_t size) { + const auto size_in_bytes = size * sizeof (T); + ACL_CALL(aclrtMemcpy(y, size_in_bytes, x, size_in_bytes, + ACL_MEMCPY_DEVICE_TO_DEVICE)); + } + + template<> + template + void primitives::convert(const U* x, V* y, dim_t size) { + cann::CannPreparation prepare; + // Assume the shape as if the tensor was one-dimensional + const Shape shape_1d = {size}; + const auto in_type = cann::getACLType(); + const auto out_type = cann::getACLType(); + + ACL_CALL(aclopSetAttrDataType(prepare._opAttr, "dst_type", out_type)); + aclFormat format = ACL_FORMAT_ND; + + cann_prepare_inputdesc(prepare, in_type, shape_1d.size(), shape_1d.data(), format); + cann_prepare_outputdesc(prepare, out_type, shape_1d.size(), shape_1d.data(), format); + + cann_prepare_inputbuffer(prepare, const_cast(x), size*sizeof(U)); + cann_prepare_outputbuffer(prepare, y, size*sizeof(V)); + + ACL_CALL(aclopCompileAndExecute("Cast", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + + template void primitives::convert(const float*, float16_t*, dim_t); + template void primitives::convert(const float16_t*, float*, dim_t); + template<> + template<> + void primitives::convert(const float*, bfloat16_t*, dim_t) { + THROW_RUNTIME_ERROR("Unsupported ACL type: float to bfloat16_t"); + } + template<> + template<> + void primitives::convert(const bfloat16_t*, float*, dim_t) { + THROW_RUNTIME_ERROR("Unsupported ACL type: bfloat16_t to float"); + } + template<> + template<> + void primitives::convert(const float16_t* x, bfloat16_t* y, dim_t size) { + THROW_RUNTIME_ERROR("Unsupported ACL type: float16_t to bfloat16_t"); + } + template<> + template<> + void primitives::convert(const bfloat16_t* x, float16_t* y, dim_t size) { + THROW_RUNTIME_ERROR("Unsupported ACL type: bfloat16_t to float16_t"); + } + + template<> + template + T primitives::sum(const T* array, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template<> + template + dim_t primitives::max_element(const T* array, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template<> + template + T primitives::max(const T* array, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template<> + template + void primitives::add(T a, const T* x, T* y, dim_t size) { + const aclDataType aclType = cann::getACLType(); + + static std::unordered_set supportedTypes{ACL_INT64, ACL_INT32, ACL_FLOAT, ACL_FLOAT16}; + if(supportedTypes.find(aclType) == supportedTypes.end()) + THROW_RUNTIME_ERROR("Unsupported ACL type for Add-scalar: " + std::to_string(aclType)); + + Shape arrayShape = {size}; + + ctranslate2::cann::CannPreparation prepare; + cann_prepare_inputdesc(prepare, aclType, arrayShape.size(), arrayShape.data(), ACL_FORMAT_ND); + cann_prepare_outputdesc(prepare, aclType, arrayShape.size(), arrayShape.data(), ACL_FORMAT_ND); + + cann_prepare_inputbuffer(prepare, const_cast(x), sizeof(T)*size); + cann_prepare_outputbuffer(prepare, y, sizeof(T)*size); + + // 'value' must be a float according to the documentation + ACL_CALL(aclopSetAttrFloat(prepare._opAttr, "value", static_cast(a))); + + ACL_CALL(aclopCompileAndExecute("Adds", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + + template<> + template + void primitives::add(const T* a, const T* b, T* c, dim_t size) { + ctranslate2::cann::CannPreparation prepare; + + const aclDataType aclType = cann::getACLType(); + Shape arrayShape = {size}; + + cann_prepare_inputdesc(prepare, aclType, arrayShape.size(), arrayShape.data(), ACL_FORMAT_ND); + cann_prepare_inputdesc(prepare, aclType, arrayShape.size(), arrayShape.data(), ACL_FORMAT_ND); + cann_prepare_outputdesc(prepare, aclType, arrayShape.size(), arrayShape.data(), ACL_FORMAT_ND); + + cann_prepare_inputbuffer(prepare, const_cast(a), sizeof(T)*size); + cann_prepare_inputbuffer(prepare, const_cast(b), sizeof(T)*size); + cann_prepare_outputbuffer(prepare, c, sizeof(T)*size); + + ACL_CALL(aclopCompileAndExecute("Add", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + + template<> + template + void primitives::add_batch_broadcast(const T* a, const T* b, T* c, + dim_t a_size, dim_t b_size, bool synchronous) { + const aclDataType aclType = cann::getACLType(); + + // According to the documentation, 'bias' should have length that is equal to the last dimension of 'value' + Shape bias_shape = {a_size}, value_shape = {b_size/a_size, a_size}; + + ctranslate2::cann::CannPreparation prepare; + + cann_prepare_inputdesc(prepare, aclType, value_shape.size(), value_shape.data(), ACL_FORMAT_ND); + cann_prepare_inputdesc(prepare, aclType, bias_shape.size(), bias_shape.data(), ACL_FORMAT_ND); + cann_prepare_outputdesc(prepare, aclType, value_shape.size(), value_shape.data(), ACL_FORMAT_ND); + + cann_prepare_inputbuffer(prepare, const_cast(b), sizeof(T)*b_size); + cann_prepare_inputbuffer(prepare, const_cast(a), sizeof(T)*a_size); + cann_prepare_outputbuffer(prepare, c, sizeof(T)*b_size); + + // We skipped the 'data_format' optional attribute, because it defaults to "NHWC" + + ACL_CALL(aclopCompileAndExecute("BiasAdd", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + if (synchronous){ + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + } + + template<> + template + void primitives::add_depth_broadcast(const T* a, const T* b, T* c, + dim_t a_size, dim_t b_size) { + cann::CannPreparation prepare; + const aclDataType aclType = cann::getACLType(); + const aclFormat format = ACL_FORMAT_ND; + + // 'bias' should have length that is equal to the first dimension of 'value'. + // 'value' can have any shape, but we can safely assume that it's a 2-D vector where the 1st dimension matches + // the length of 'bias'. + Shape bias_shape = {a_size}, value_shape = {a_size, b_size/a_size}; + + cann_prepare_inputdesc(prepare, aclType, value_shape.size(), value_shape.data(), format); + cann_prepare_inputdesc(prepare, aclType, bias_shape.size(), bias_shape.data(), format); + cann_prepare_outputdesc(prepare, aclType, value_shape.size(), value_shape.data(), format); + + // Instruct the "Bias" operator to match the 1st axis of the 'values' vector with the 'bias' vector. + ACL_CALL(aclopSetAttrInt(prepare._opAttr, "axis", 0)); + + const dim_t value_size = b_size*sizeof(T); + cann_prepare_inputbuffer(prepare, const_cast(b), value_size); + cann_prepare_inputbuffer(prepare, const_cast(a), a_size*sizeof(T)); + cann_prepare_outputbuffer(prepare, c, value_size); + + ACL_CALL(aclopCompileAndExecute("Bias", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + + template<> + template + void primitives::sub(const T* a, const T* b, T* c, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template<> + template + void primitives::min(T a, const T* x, T* y, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template<> + template + void primitives::min(const T* a, const T* b, T* c, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template<> + template + void primitives::max(T a, const T* x, T* y, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template<> + template + void primitives::max(const T* a, const T* b, T* c, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template<> + template + void primitives::mul(T a, const T* x, T* y, dim_t size) { + THROW_RUNTIME_ERROR("CANN case is handled in StorageView level"); + } + + template<> + template + void primitives::mul(const T* a, const T* b, T* c, dim_t size) { + THROW_RUNTIME_ERROR("CANN case is handled in StorageView level"); + } + + template<> + template + void primitives::mul_batch_broadcast(const T* a, const T* b, T* c, + dim_t a_size, dim_t b_size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template<> + template + void primitives::relu(const T* x, T* y, dim_t size) { + cann::CannPreparation prepare; + + // Assume the shape as if the tensor was one-dimensional + const ctranslate2::Shape shape_1d = {size}; + const aclDataType aclType = cann::getACLType(); + + if(aclType == ACL_BF16) { + THROW_RUNTIME_ERROR("Unsupported ACL type: " + std::to_string(aclType)); + } + + aclFormat format = ACL_FORMAT_ND; + const dim_t size_in_bytes = size*sizeof(T); + + cann_prepare_inputdesc(prepare, aclType, shape_1d.size(), shape_1d.data(), format); + cann_prepare_outputdesc(prepare, aclType, shape_1d.size(), shape_1d.data(), format); + + cann_prepare_inputbuffer(prepare, const_cast(x), size_in_bytes); + cann_prepare_outputbuffer(prepare, y, size_in_bytes); + + ACL_CALL(aclopCompileAndExecute("Relu", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + + template void primitives::relu(const float*, float*, dim_t); + template void primitives::relu(const float16_t*, float16_t*, dim_t); + template void primitives::relu(const bfloat16_t*, bfloat16_t*, dim_t); + + template<> + template + void primitives::gelu(const T* x, T* y, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template void primitives::gelu(const float*, float*, dim_t); + template void primitives::gelu(const float16_t*, float16_t*, dim_t); + template void primitives::gelu(const bfloat16_t*, bfloat16_t*, dim_t); + + template<> + template + void primitives::gelu_tanh(const T* x, T* y, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template void primitives::gelu_tanh(const float*, float*, dim_t); + template void primitives::gelu_tanh(const float16_t*, float16_t*, dim_t); + template void primitives::gelu_tanh(const bfloat16_t*, bfloat16_t*, dim_t); + + template<> + template + void primitives::gelu_sigmoid(const T* x, T* y, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template void primitives::gelu_sigmoid(const float*, float*, dim_t); + template void primitives::gelu_sigmoid(const float16_t*, float16_t*, dim_t); + template void primitives::gelu_sigmoid(const bfloat16_t*, bfloat16_t*, dim_t); + + template<> + template + void primitives::swish(const T* x, T* y, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template void primitives::swish(const float*, float*, dim_t); + template void primitives::swish(const float16_t*, float16_t*, dim_t); + template void primitives::swish(const bfloat16_t*, bfloat16_t*, dim_t); + + template<> + template + void primitives::penalize_previous_tokens(T* scores, + const T* previous_scores, + const int32_t* previous_ids, + T penalty, + dim_t batch_size, + dim_t length, + dim_t vocabulary_size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + /** + * Broadcasts vector 'x' of length 'x_length' across 'num_rows' rows and produces a vector 'y' of shape {num_rows, x_length}. + * This is achieved through the "BroadcastTo" CANN operator with: + * Inputs: + * x: A tensor. + * shape: A 1D tensor of type int32, for the shape of the desired output. + * Outputs: + * y: A tensor of shape 'shape' and type same as 'x'. + */ + void broadcast_to(int32_t* x, int32_t num_rows, int32_t x_length, int32_t* y) { + cann::CannPreparation prepare; + const Shape shape_of_x = {x_length}; + const Shape shape_of_y = {num_rows, x_length}; + const Shape shape_of_shape = {static_cast(shape_of_y.size())}; + cann_prepare_inputdesc(prepare, ACL_INT32, shape_of_x.size(), shape_of_x.data(), ACL_FORMAT_ND); + cann_prepare_inputdesc(prepare, ACL_INT32, shape_of_shape.size(), shape_of_shape.data(), ACL_FORMAT_ND); + cann_prepare_outputdesc(prepare, ACL_INT32, shape_of_y.size(), shape_of_y.data(), ACL_FORMAT_ND); + + std::vector shape = {num_rows, x_length}; + dim_t shape_size_in_bytes = shape.size()*sizeof(int32_t); + dim_t x_size_in_bytes = x_length*sizeof(int32_t); + void *x_dev_ptr = nullptr, *shape_dev_ptr = nullptr; + // Temporary way to allocate memory on the NPU for inputs 'x' and 'shape' + ACL_CALL(aclrtMalloc(&x_dev_ptr, x_size_in_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CALL(aclrtMemcpyAsync(x_dev_ptr, x_size_in_bytes, x, x_size_in_bytes, ACL_MEMCPY_HOST_TO_DEVICE, cann::get_aclrt_stream())); + ACL_CALL(aclrtMalloc(&shape_dev_ptr, shape_size_in_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CALL(aclrtMemcpyAsync(shape_dev_ptr, shape_size_in_bytes, shape.data(), shape_size_in_bytes, + ACL_MEMCPY_HOST_TO_DEVICE, cann::get_aclrt_stream())); + + cann_prepare_inputbuffer(prepare, x_dev_ptr, x_size_in_bytes); + cann_prepare_inputbuffer(prepare, shape_dev_ptr, shape_size_in_bytes); + cann_prepare_outputbuffer(prepare, y, num_rows*x_size_in_bytes); + + ACL_CALL(aclopCompileAndExecute("BroadcastTo", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + + // Temporary way to free allocated memory for 'x' and 'shape'. + ACL_CALL(aclrtFree(x_dev_ptr)); + ACL_CALL(aclrtFree(shape_dev_ptr)); + } + + int32_t* prepare_vector_to_broadcast(std::vector &x, const dim_t size, const int32_t max_value) { + x.resize(size); + for (int32_t j = 0; j < size; ++j) { + x[j] = std::min(max_value, j+1); + } + return x.data(); + } + + template<> + void primitives::prepare_length_mask(const int32_t* lengths, + dim_t batch_size, + dim_t num_heads, + dim_t num_queries, + bool mask_future, + bool multi_query, + int32_t* mask) { + std::vector lengths_cpu(batch_size); + cross_device_primitives::copy(lengths, lengths_cpu.data(), batch_size); + std::vector x; + const uint64_t num_heads_and_queries = num_heads*num_queries; + for (dim_t b = 0; b < batch_size; ++b) { // Iterate over the 1st dimension (batches) of the output 'mask' + const int32_t length = lengths_cpu[b]; + int32_t* batch_mask = mask + b * num_heads_and_queries; + if (mask_future) { + if (multi_query) { // Shape of output 'mask' is: {batch_size, num_queries, num_heads} + int32_t *row_ptr; + for (dim_t i = 0; i < num_queries; ++i) { // Iterate over the 2nd dimension of the output 'mask' + row_ptr = batch_mask + i*num_heads; + // Fill each row of the current batch with value: min(length, i+1) + primitives::fill(row_ptr, std::min(length, int32_t(i+1)), num_heads); + } + } else { // Shape of output 'mask' is: {batch_size, num_heads, num_queries} + // Create a 1-D vector: {1, 2, ..., min(length, num_queries-1), min(length, num_queries)}. + prepare_vector_to_broadcast(x, num_queries, length); + // Broadcast the vector across all the rows of the current batch. + broadcast_to(x.data(), num_heads, num_queries, batch_mask); + } + } else { + primitives::fill(batch_mask, length, num_heads_and_queries); + } + } + } + + template<> + template + void primitives::transpose_2d(const T* a, const dim_t* dims, T* b) { + THROW_RUNTIME_ERROR("CANN case is handled in StorageView level"); + } + + template<> + template + void primitives::transpose_3d(const T* a, + const dim_t* dims, + const dim_t* perm, + T* b) { + THROW_RUNTIME_ERROR("CANN case is handled in StorageView level"); + } + + template<> + template + void primitives::transpose_4d(const T* a, + const dim_t* dims, + const dim_t* perm, + T* b) { + THROW_RUNTIME_ERROR("CANN case is handled in StorageView level"); + } + + template + void run_gemm_alpha_beta_in_device(bool transpose_a, bool transpose_b, + dim_t m, dim_t n, dim_t k, + const float* alpha_dev_ptr, const In* a, + const float* beta_dev_ptr, const In* b, + Out* c) { + aclFormat format = ACL_FORMAT_ND; + aclDataType aclIn = ctranslate2::cann::getACLType(), aclOut = ctranslate2::cann::getACLType(); + + ctranslate2::cann::CannPreparation prepare; + + ACL_CALL(aclopSetAttrBool(prepare._opAttr, "transpose_a", transpose_a)); + ACL_CALL(aclopSetAttrBool(prepare._opAttr, "transpose_b", transpose_b)); + + ctranslate2::Shape a_shape, b_shape; + // The "GEMM" CANN operator expects different shapes for the 'a' and 'b' input vectors, based on whether they are + // transpose or not. + if (transpose_a) { + a_shape = {k, m}; + } else { + a_shape = {m, k}; + } + if (transpose_b) { + b_shape = {n, k}; + } else { + b_shape = {k, n}; + } + + const ctranslate2::Shape c_shape = {m, n}; + // 'alpha' and 'beta' should be 1-D vectors of size 1, according to CANN documentation + static const ctranslate2::Shape alpha_beta_shape = {1}; + + cann_prepare_inputdesc(prepare, aclIn, a_shape.size(), a_shape.data(), format); + cann_prepare_inputdesc(prepare, aclIn, b_shape.size(), b_shape.data(), format); + cann_prepare_inputdesc(prepare, aclOut, c_shape.size(), c_shape.data(), format); + cann_prepare_inputdesc(prepare, ACL_FLOAT, alpha_beta_shape.size(), alpha_beta_shape.data(), format); + cann_prepare_inputdesc(prepare, ACL_FLOAT, alpha_beta_shape.size(), alpha_beta_shape.data(), format); + cann_prepare_outputdesc(prepare, aclOut, c_shape.size(), c_shape.data(), format); + + const dim_t c_size_in_bytes = m*n*sizeof(Out); + cann_prepare_inputbuffer(prepare, const_cast(a), m*k*sizeof(In)); + cann_prepare_inputbuffer(prepare, const_cast(b), k*n*sizeof(In)); + cann_prepare_inputbuffer(prepare, c, c_size_in_bytes); + cann_prepare_inputbuffer(prepare, const_cast(alpha_dev_ptr), sizeof(float)); + cann_prepare_inputbuffer(prepare, const_cast(beta_dev_ptr), sizeof(float)); + cann_prepare_outputbuffer(prepare, c, c_size_in_bytes); + + ACL_CALL(aclopCompileAndExecute("GEMM", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + // Synchronizing the stream takes place in the StorageView level + } + + template<> + template<> + void primitives::gemm_alpha_beta_in_device(bool, bool, + bool transpose_a, bool transpose_b, + dim_t m, dim_t n, dim_t k, + const float* alpha, + const float* a, dim_t, + const float* b, dim_t, + const float* beta, + float* c, dim_t, + const float*) { + run_gemm_alpha_beta_in_device(transpose_a, transpose_b, m, n, k, alpha, a, beta, b, c); + } + + template<> + template<> + void primitives::gemm_alpha_beta_in_device(bool, bool, + bool transpose_a, bool transpose_b, + dim_t m, dim_t n, dim_t k, + const float* alpha, + const float16_t* a, dim_t, + const float16_t* b, dim_t, + const float* beta, + float16_t* c, dim_t, + const float16_t*) { + run_gemm_alpha_beta_in_device(transpose_a, transpose_b, m, n, k, alpha, a, beta, b, c); + } + + template<> + template<> + void primitives::gemm_alpha_beta_in_device(bool, bool, + bool transpose_a, bool transpose_b, + dim_t m, dim_t n, dim_t k, + const float* alpha, + const bfloat16_t* a, dim_t, + const bfloat16_t* b, dim_t, + const float* beta, + bfloat16_t* c, dim_t, + const bfloat16_t*) { + THROW_RUNTIME_ERROR("FP16 GEMM is not supported by CANN"); + } + + template<> + template<> + void primitives::gemm_alpha_beta_in_device(bool, bool, + bool transpose_a, bool transpose_b, + dim_t m, dim_t n, dim_t k, + const float* alpha, + const int8_t* a, dim_t, + const int8_t* b, dim_t, + const float* beta, + int32_t* c, dim_t, + const int32_t*) { + run_gemm_alpha_beta_in_device(transpose_a, transpose_b, m, n, k, alpha, a, beta, b, c); + } + + template<> + template + float primitives::logsumexp(const T* x, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template float primitives::logsumexp(const float*, dim_t); + template float primitives::logsumexp(const float16_t*, dim_t); + template float primitives::logsumexp(const bfloat16_t*, dim_t); + + template<> + template + void primitives::sin(const T* x, T* y, dim_t size) { + const auto aclType = cann::getACLType(); + + aclFormat format = ACL_FORMAT_ND; + cann::CannPreparation prepare; + const ctranslate2::Shape x_y_shape = {size}; + const dim_t size_in_bytes = size * sizeof(T); + + cann_prepare_inputdesc(prepare, aclType, x_y_shape.size(), x_y_shape.data(), format); + cann_prepare_outputdesc(prepare, aclType, x_y_shape.size(), x_y_shape.data(), format); + + cann_prepare_inputbuffer(prepare, const_cast(x), size_in_bytes); + cann_prepare_outputbuffer(prepare, y, size_in_bytes); + + ACL_CALL(aclopCompileAndExecute("Sin", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + + template void primitives::sin(const float*, float*, dim_t); + template void primitives::sin(const float16_t*, float16_t*, dim_t); + template<> + template<> + void primitives::sin(const bfloat16_t*, bfloat16_t*, dim_t) { + THROW_RUNTIME_ERROR("Unsupported ACL type: bfloat16_t"); + } + + template<> + template + void primitives::cos(const T* x, T* y, dim_t size) { + const auto aclType = cann::getACLType(); + aclFormat format = ACL_FORMAT_ND; + cann::CannPreparation prepare; + + const ctranslate2::Shape shape_1d = {size}; + cann_prepare_inputdesc(prepare, aclType, shape_1d.size(), shape_1d.data(), format); + cann_prepare_outputdesc(prepare, aclType, shape_1d.size(), shape_1d.data(), format); + + const auto in_out_size_in_bytes = size*sizeof(T); + cann_prepare_inputbuffer(prepare, const_cast(x), in_out_size_in_bytes); + cann_prepare_outputbuffer(prepare, y, in_out_size_in_bytes); + + ACL_CALL(aclopCompileAndExecute("Cos", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + + template void primitives::cos(const float*, float*, dim_t); + template void primitives::cos(const float16_t*, float16_t*, dim_t); + template<> + template<> + void primitives::cos(const bfloat16_t*, bfloat16_t*, dim_t) { + THROW_RUNTIME_ERROR("Unsupported ACL type: bfloat16_t"); + } + + + template<> + template + void primitives::tanh(const T* x, T* y, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template void primitives::tanh(const float*, float*, dim_t); + template void primitives::tanh(const float16_t*, float16_t*, dim_t); + template void primitives::tanh(const bfloat16_t*, bfloat16_t*, dim_t); + + template<> + template<> + void primitives::exp(const float* x, float* y, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template<> + template + void primitives::log(const T* x, T* y, dim_t size) { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template void primitives::log(const float*, float*, dim_t); + template void primitives::log(const float16_t*, float16_t*, dim_t); + template void primitives::log(const bfloat16_t*, bfloat16_t*, dim_t); + + template<> + template + void cross_device_primitives::copy(const T* x, T* y, dim_t size) { + const auto size_in_bytes = size * sizeof (T); + ACL_CALL(aclrtMemcpy(y, size_in_bytes, x, size_in_bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + } + + template<> + template + void cross_device_primitives::copy(const T* x, T* y, dim_t size) { + const auto size_in_bytes = size * sizeof (T); + ACL_CALL(aclrtMemcpy(y, size_in_bytes, x, size_in_bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + } + +#define DECLARE_IMPL(T) \ + template T \ + primitives::at(const T* x, dim_t index); \ + template void \ + primitives::fill(T* x, T a, dim_t size); \ + template void \ + primitives::zero(T* x, dim_t size, bool synchronous); \ + template void \ + primitives::strided_fill(T* x, T a, dim_t inc_x, dim_t size); \ + template void \ + primitives::indexed_fill(T*, T, const int32_t*, dim_t, dim_t); \ + template void \ + primitives::copy(const T* x, T* y, dim_t size); \ + template T \ + primitives::sum(const T* array, dim_t size); \ + template dim_t \ + primitives::max_element(const T* array, dim_t size); \ + template T \ + primitives::max(const T* array, dim_t size); \ + template void \ + primitives::add(T a, const T* x, T* y, dim_t size); \ + template void \ + primitives::add(const T* a, const T* b, T* c, dim_t size); \ + template void \ + primitives::add_batch_broadcast(const T* a, const T* b, \ + T* c, dim_t a_size, \ + dim_t b_size, \ + bool synchronous); \ + template void \ + primitives::add_depth_broadcast(const T* a, const T* b, \ + T* c, dim_t a_size, dim_t b_size); \ + template void \ + primitives::sub(const T* a, const T* b, T* c, dim_t size); \ + template void \ + primitives::min(T a, const T* x, T* y, dim_t size); \ + template void \ + primitives::min(const T* a, const T* b, T* c, dim_t size); \ + template void \ + primitives::max(T a, const T* x, T* y, dim_t size); \ + template void \ + primitives::max(const T* a, const T* b, T* c, dim_t size); \ + template void \ + primitives::mul(T a, const T* x, T* y, dim_t size); \ + template void \ + primitives::mul(const T* a, const T* b, T* c, dim_t size); \ + template void \ + primitives::mul_batch_broadcast(const T* a, const T* b, \ + T* c, dim_t a_size, dim_t b_size); \ + template void \ + primitives::penalize_previous_tokens(T*, \ + const T*, \ + const int32_t*, \ + T, \ + dim_t, \ + dim_t, \ + dim_t); \ + template void \ + primitives::transpose_2d(const T* a, \ + const dim_t* dims, \ + T* b); \ + template void \ + primitives::transpose_3d(const T* a, \ + const dim_t* dims, \ + const dim_t* perm, \ + T* b); \ + template void \ + primitives::transpose_4d(const T* a, \ + const dim_t* dims, \ + const dim_t* perm, \ + T* b); \ + template void \ + cross_device_primitives::copy(const T*, T*, dim_t); \ + template void \ + cross_device_primitives::copy(const T*, T*, dim_t); + DECLARE_ALL_TYPES(DECLARE_IMPL) +} diff --git a/src/cann/utils.cc b/src/cann/utils.cc new file mode 100644 index 000000000..6ac8929e3 --- /dev/null +++ b/src/cann/utils.cc @@ -0,0 +1,156 @@ +#include "./utils.h" +#include "ctranslate2/devices.h" +#include "ctranslate2/types.h" +#include + +namespace ctranslate2 { + namespace cann { + CannPreparation::CannPreparation() { + _opAttr = aclopCreateAttr(); + if(_opAttr == nullptr) + THROW_RUNTIME_ERROR("aclopCreateAttr out of memory"); + } + + CannPreparation::~CannPreparation() { + for (auto desc : _inputDesc) + aclDestroyTensorDesc(desc); + + for (auto desc : _outputDesc) + aclDestroyTensorDesc(desc); + + try { + for (auto buf : _inputBuffers) + ACL_CALL(aclDestroyDataBuffer(buf)); + + for (auto buf : _outputBuffers) + ACL_CALL(aclDestroyDataBuffer(buf)); + } + catch (const std::exception& e) { + // Log that CannPreparation deallocation failed and swallow the exception + spdlog::error(e.what()); + } + aclopDestroyAttr(_opAttr); + } + + template + aclDataType getACLType() { + return ACL_DT_UNDEFINED; + } + + #define GET_ACL_TYPE(ctranslate2_type, cann_type) \ + template <> \ + aclDataType getACLType() { \ + return cann_type; \ + } + + GET_ACL_TYPE(int8_t, ACL_INT8); + GET_ACL_TYPE(int16_t, ACL_INT16); + GET_ACL_TYPE(int32_t, ACL_INT32); + GET_ACL_TYPE(int64_t, ACL_INT64); + GET_ACL_TYPE(uint8_t, ACL_UINT8); + GET_ACL_TYPE(uint16_t, ACL_UINT16); + GET_ACL_TYPE(uint32_t, ACL_UINT32); + GET_ACL_TYPE(uint64_t, ACL_UINT64); + GET_ACL_TYPE(float, ACL_FLOAT); + GET_ACL_TYPE(float16_t, ACL_FLOAT16); + GET_ACL_TYPE(bfloat16_t, ACL_BF16); + GET_ACL_TYPE(double, ACL_DOUBLE); + GET_ACL_TYPE(bool, ACL_BOOL); + } +} + +namespace ctranslate2 { + namespace cann { + void AclDeviceEnabler::acl_initialize() { + static std::once_flag initialize_flag; + std::call_once(initialize_flag, [](){ + spdlog::debug("aclInit"); + ACL_CALL(aclInit(nullptr)); + }); + } + + struct AclInitializer { + AclInitializer() { + AclDeviceEnabler::acl_initialize(); + } + }; + // Initializes AscendCL. It can be called only once per execution. + // aclInit must be called before the use of AscendCL APIs. + const static AclInitializer aclInitializer; + + void AclDeviceEnabler::acl_finalize() { + if(!_finalize_enabled) + return; + + static std::once_flag finalize_flag; + std::call_once(finalize_flag, [](){ + try { + // Make sure all streams are destroyed before AscendCL deinitializing + AclrtStreamHandler::destroy_steams(); + spdlog::debug("aclFinalize"); + ACL_CALL(aclFinalize()); + } + catch (const std::exception& e) { + // acl_finalize is called in ReplicaPool dtor + // Log that deinitialization failed and swallow the exception + spdlog::error(e.what()); + } + }); + } + + void AclDeviceEnabler::set_allow_acl_finalize(const bool enable) { + _finalize_enabled = enable; + } + + void AclrtStreamHandler::store(const int32_t device, const aclrtStream stream) { + const std::lock_guard lock(_mutex); + _streams.emplace_back(device, stream); + } + + void AclrtStreamHandler::destroy_steams() { + const std::lock_guard lock(_mutex); + for(const auto& [device, stream] : _streams) { + ScopedDeviceSetter scoped_device_setter(Device::CANN, device); + // Synchronize stream to ensure that all tasks in the stream have completed before destroying it + ACL_CALL(aclrtSynchronizeStream(stream)); + ACL_CALL(aclrtDestroyStream(stream)); + } + } + + class AclrtStream { + public: + AclrtStream() { + ACL_CALL(aclrtGetDevice(&_device)); + ACL_CALL(aclrtCreateStream(&_stream)); + AclrtStreamHandler::store(_device, _stream); + } + + // Place the stream destruction responsibility to AclrtStreamHandler to ensure + // streams are destroyed just before AscendCL deinitialization + + aclrtStream get() const { + return _stream; + } + + private: + int32_t _device; + aclrtStream _stream; + }; + + // We create one aclrt handle per host thread. The handle is destroyed when the thread exits. + aclrtStream get_aclrt_stream() { + static thread_local AclrtStream aclrt_stream; + return aclrt_stream.get(); + } + + uint32_t get_npu_count() { + uint32_t npu_count = 0; + ACL_CALL(aclrtGetDeviceCount(&npu_count)); + return npu_count; + } + + bool has_npu() { + return get_npu_count() > 0; + } + } +} diff --git a/src/cann/utils.h b/src/cann/utils.h new file mode 100644 index 000000000..cdcd9c496 --- /dev/null +++ b/src/cann/utils.h @@ -0,0 +1,127 @@ +#pragma once + +#include "./cann_inc.h" +#include "ctranslate2/utils.h" +#include + +namespace ctranslate2 { + namespace cann { +#define ACL_CALL(ans) \ + { \ + aclError code = (ans); \ + if (code != ACL_SUCCESS) \ + THROW_RUNTIME_ERROR("CANN failed with error " + std::to_string(code)); \ + } + } +} + +namespace ctranslate2 { + namespace cann { + struct CannPreparation { + CannPreparation(); + ~CannPreparation(); + + std::vector _inputBuffers; + std::vector _outputBuffers; + std::vector _inputDesc; + std::vector _outputDesc; + aclopAttr* _opAttr; + }; + + template + inline void cann_prepare_inputdesc(CannPreparation& prepare, Args... args) { + auto _rPtr = aclCreateTensorDesc(args...); + if (_rPtr == nullptr) + THROW_RUNTIME_ERROR("aclCreateTensorDesc run failed"); + else + prepare._inputDesc.emplace_back(_rPtr); + } + + template + inline void cann_prepare_outputdesc(CannPreparation& prepare, Args... args) { + auto _rPtr = aclCreateTensorDesc(args...); + if (_rPtr == nullptr) + THROW_RUNTIME_ERROR("aclCreateTensorDesc run failed"); + else + prepare._outputDesc.emplace_back(_rPtr); + } + + template + inline void cann_prepare_inputbuffer(CannPreparation& prepare, Args... args) { + auto _rPtr = aclCreateDataBuffer(args...); + if (_rPtr == nullptr) + THROW_RUNTIME_ERROR("aclCreateDataBuffer run failed"); + else + prepare._inputBuffers.emplace_back(_rPtr); + } + + template + inline void cann_prepare_outputbuffer(CannPreparation& prepare, Args... args) { + auto _rPtr = aclCreateDataBuffer(args...); + if (_rPtr == nullptr) + THROW_RUNTIME_ERROR("aclCreateDataBuffer run failed"); + else + prepare._outputBuffers.emplace_back(_rPtr); + } + + template + inline void cann_const_inputdesc(CannPreparation& prepare, size_t index, Args... args) { + auto _rPtr = aclSetTensorConst(prepare._inputDesc[index], args...); + if (_rPtr != ACL_SUCCESS) + THROW_RUNTIME_ERROR("aclSetTensorConst run failed"); + } + + inline void cann_prepare_inputdescname(CannPreparation& prepare, size_t index, const char* name) { + aclSetTensorDescName(prepare._inputDesc[index], name); + } + + inline void cann_prepare_outputdescname(CannPreparation& prepare, size_t index, const char* name) { + aclSetTensorDescName(prepare._outputDesc[index], name); + } + + inline void cann_tensor_placement(CannPreparation& prepare, size_t index, aclMemType memType) { + auto _rPtr = aclSetTensorPlaceMent(prepare._inputDesc[index], memType); + if (_rPtr != ACL_SUCCESS) + THROW_RUNTIME_ERROR("aclSetTensorDescName run failed"); + } + + template + aclDataType getACLType(); + } +} + +namespace ctranslate2 { + namespace cann { + class AclDeviceEnabler { + public: + static void acl_initialize(); + static void acl_finalize(); + static void set_allow_acl_finalize(bool enable); + + AclDeviceEnabler() = delete; + private: + static inline bool _finalize_enabled = true; // False value only during testing + }; + } +} + +namespace ctranslate2 { + namespace cann { + aclrtStream get_aclrt_stream(); + uint32_t get_npu_count(); + bool has_npu(); + } +} + +namespace ctranslate2 { + namespace cann { + class AclrtStreamHandler { + public: + static void store(int32_t device, aclrtStream stream); + static void destroy_steams(); + private: + inline static std::mutex _mutex; + inline static std::vector> _streams; + }; + } +} diff --git a/src/cpu/primitives.cc b/src/cpu/primitives.cc index 49bee5b7b..827567b4c 100644 --- a/src/cpu/primitives.cc +++ b/src/cpu/primitives.cc @@ -47,6 +47,12 @@ namespace ctranslate2 { std::fill(x, x + size, a); } + template<> + template + void primitives::zero(T* x, dim_t size, bool) { + std::fill(x, x + size, 0); + } + template<> template void primitives::strided_fill(T* x, T a, dim_t inc_x, dim_t size) { @@ -57,7 +63,7 @@ namespace ctranslate2 { template<> template - void primitives::indexed_fill(T* x, T a, const int32_t* indices, dim_t num_indices) { + void primitives::indexed_fill(T* x, T a, const int32_t* indices, dim_t num_indices, dim_t) { for (dim_t i = 0; i < num_indices; ++i) x[indices[i]] = a; } @@ -148,7 +154,7 @@ namespace ctranslate2 { template<> template void primitives::add_batch_broadcast(const T* a, const T* b, T* c, - dim_t a_size, dim_t b_size) { + dim_t a_size, dim_t b_size, bool) { const dim_t iter_size = b_size / a_size; cpu::parallel_for(0, iter_size, 1, [&](dim_t begin, dim_t end) { for (dim_t i = begin; i < end; ++i) { @@ -1079,16 +1085,17 @@ namespace ctranslate2 { } } - #define DECLARE_IMPL(T) \ template T \ primitives::at(const T* x, dim_t index); \ template void \ primitives::fill(T* x, T a, dim_t size); \ template void \ + primitives::zero(T* x, dim_t size, bool); \ + template void \ primitives::strided_fill(T* x, T a, dim_t inc_x, dim_t size); \ template void \ - primitives::indexed_fill(T*, T, const int32_t*, dim_t); \ + primitives::indexed_fill(T*, T, const int32_t*, dim_t, dim_t); \ template void \ primitives::copy(const T* x, T* y, dim_t size); \ template T \ @@ -1101,7 +1108,7 @@ namespace ctranslate2 { primitives::add(T a, const T* x, T* y, dim_t size); \ template void \ primitives::add_batch_broadcast(const T* a, const T* b, T* c, \ - dim_t a_size, dim_t b_size); \ + dim_t a_size, dim_t b_size, bool); \ template void \ primitives::add_depth_broadcast(const T* a, const T* b, T* c, \ dim_t a_size, dim_t b_size); \ diff --git a/src/cuda/primitives.cu b/src/cuda/primitives.cu index 149e10dbb..637148e88 100644 --- a/src/cuda/primitives.cu +++ b/src/cuda/primitives.cu @@ -23,6 +23,12 @@ namespace ctranslate2 { THRUST_CALL(thrust::fill, x, x + size, a); } + template<> + template + void primitives::zero(T* x, dim_t size, bool) { + THRUST_CALL(thrust::fill, x, x + size, T(0)); + } + template<> template void primitives::strided_fill(T* x, T a, dim_t inc_x, dim_t size) { @@ -34,7 +40,7 @@ namespace ctranslate2 { template<> template - void primitives::indexed_fill(T* x, T a, const int32_t* indices, dim_t num_indices) { + void primitives::indexed_fill(T* x, T a, const int32_t* indices, dim_t num_indices, dim_t) { auto element_it = thrust::device_pointer_cast(cuda::device_cast(x)); auto index_it = thrust::device_pointer_cast(indices); auto it = thrust::make_permutation_iterator(element_it, index_it); @@ -125,7 +131,7 @@ namespace ctranslate2 { template<> template void primitives::add_batch_broadcast(const T* a, const T* b, T* c, - dim_t a_size, dim_t b_size) { + dim_t a_size, dim_t b_size, bool) { cuda::binary_transform(a, b, c, b_size, cuda::plus>(), cuda::repeat_vec(a_size)); @@ -716,9 +722,11 @@ namespace ctranslate2 { template void \ primitives::fill(T* x, T a, dim_t size); \ template void \ + primitives::zero(T* x, dim_t size, bool); \ + template void \ primitives::strided_fill(T* x, T a, dim_t inc_x, dim_t size); \ template void \ - primitives::indexed_fill(T*, T, const int32_t*, dim_t); \ + primitives::indexed_fill(T*, T, const int32_t*, dim_t, dim_t); \ template void \ primitives::copy(const T* x, T* y, dim_t size); \ template T \ @@ -733,7 +741,8 @@ namespace ctranslate2 { primitives::add(const T* a, const T* b, T* c, dim_t size); \ template void \ primitives::add_batch_broadcast(const T* a, const T* b, \ - T* c, dim_t a_size, dim_t b_size); \ + T* c, dim_t a_size, dim_t b_size, \ + bool); \ template void \ primitives::add_depth_broadcast(const T* a, const T* b, \ T* c, dim_t a_size, dim_t b_size); \ diff --git a/src/decoding_utils.cc b/src/decoding_utils.cc index fed4670d3..ccd89cb98 100644 --- a/src/decoding_utils.cc +++ b/src/decoding_utils.cc @@ -29,7 +29,8 @@ namespace ctranslate2 { primitives::indexed_fill(_logits.data(), static_cast(_disable_value), flat_indices.data(), - num_indices)); + num_indices, + _logits.size())); _flat_indices.clear(); } diff --git a/src/device_dispatch.h b/src/device_dispatch.h index 3106c7d55..48a1b7b66 100644 --- a/src/device_dispatch.h +++ b/src/device_dispatch.h @@ -18,16 +18,26 @@ } #define SINGLE_ARG(...) __VA_ARGS__ -#ifndef CT2_WITH_CUDA + +#ifdef CT2_WITH_CUDA +# define DEVICE_DISPATCH(DEVICE, STMTS) \ + switch (DEVICE) { \ + UNSUPPORTED_DEVICE_CASE(Device::CANN) \ + DEVICE_CASE(Device::CUDA, SINGLE_ARG(STMTS)) \ + DEVICE_CASE(Device::CPU, SINGLE_ARG(STMTS)) \ + } +#elif CT2_WITH_CANN # define DEVICE_DISPATCH(DEVICE, STMTS) \ switch (DEVICE) { \ UNSUPPORTED_DEVICE_CASE(Device::CUDA) \ + DEVICE_CASE(Device::CANN, SINGLE_ARG(STMTS)) \ DEVICE_CASE(Device::CPU, SINGLE_ARG(STMTS)) \ } #else # define DEVICE_DISPATCH(DEVICE, STMTS) \ switch (DEVICE) { \ - DEVICE_CASE(Device::CUDA, SINGLE_ARG(STMTS)) \ + UNSUPPORTED_DEVICE_CASE(Device::CUDA) \ + UNSUPPORTED_DEVICE_CASE(Device::CANN) \ DEVICE_CASE(Device::CPU, SINGLE_ARG(STMTS)) \ } #endif diff --git a/src/devices.cc b/src/devices.cc index 3822cc3c3..bc16d49ea 100644 --- a/src/devices.cc +++ b/src/devices.cc @@ -2,6 +2,8 @@ #ifdef CT2_WITH_CUDA # include "cuda/utils.h" +#elif CT2_WITH_CANN +# include "cann/utils.h" #endif #include "device_dispatch.h" @@ -9,6 +11,12 @@ namespace ctranslate2 { Device str_to_device(const std::string& device) { + if (device == "cann" || device == "CANN") +#ifdef CT2_WITH_CANN + return Device::CANN; +#else + throw std::invalid_argument("This CTranslate2 package was not compiled with CANN support"); +#endif if (device == "cuda" || device == "CUDA") #ifdef CT2_WITH_CUDA return Device::CUDA; @@ -20,6 +28,8 @@ namespace ctranslate2 { if (device == "auto" || device == "AUTO") #ifdef CT2_WITH_CUDA return cuda::has_gpu() ? Device::CUDA : Device::CPU; +#elif CT2_WITH_CANN + return cann::has_npu() ? Device::CANN : Device::CPU; #else return Device::CPU; #endif @@ -32,6 +42,8 @@ namespace ctranslate2 { return "cuda"; case Device::CPU: return "cpu"; + case Device::CANN: + return "cann"; } return ""; } @@ -50,6 +62,12 @@ namespace ctranslate2 { #endif case Device::CPU: return 1; + case Device::CANN: +#ifdef CT2_WITH_CANN + return cann::get_npu_count(); +#else + return 0; +#endif } return 0; } @@ -66,11 +84,23 @@ namespace ctranslate2 { template<> void set_device_index(int index) { - if (index != 0) - throw std::invalid_argument("Invalid CPU device index: " + std::to_string(index)); + if (index != 0) { + throw std::invalid_argument("Invalid CPU device index: " + std::to_string(index)); + } + } +#ifdef CT2_WITH_CANN + template<> + int get_device_index() { + int index = 0; + ACL_CALL(aclrtGetDevice(&index)); + return index; } -#ifdef CT2_WITH_CUDA + template<> + void set_device_index(int index) { + ACL_CALL(aclrtSetDevice(index)); + } +#elif CT2_WITH_CUDA template<> int get_device_index() { int index = 0; @@ -100,17 +130,42 @@ namespace ctranslate2 { const ScopedDeviceSetter scoped_device_setter(device, index); cudaDeviceSynchronize(); } +#elif CT2_WITH_CANN + if (device == Device::CANN) { + const ScopedDeviceSetter scoped_device_setter(device, index); + ACL_CALL(aclrtSynchronizeDevice()); + } #else (void)device; (void)index; #endif } + void initialize_device() { +#ifdef CT2_WITH_CANN + // Initializes AscendCL. It can be called only once per execution. + // aclInit must be called before the use of AscendCL APIs. + cann::AclDeviceEnabler::acl_initialize(); +#endif + } + + void finalize_device() { +#ifdef CT2_WITH_CANN + // This API needs to be called explicitly to deinitialize AscendCL + // after all NPU tasks have completed and before the app process exits. + cann::AclDeviceEnabler::acl_finalize(); +#endif + } + void synchronize_stream(Device device) { #ifdef CT2_WITH_CUDA if (device == Device::CUDA) { cudaStreamSynchronize(cuda::get_cuda_stream()); } +#elif CT2_WITH_CANN + if (device == Device::CANN) { + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } #else (void)device; #endif diff --git a/src/dispatch.h b/src/dispatch.h index 3eed748c3..17ad406a8 100644 --- a/src/dispatch.h +++ b/src/dispatch.h @@ -11,17 +11,7 @@ default: \ throw std::invalid_argument(NAME " only supports float types"); \ - -#ifndef CT2_WITH_CUDA - -# define DEVICE_AND_FLOAT_DISPATCH(NAME, DEVICE, TYPE, STMTS) \ - switch (TYPE) { \ - TYPE_CASE(float, DEVICE_DISPATCH(DEVICE, (STMTS))) \ - NON_FLOAT_CASE(NAME) \ - } - -#else - +#ifdef CT2_WITH_CUDA # define DEVICE_AND_FLOAT_DISPATCH(NAME, DEVICE, TYPE, STMTS) \ switch (TYPE) { \ TYPE_CASE(float, DEVICE_DISPATCH(DEVICE, (STMTS))) \ @@ -39,5 +29,29 @@ }) \ NON_FLOAT_CASE(NAME) \ } - +#elif CT2_WITH_CANN +# define DEVICE_AND_FLOAT_DISPATCH(NAME, DEVICE, TYPE, STMTS) \ + switch (TYPE) { \ + TYPE_CASE(float, DEVICE_DISPATCH(DEVICE, (STMTS))) \ + TYPE_CASE(float16_t, { \ + if (DEVICE != Device::CANN) \ + throw std::invalid_argument("FP16 " NAME " is only supported on NPU"); \ + constexpr Device D = Device::CANN; \ + (STMTS); \ + }) \ + TYPE_CASE(bfloat16_t, { \ + if (DEVICE != Device::CANN) \ + throw std::invalid_argument("BF16 " NAME " is only supported on NPU"); \ + constexpr Device D = Device::CANN; \ + (STMTS); \ + }) \ + NON_FLOAT_CASE(NAME) \ + } +#else +# define DEVICE_AND_FLOAT_DISPATCH(NAME, DEVICE, TYPE, STMTS) \ + switch (TYPE) { \ + TYPE_CASE(float, DEVICE_DISPATCH(DEVICE, (STMTS))) \ + NON_FLOAT_CASE(NAME) \ + } #endif + diff --git a/src/layers/decoder.cc b/src/layers/decoder.cc index 046c581f5..cf0d34e13 100644 --- a/src/layers/decoder.cc +++ b/src/layers/decoder.cc @@ -113,9 +113,7 @@ namespace ctranslate2 { extra_bias = std::make_unique(Shape{new_output_size}, output_type(), _device); DEVICE_AND_TYPE_DISPATCH( _device, output_type(), - primitives::fill(extra_bias->data(), - T(0), - new_output_size - padding_size)); + primitives::zero(extra_bias->data(), new_output_size - padding_size)); DEVICE_AND_TYPE_DISPATCH( _device, output_type(), primitives::fill(extra_bias->data() + new_output_size - padding_size, diff --git a/src/models/model.cc b/src/models/model.cc index 0672494ff..b045567dd 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -422,6 +422,10 @@ namespace ctranslate2 { } { +#if CT2_WITH_CANN + // set_device_index has to always be called at least once before get_device_index + set_device_index(device, device_index); +#endif // Check that the device and device index are valid. ScopedDeviceSetter(device, device_index); } diff --git a/src/ops/alibi_add_npu.cc b/src/ops/alibi_add_npu.cc new file mode 100644 index 000000000..ff5133fa1 --- /dev/null +++ b/src/ops/alibi_add_npu.cc @@ -0,0 +1,28 @@ +#include "ctranslate2/ops/alibi_add.h" + +#include "type_dispatch.h" + +namespace ctranslate2 { + namespace ops { + + template + void AlibiAdd::compute(const StorageView& input, + const StorageView& alibi, + const dim_t alibi_offset, + StorageView& output) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + +#define DECLARE_IMPL(T) \ + template void \ + AlibiAdd::compute(const StorageView& input, \ + const StorageView& alibi, \ + const dim_t alibi_offset, \ + StorageView& output) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/src/ops/bias_add_npu.cc b/src/ops/bias_add_npu.cc new file mode 100644 index 000000000..e9cb74daa --- /dev/null +++ b/src/ops/bias_add_npu.cc @@ -0,0 +1,33 @@ +#include "ctranslate2/ops/bias_add.h" + +#include "type_dispatch.h" + +namespace ctranslate2 { + namespace ops { + + template + void BiasAdd::compute(const StorageView& value, + const StorageView& bias, + StorageView& output) const { + primitives::add_batch_broadcast(bias.data(), + value.data(), + output.data(), + bias.size(), + value.size(), + _activation_type == nullptr); // if no activation, then synchronize stream here + if (_activation_type) + get_activation_op(*_activation_type)(output, output); + } + +#define DECLARE_IMPL(T) \ + template void \ + BiasAdd::compute(const StorageView& value, \ + const StorageView& bias, \ + StorageView& output) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/src/ops/concat_split_slide_npu.cc b/src/ops/concat_split_slide_npu.cc new file mode 100644 index 000000000..0631795e6 --- /dev/null +++ b/src/ops/concat_split_slide_npu.cc @@ -0,0 +1,181 @@ +#include "ctranslate2/ops/concat.h" +#include "ctranslate2/ops/split.h" +#include "ctranslate2/ops/slide.h" +#include "type_dispatch.h" +#include "../cann/utils.h" + +namespace ctranslate2 { + namespace ops { + + template + void Concat::compute(const std::vector& inputs, + StorageView& output) const { + // Tensors' descriptors have to be set in the order mentioned in the documentation. + // For operators whose input is a list (such as concat) it is needed to set the name of each input tensor in the + // same order of the tensor set. + + // prepare types + using axis_type = decltype(_axis); + static_assert(std::is_same_v || std::is_same_v); + const auto axis_acl_type = cann::getACLType(); + const auto in_out_acl_type = cann::getACLType(); + + aclFormat format = ACL_FORMAT_ND; + + cann::CannPreparation prepare; + + // input axis + constexpr char const* axis_label = "concat_dim"; + cann_prepare_inputdesc(prepare, axis_acl_type, 0, nullptr, format); // handle axis as scalar + cann_const_inputdesc(prepare, 0, const_cast(&_axis), sizeof(axis_type)); // axis has to be set to const + // axis_label is the first in the list of the descriptor names + constexpr short axis_label_index = 0; + cann_prepare_inputdescname(prepare, axis_label_index, axis_label); + cann_prepare_inputbuffer(prepare, const_cast(&_axis), sizeof(axis_type)); + + // input tensors + static const std::string desc_prefix = "x"; + for(size_t i=0; ishape().size(), inputs[i]->shape().data(), format); + const auto descriptor_label = desc_prefix + std::to_string(i); + cann_prepare_inputdescname(prepare, i + 1, descriptor_label.c_str()); // first element is already populated by axis_label + cann_prepare_inputbuffer(prepare, const_cast(inputs[i]->data()), inputs[i]->size_in_bytes()); + } + + // output + cann_prepare_outputdesc(prepare, in_out_acl_type, output.shape().size(), output.shape().data(), format); + cann_prepare_outputbuffer(prepare, output.data(), output.size_in_bytes()); + + // attribute is optional in Concat + // ACL_CALL(aclopSetAttrInt(prepare.opAttr_, "N", inputs.size())); + + ACL_CALL(aclopCompileAndExecute("Concat", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + + /** + * Creates the input descriptors and buffers that the CANN "Split" operator needs to run correctly. + */ + template + void prepare_split_inputs(const StorageView& input, + const int32_t axis, + ctranslate2::cann::CannPreparation& prepare) { + const aclFormat format = ACL_FORMAT_ND; + + // input: split_dim. The CANN documentation for the "Split" operator specifies that 'split_dim' should be passed + // after 'x', but in reality it should be first. 'split_dim' is a scalar, so according to the documentation we + // need to specify its number of dimensions as 0. + cann_prepare_inputdesc(prepare, ACL_INT32, 0, nullptr, format); + + // input: x + cann_prepare_inputdesc(prepare, cann::getACLType(), input.shape().size(), input.shape().data(), format); + + auto split_dim_sv = StorageView(axis, Device::CANN); + cann_prepare_inputbuffer(prepare, split_dim_sv.data(), sizeof(int32_t)); + cann_prepare_inputbuffer(prepare, const_cast(input.data()), input.size_in_bytes()); + } + + /** + * Creates the input descriptors and buffers that the CANN "SplitV" operator needs to run correctly. + */ + template + void prepare_splitv_inputs(const StorageView& input, + const int32_t axis, + const std::vector& size_splits, + ctranslate2::cann::CannPreparation& prepare) { + static_assert(std::is_same_v); + const aclFormat format = ACL_FORMAT_ND; + + // input: x + cann_prepare_inputdesc(prepare, cann::getACLType(), input.shape().size(), input.shape().data(), format); + + // input: size_splits + const Shape size_splits_shape = {static_cast(size_splits.size())}; + cann_prepare_inputdesc(prepare, ACL_INT64, size_splits_shape.size(), size_splits_shape.data(), format); + + // input: split_dim. This is a scalar, so according to the documentation we need to specify its number of + // dimensions as 0. + cann_prepare_inputdesc(prepare, ACL_INT32, 0, nullptr, format); + + cann_prepare_inputbuffer(prepare, const_cast(input.data()), input.size_in_bytes()); + cann_prepare_inputbuffer(prepare, const_cast(size_splits.data()), size_splits.size()*sizeof(dim_t)); + auto split_dim_sv = StorageView(axis, Device::CANN); + cann_prepare_inputbuffer(prepare, split_dim_sv.data(), sizeof(int32_t)); + } + + template + void Split::compute(const StorageView& input, + std::vector& outputs) const { + ctranslate2::cann::CannPreparation prepare; + const int32_t axis = _axis < 0 ? input.rank() + _axis : _axis; + std::string op_name; + + if (_split.empty()) { + op_name = "Split"; + prepare_split_inputs(input, axis, prepare); + } else { + op_name = "SplitV"; + prepare_splitv_inputs(input, axis, _split, prepare); + } + + ACL_CALL(aclopSetAttrInt(prepare._opAttr, "num_split", outputs.size())); + + // output: y + const std::string desc_prefix = "y"; + std::string descriptor_label; + const aclFormat format = ACL_FORMAT_ND; + const aclDataType aclType = cann::getACLType(); + for(size_t i=0; ishape().size(), outputs[i]->shape().data(), format); + descriptor_label = desc_prefix + std::to_string(i); + cann_prepare_outputdescname(prepare, i, descriptor_label.c_str()); + cann_prepare_outputbuffer(prepare, outputs[i]->data(), outputs[i]->size_in_bytes()); + } + + ACL_CALL(aclopCompileAndExecute(op_name.c_str(), + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + + template + void Slide::compute(const StorageView& input, StorageView& output, const dim_t& index) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + +#define DECLARE_IMPL(T) \ + template void \ + Concat::compute(const std::vector& inputs, \ + StorageView& output) const; \ + template void \ + Split::compute(const StorageView& input, \ + std::vector& outputs) const; \ + template void \ + Slide::compute(const StorageView& input, \ + StorageView& output, \ + const dim_t& index) const; + + DECLARE_ALL_TYPES(DECLARE_IMPL) + + } +} diff --git a/src/ops/conv1d_npu.cc b/src/ops/conv1d_npu.cc new file mode 100644 index 000000000..c1cb3ec0e --- /dev/null +++ b/src/ops/conv1d_npu.cc @@ -0,0 +1,26 @@ +#include "ctranslate2/ops/conv1d.h" + +namespace ctranslate2 { + namespace ops { + + template + void Conv1D::compute(const StorageView& input, + const StorageView& weight, + const StorageView* bias, + StorageView& output) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + +#define DECLARE_IMPL(T) \ + template void \ + Conv1D::compute(const StorageView& input, \ + const StorageView& weight, \ + const StorageView* bias, \ + StorageView& output) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/src/ops/dequantize_npu.cc b/src/ops/dequantize_npu.cc new file mode 100644 index 000000000..5e5012911 --- /dev/null +++ b/src/ops/dequantize_npu.cc @@ -0,0 +1,63 @@ +#include "ctranslate2/ops/dequantize.h" + +namespace ctranslate2 { + namespace ops { + + template + void Dequantize::dequantize(const StorageView& input, + const StorageView& scale, + StorageView& output) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template void + Dequantize::dequantize(const StorageView&, + const StorageView&, + StorageView&) const; + template void + Dequantize::dequantize(const StorageView&, + const StorageView&, + StorageView&) const; + template void + Dequantize::dequantize(const StorageView&, + const StorageView&, + StorageView&) const; + + template + void Dequantize::dequantize_gemm_output(const StorageView& c, + const StorageView& a_scale, + const StorageView& b_scale, + const bool transpose_a, + const bool transpose_b, + const StorageView* bias, + StorageView& y) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template void + Dequantize::dequantize_gemm_output(const StorageView&, + const StorageView&, + const StorageView&, + const bool, + const bool, + const StorageView*, + StorageView&) const; + template void + Dequantize::dequantize_gemm_output(const StorageView&, + const StorageView&, + const StorageView&, + const bool, + const bool, + const StorageView*, + StorageView&) const; + template void + Dequantize::dequantize_gemm_output(const StorageView&, + const StorageView&, + const StorageView&, + const bool, + const bool, + const StorageView*, + StorageView&) const; + + } +} diff --git a/src/ops/gather_npu.cc b/src/ops/gather_npu.cc new file mode 100644 index 000000000..746f13fe2 --- /dev/null +++ b/src/ops/gather_npu.cc @@ -0,0 +1,65 @@ +#include "ctranslate2/ops/gather.h" +#include "type_dispatch.h" +#include "../cann/utils.h" + +namespace ctranslate2 { + namespace ops { + + template + void Gather::compute(const StorageView& data, + const StorageView& input, + const dim_t axis, + const dim_t batch_dims, + StorageView& output) const { + // CANN expects int32_t indices according to documentation + using indiceType = int32_t; + const indiceType* indices = input.data(); + const T* src = data.data(); + T* dst = output.data(); + + if (axis == batch_dims) { + const aclDataType aclType = cann::getACLType(); + + ctranslate2::cann::CannPreparation prepare; + + cann_prepare_inputdesc(prepare, aclType, data.shape().size(), data.shape().data(), ACL_FORMAT_ND); + cann_prepare_inputdesc(prepare, ACL_INT32, input.shape().size(), input.shape().data(), ACL_FORMAT_ND); + cann_prepare_outputdesc(prepare, aclType, output.shape().size(), output.shape().data(), ACL_FORMAT_ND); + + cann_prepare_inputbuffer(prepare, const_cast(src), data.size()*sizeof(T)); + cann_prepare_inputbuffer(prepare, const_cast(indices), input.size()*sizeof(indiceType)); + cann_prepare_outputbuffer(prepare, dst, output.size()*sizeof(T)); + + ACL_CALL(aclopSetAttrBool(prepare._opAttr, "validate_indices", true)); + ACL_CALL(aclopSetAttrInt(prepare._opAttr, "batch_dims", static_cast(batch_dims))); + + ACL_CALL(aclopCompileAndExecute("Gather", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } else { + throw std::invalid_argument("Gather only supports indexing the first non batch dimension"); + } + } + +#define DECLARE_IMPL(T) \ + template void \ + Gather::compute(const StorageView& data, \ + const StorageView& input, \ + const dim_t axis, \ + const dim_t batch_dims, \ + StorageView& output) const; + + DECLARE_ALL_TYPES(DECLARE_IMPL) + + } +} diff --git a/src/ops/gemm.cc b/src/ops/gemm.cc index e6ff87f9d..587e71ad2 100644 --- a/src/ops/gemm.cc +++ b/src/ops/gemm.cc @@ -2,7 +2,11 @@ #include "ctranslate2/ops/bias_add.h" +#ifdef CT2_WITH_CANN +#include "../cann/utils.h" +#endif #include "dispatch.h" +#include namespace ctranslate2 { namespace ops { @@ -15,7 +19,15 @@ namespace ctranslate2 { bias_add_op(x, *bias, x); } else if (activation_type) { get_activation_op(*activation_type)(x, x); +#ifdef CT2_WITH_CANN + } else if (x.device() == Device::CANN) { + // We rely on BiasAdd and activation operators to synchronize the stream in the general case, else we + // synchronize the stream manually. + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); } +#else + } +#endif } @@ -82,11 +94,37 @@ namespace ctranslate2 { const dim_t ldc = n; { + if (_trans_a) { + // In this case, the shape of vector 'c' might be computed incorrectly. + // See relevant upstream issue: https://github.com/OpenNMT/CTranslate2/issues/1583 + spdlog::warn("GEMM: Input vector 'a' is in transpose form. " + "The shape of vector 'c' might be computed incorrectly."); + } Shape output_shape(a.shape()); output_shape[output_shape.size() - 1] = n; c.resize(std::move(output_shape)); } +#ifdef CT2_WITH_CANN + if constexpr (D == Device::CANN) { + if(!_alpha_sv && !_beta_sv) { + // Avoid repeated allocation of NPU memory for 'alpha' and 'beta' across GEMM operator calls. + // Allocate NPU memory only once per Gemm object. + _alpha_sv = std::make_shared(_alpha, Device::CANN); + _beta_sv = std::make_shared(_beta, Device::CANN); + } + primitives::gemm_alpha_beta_in_device(_a_is_packed, _b_is_packed, + _trans_a, _trans_b, + m, n, k, + _alpha_sv->data(), + a.data(), lda, + b.data(), ldb, + _beta_sv->data(), + c.data(), ldc, + a_shift_compensation ? a_shift_compensation->data() : nullptr); + return; + } +#endif primitives::gemm(_a_is_packed, _b_is_packed, _trans_a, _trans_b, m, n, k, diff --git a/src/ops/gumbel_max_npu.cc b/src/ops/gumbel_max_npu.cc new file mode 100644 index 000000000..14c01ab84 --- /dev/null +++ b/src/ops/gumbel_max_npu.cc @@ -0,0 +1,23 @@ +#include "ctranslate2/ops/gumbel_max.h" + +#include "type_dispatch.h" + +namespace ctranslate2 { + namespace ops { + + template + void GumbelMax::add_gumbel_noise(const StorageView& x, StorageView& y) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + +#define DECLARE_IMPL(T) \ + template void \ + GumbelMax::add_gumbel_noise(const StorageView& x, \ + StorageView& y) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/src/ops/layer_norm_npu.cc b/src/ops/layer_norm_npu.cc new file mode 100644 index 000000000..584ab9470 --- /dev/null +++ b/src/ops/layer_norm_npu.cc @@ -0,0 +1,80 @@ +#include "ctranslate2/ops/layer_norm.h" +#include "../cann/utils.h" + +namespace ctranslate2 { + namespace ops { + + template + void LayerNorm::compute(const StorageView* beta, + const StorageView* gamma, + const StorageView& input, + const dim_t axis, + const dim_t, + const dim_t, + const dim_t, + StorageView& output) const { + // This case is not implemented on CUDA, so we do not support it as well + if (axis != input.rank() - 1 || !beta || !gamma) + throw std::invalid_argument("Generalized LayerNorm is currently not implemented on CANN"); + + aclFormat format = ACL_FORMAT_ND; + const aclDataType aclType = cann::getACLType(); + + ctranslate2::cann::CannPreparation prepare; + + // LayerNorm in CANN also provides 'mean' and 'variance' as outputs, which we do not need. + // But also we cannot instruct the operator to avoid calculating them. So we must declare them. + const dim_t mean_variance_length = input.size()/input.shape()[axis]; + StorageView mean({mean_variance_length}, DataType::FLOAT32, D); + StorageView variance({mean_variance_length}, DataType::FLOAT32, D); + + cann_prepare_inputdesc(prepare, aclType, input.shape().size(), input.shape().data(), format); + cann_prepare_inputdesc(prepare, aclType, gamma->shape().size(), gamma->shape().data(), format); + cann_prepare_inputdesc(prepare, aclType, beta->shape().size(), beta->shape().data(), format); + cann_prepare_outputdesc(prepare, aclType, output.shape().size(), output.shape().data(), format); + cann_prepare_outputdesc(prepare, ACL_FLOAT, mean.shape().size(), mean.shape().data(), format); + cann_prepare_outputdesc(prepare, ACL_FLOAT, variance.shape().size(), variance.shape().data(), format); + + ACL_CALL(aclopSetAttrInt(prepare._opAttr, "begin_norm_axis", axis)); + ACL_CALL(aclopSetAttrInt(prepare._opAttr, "begin_params_axis", axis)); + ACL_CALL(aclopSetAttrFloat(prepare._opAttr, "epsilon", _epsilon)); + + cann_prepare_inputbuffer(prepare, const_cast(input.data()), input.size()*sizeof(T)); + cann_prepare_inputbuffer(prepare, const_cast(gamma->data()), gamma->size()*sizeof(T)); + cann_prepare_inputbuffer(prepare, const_cast(beta->data()), beta->size()*sizeof(T)); + cann_prepare_outputbuffer(prepare, output.data(), output.size()*sizeof(T)); + cann_prepare_outputbuffer(prepare, mean.data(), mean.size()*sizeof(float)); + cann_prepare_outputbuffer(prepare, variance.data(), variance.size()*sizeof(float)); + + ACL_CALL(aclopCompileAndExecute("LayerNorm", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + +#define DECLARE_IMPL(T) \ + template void \ + LayerNorm::compute(const StorageView* beta, \ + const StorageView* gamma, \ + const StorageView& input, \ + const dim_t axis, \ + const dim_t outer_size, \ + const dim_t axis_size, \ + const dim_t inner_size, \ + StorageView& output) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/src/ops/matmul.cc b/src/ops/matmul.cc index 6d9a3fe04..5f155670a 100644 --- a/src/ops/matmul.cc +++ b/src/ops/matmul.cc @@ -1,87 +1,142 @@ #include "ctranslate2/ops/matmul.h" - +#ifdef CT2_WITH_CANN +#include "../cann/utils.h" +#include "ctranslate2/ops/mul.h" +#endif #include "dispatch.h" namespace ctranslate2 { - namespace ops { + namespace ops { - MatMul::MatMul(bool trans_a, bool trans_b, float alpha) - : _trans_a(trans_a) - , _trans_b(trans_b) - , _alpha(alpha) { - } + MatMul::MatMul(bool trans_a, bool trans_b, float alpha) + : _trans_a(trans_a) + , _trans_b(trans_b) + , _alpha(alpha) { + } - void MatMul::operator()(const StorageView& a, const StorageView& b, StorageView& c) const { - PROFILE("MatMul"); - DEVICE_AND_FLOAT_DISPATCH("MatMul", a.device(), a.dtype(), (compute(a, b, c))); - } + void MatMul::operator()(const StorageView& a, const StorageView& b, StorageView& c) const { + PROFILE("MatMul"); + DEVICE_AND_FLOAT_DISPATCH("MatMul", a.device(), a.dtype(), (compute(a, b, c))); + } - template - void MatMul::compute(const StorageView& a, const StorageView& b, StorageView& c) const { - dim_t m, k_a; - if (_trans_a) { - m = a.dim(-1); - k_a = a.dim(-2); - } else { - m = a.dim(-2); - k_a = a.dim(-1); - } - - dim_t k_b, n; - if (_trans_b) { - n = b.dim(-2); - k_b = b.dim(-1); - } else { - n = b.dim(-1); - k_b = b.dim(-2); - } - - if (k_a != k_b) - throw std::invalid_argument("MatMul: k dimension of inputs a and b should match"); - - const dim_t k = k_a; - const dim_t a_batch_size = a.size() / (m * k); - const dim_t b_batch_size = b.size() / (k * n); - - if (a_batch_size != b_batch_size) - throw std::invalid_argument("MatMul: batch dimension of inputs a and b should match"); - - { - Shape output_shape(a.shape()); - output_shape[output_shape.size() - 1] = n; - output_shape[output_shape.size() - 2] = m; - c.resize(std::move(output_shape)); - } - - const dim_t batch_size = a_batch_size; - const dim_t lda = _trans_a ? m : k; - const dim_t ldb = _trans_b ? k : n; - const dim_t ldc = n; - const float beta = 0; - - if (batch_size > 1) { - const dim_t stridea = m * k; - const dim_t strideb = k * n; - const dim_t stridec = m * n; - primitives::gemm_batch_strided(_trans_a, _trans_b, - m, n, k, - _alpha, - a.data(), lda, stridea, - b.data(), ldb, strideb, - beta, - c.data(), ldc, stridec, - batch_size); - } else { - primitives::gemm(/*a_is_packed=*/false, /*b_is_packed=*/false, - _trans_a, _trans_b, - m, n, k, - _alpha, - a.data(), lda, - b.data(), ldb, - beta, - c.data(), ldc); - } - } + template + void MatMul::handleNonCann(const StorageView &a, const StorageView &b, StorageView &c, dim_t m, dim_t n, + const dim_t k, + const dim_t a_batch_size) const { + const dim_t batch_size = a_batch_size; + const dim_t lda = _trans_a ? m : k; + const dim_t ldb = _trans_b ? k : n; + const dim_t ldc = n; + const float beta = 0; + + if (batch_size > 1) { + const dim_t stridea = m * k; + const dim_t strideb = k * n; + const dim_t stridec = m * n; + primitives::gemm_batch_strided(_trans_a, _trans_b, + m, n, k, + _alpha, + a.data(), lda, stridea, + b.data(), ldb, strideb, + beta, + c.data(), ldc, stridec, + batch_size); + } else { + primitives::gemm(/*a_is_packed=*/false, /*b_is_packed=*/false, + _trans_a, _trans_b, + m, n, k, + _alpha, + a.data(), lda, + b.data(), ldb, + beta, + c.data(), ldc); + } + } + + template + void MatMul::handleCann(const StorageView &a, const StorageView &b, StorageView &c) const { +#ifdef CT2_WITH_CANN + const auto aclType = cann::getACLType(); + aclFormat format = ACL_FORMAT_ND; + + cann::CannPreparation prepare; + + ACL_CALL(aclopSetAttrBool(prepare._opAttr, "adj_x1", _trans_a)); + ACL_CALL(aclopSetAttrBool(prepare._opAttr, "adj_x2", _trans_b)); + + cann_prepare_inputdesc(prepare, aclType, a.shape().size(), a.shape().data(), format); + cann_prepare_inputdesc(prepare, aclType, b.shape().size(), b.shape().data(), format); + cann_prepare_outputdesc(prepare, aclType, c.shape().size(), c.shape().data(), format); - } + cann_prepare_inputbuffer(prepare, const_cast(a.data()), a.size_in_bytes()); + cann_prepare_inputbuffer(prepare, const_cast(b.data()), b.size_in_bytes()); + cann_prepare_outputbuffer(prepare, c.data(), c.size_in_bytes()); + + ACL_CALL(aclopCompileAndExecute("BatchMatMul", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + ctranslate2::cann::get_aclrt_stream())); + if (_alpha != 1) { + // The Mul operator will synchronize the stream. + ops::Mul()(c, StorageView(_alpha), c); + } else { + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } +#endif + } + + template + void MatMul::compute(const StorageView& a, const StorageView& b, StorageView& c) const { + dim_t m, k_a; + if (_trans_a) { + m = a.dim(-1); + k_a = a.dim(-2); + } else { + m = a.dim(-2); + k_a = a.dim(-1); + } + + dim_t k_b, n; + if (_trans_b) { + n = b.dim(-2); + k_b = b.dim(-1); + } else { + n = b.dim(-1); + k_b = b.dim(-2); + } + + if (k_a != k_b) + throw std::invalid_argument("MatMul: k dimension of inputs a and b should match"); + + const dim_t k = k_a; + const dim_t a_batch_size = a.size() / (m * k); + const dim_t b_batch_size = b.size() / (k * n); + + if (a_batch_size != b_batch_size) + throw std::invalid_argument("MatMul: batch dimension of inputs a and b should match"); + + { + Shape output_shape(a.shape()); + output_shape[output_shape.size() - 1] = n; + output_shape[output_shape.size() - 2] = m; + c.resize(std::move(output_shape)); + } + + if constexpr (D == Device::CANN) { + // Employ BatchMatMul directly instead of Gemm since it is significantly faster in CANN + handleCann(a, b, c); + } + else { + handleNonCann(a, b, c, m, n, k, a_batch_size); + } + } + } } diff --git a/src/ops/mean_npu.cc b/src/ops/mean_npu.cc new file mode 100644 index 000000000..f67bbc8dd --- /dev/null +++ b/src/ops/mean_npu.cc @@ -0,0 +1,30 @@ +#include "ctranslate2/ops/mean.h" + +#include "type_dispatch.h" + +namespace ctranslate2 { + namespace ops { + + template + void Mean::compute(const StorageView& input, + const dim_t outer_size, + const dim_t axis_size, + const dim_t inner_size, + StorageView& output) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + +#define DECLARE_IMPL(T) \ + template void \ + Mean::compute(const StorageView& input, \ + const dim_t outer_size, \ + const dim_t axis_size, \ + const dim_t inner_size, \ + StorageView& output) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/src/ops/mul.cc b/src/ops/mul.cc index efb256967..490a79958 100644 --- a/src/ops/mul.cc +++ b/src/ops/mul.cc @@ -1,5 +1,7 @@ #include "ctranslate2/ops/mul.h" - +#ifdef CT2_WITH_CANN +#include "../cann/utils.h" +#endif #include "dispatch.h" namespace ctranslate2 { @@ -10,5 +12,74 @@ namespace ctranslate2 { DEVICE_AND_TYPE_DISPATCH(a.device(), a.dtype(), (compute(a, b, c))); } + template + void Mul::handleCann(const StorageView& a, const StorageView& b, StorageView& c) const { +#ifdef CT2_WITH_CANN + if (a.shape() != b.shape() || a.dtype() != b.dtype()) + throw std::invalid_argument("Mul: a and b have incompatible shapes or types"); + + const auto aclType = cann::getACLType(); + aclFormat format = ACL_FORMAT_ND; + + cann::CannPreparation prepare; + + cann_prepare_inputdesc(prepare, aclType, a.shape().size(), a.shape().data(), format); + cann_prepare_inputdesc(prepare, aclType, b.shape().size(), b.shape().data(), format); + cann_prepare_outputdesc(prepare, aclType, c.shape().size(), c.shape().data(), format); + + cann_prepare_inputbuffer(prepare, const_cast(a.data()), a.size_in_bytes()); + cann_prepare_inputbuffer(prepare, const_cast(b.data()), b.size_in_bytes()); + cann_prepare_outputbuffer(prepare, c.data(), c.size_in_bytes()); + + ACL_CALL(aclopCompileAndExecute("Mul", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); +#endif + } + + template + void Mul::handleCannScalar(const T scalar, const StorageView& a, StorageView& c) const { +#ifdef CT2_WITH_CANN + if (a.shape() != c.shape() || a.dtype() != c.dtype()) + throw std::invalid_argument("Muls: a and c have incompatible shapes or types"); + + const auto aclType = cann::getACLType(); + aclFormat format = ACL_FORMAT_ND; + + cann::CannPreparation prepare; + + // Note: CANN documentation on "scalar" value is ambiguous + ACL_CALL(aclopSetAttrFloat(prepare._opAttr, "value", static_cast(scalar))); + cann_prepare_inputdesc(prepare, aclType, a.shape().size(), a.shape().data(), format); + cann_prepare_outputdesc(prepare, aclType, c.shape().size(), c.shape().data(), format); + + cann_prepare_inputbuffer(prepare, const_cast(a.data()), a.size_in_bytes()); + cann_prepare_outputbuffer(prepare, c.data(), c.size_in_bytes()); + + ACL_CALL(aclopCompileAndExecute("Muls", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); +#endif + } } } diff --git a/src/ops/multinomial_npu.cc b/src/ops/multinomial_npu.cc new file mode 100644 index 000000000..85f6b5474 --- /dev/null +++ b/src/ops/multinomial_npu.cc @@ -0,0 +1,22 @@ +#include "ctranslate2/ops/multinomial.h" + +namespace ctranslate2 { + namespace ops { + + + template + void Multinomial::compute(const StorageView& input, StorageView& output) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + +#define DECLARE_IMPL(T) \ + template void \ + Multinomial::compute(const StorageView& input, \ + StorageView& output) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/src/ops/quantize_npu.cc b/src/ops/quantize_npu.cc new file mode 100644 index 000000000..5c720689c --- /dev/null +++ b/src/ops/quantize_npu.cc @@ -0,0 +1,27 @@ +#include "ctranslate2/ops/quantize.h" + +namespace ctranslate2 { + namespace ops { + + template + void Quantize::quantize(const StorageView& input, + StorageView& output, + StorageView& scale) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template void + Quantize::quantize(const StorageView&, + StorageView&, + StorageView&) const; + template void + Quantize::quantize(const StorageView&, + StorageView&, + StorageView&) const; + template void + Quantize::quantize(const StorageView&, + StorageView&, + StorageView&) const; + + } +} diff --git a/src/ops/rms_norm_npu.cc b/src/ops/rms_norm_npu.cc new file mode 100644 index 000000000..f8d788006 --- /dev/null +++ b/src/ops/rms_norm_npu.cc @@ -0,0 +1,24 @@ +#include "ctranslate2/ops/rms_norm.h" + +namespace ctranslate2 { + namespace ops { + + template + void RMSNorm::compute(const StorageView& gamma, + const StorageView& input, + StorageView& output) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + + } + +#define DECLARE_IMPL(T) \ + template void RMSNorm::compute(const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/src/ops/rotary_npu.cc b/src/ops/rotary_npu.cc new file mode 100644 index 000000000..2cc02eed0 --- /dev/null +++ b/src/ops/rotary_npu.cc @@ -0,0 +1,26 @@ +#include "ctranslate2/ops/rotary.h" + +namespace ctranslate2 { + namespace ops { + + template + void Rotary::compute(const StorageView& input, + const StorageView& sin, + const StorageView& cos, + StorageView& output) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + +#define DECLARE_IMPL(T) \ + template void \ + Rotary::compute(const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/src/ops/softmax_npu.cc b/src/ops/softmax_npu.cc new file mode 100644 index 000000000..6ecae7ff4 --- /dev/null +++ b/src/ops/softmax_npu.cc @@ -0,0 +1,98 @@ +#include "ctranslate2/ops/softmax.h" +#include "../cann/utils.h" + +namespace ctranslate2 { + namespace ops { + + template + void run_softmax(const StorageView& input, + StorageView& output, + bool log){ + const aclDataType aclType = cann::getACLType(); + if(aclType == ACL_BF16) + THROW_RUNTIME_ERROR("Unsupported ACL type: " + std::to_string(aclType)); + + ctranslate2::cann::CannPreparation prepare; + const aclFormat format = ACL_FORMAT_ND; + cann_prepare_inputdesc(prepare, aclType, input.shape().size(), input.shape().data(), format); + cann_prepare_outputdesc(prepare, aclType, output.shape().size(), output.shape().data(), format); + + cann_prepare_inputbuffer(prepare, const_cast(input.data()), input.size()*sizeof(T)); + cann_prepare_outputbuffer(prepare, output.data(), output.size()*sizeof(T)); + + std::string op_type = log ? "LogSoftmaxV2" : "SoftmaxV2"; + ACL_CALL(aclopCompileAndExecute(op_type.c_str(), + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + } + + template + void SoftMax::compute(const StorageView& input, + const StorageView* lengths, + StorageView& output) const { + if (!lengths) { + run_softmax(input, output, _log); + } else { + // todo reduce number of operator calls for this case in the future + dim_t batch_size = input.size() / input.dim(-1); + + std::vector lengths_vector = lengths->to_vector(); + int32_t current_length; + + // View 'input' and 'output' as 2D vectors with 'batch_size' number of rows. + auto input_2D = StorageView({batch_size, input.dim(-1)}, const_cast(input.data()), D); + auto output_2D = StorageView({batch_size, input.dim(-1)}, const_cast(output.data()), D); + + // keeps track of the indices of the current slice of the input/output StorageView + // {0, 0} corresponds to the 1st slice, {1, 0} corresponds to the 2nd slice etc... + std::vector indices(2, 0); + + for (size_t i=0; i(input_2D.index(indices)), D), + output_slice = StorageView({current_length}, const_cast(output_2D.index(indices)), D); + + run_softmax(input_slice, output_slice, _log); + + if (input_2D.dim(-1) == current_length) { + continue; + } + + // point at the end of the current slice + indices[1] = current_length; + + // fill the end of the current slice with zeros + StorageView({input_2D.dim(-1) - current_length}, const_cast(output_2D.index(indices)), D).zero(false); + + // point back at the start of the current slice for the next iteration + indices[1] = 0; + } + } + + // Synchronize stream only once in the end, not after operator call + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + +#define DECLARE_IMPL(T) \ + template void \ + SoftMax::compute(const StorageView& input, \ + const StorageView* lengths, \ + StorageView& output) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} + diff --git a/src/ops/tile_npu.cc b/src/ops/tile_npu.cc new file mode 100644 index 000000000..a2e97fbc2 --- /dev/null +++ b/src/ops/tile_npu.cc @@ -0,0 +1,25 @@ +#include "ctranslate2/ops/tile.h" +#include "type_dispatch.h" + +namespace ctranslate2 { + namespace ops { + + template + void Tile::compute(const StorageView& input, + const dim_t, + const dim_t inner_size, + StorageView& output) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + +#define DECLARE_IMPL(T) \ + template void \ + Tile::compute(const StorageView& input, \ + const dim_t outer_size, \ + const dim_t inner_size, \ + StorageView& output) const; + + DECLARE_ALL_TYPES(DECLARE_IMPL) + + } +} diff --git a/src/ops/topk_npu.cc b/src/ops/topk_npu.cc new file mode 100644 index 000000000..f06eef9c8 --- /dev/null +++ b/src/ops/topk_npu.cc @@ -0,0 +1,79 @@ +#include "ctranslate2/ops/topk.h" +#include "../cann/utils.h" + +namespace ctranslate2 { + namespace ops { + + template + void TopK::compute(const StorageView& x, + StorageView& values, + StorageView& indices) const { + static_assert(std::is_same_v); // indices have to be int32_t + // derive types + const auto in_out_acl_type = cann::getACLType(); + using k_type = int32_t; + const auto k_acl_type = cann::getACLType(); + const auto index_acl_type = k_acl_type; + aclFormat format = ACL_FORMAT_ND; + + cann::CannPreparation prepare; + + // x + cann_prepare_inputdesc(prepare, in_out_acl_type, x.shape().size(), x.shape().data(), format); + // k + cann_prepare_inputdesc(prepare, k_acl_type, 0, nullptr, format); // handle k as scalar + auto tmp_k = static_cast(_k); + cann_tensor_placement(prepare, 1, ACL_MEMTYPE_HOST); + //values + cann_prepare_outputdesc(prepare, in_out_acl_type, values.shape().size(), values.shape().data(), format); + // indices + cann_prepare_outputdesc(prepare, index_acl_type, indices.shape().size(), indices.shape().data(), format); + + // x + cann_prepare_inputbuffer(prepare, const_cast(x.data()), x.size_in_bytes()); + // k + cann_prepare_inputbuffer(prepare, const_cast(&tmp_k), sizeof(k_type)); + //values + cann_prepare_outputbuffer(prepare, const_cast(values.data()), values.size_in_bytes()); + // indices + cann_prepare_outputbuffer(prepare, const_cast(indices.data()), indices.size_in_bytes()); + + auto op_type= "TopKV2"; + // // TopK implementation + // const int16_t kMaxTopkSize = std::numeric_limits::max(), kMaxK = 8, kMinK = 0; + // if(x.size() > kMaxTopkSize && tmp_k > kMinK && tmp_k < kMaxK) { + // op_type = "TopK"; + // // "dim" usage is ambiguous. According to paddle: + // // axis is always equal to -1 + // // if (axis < 0) + // // axis += x.dims().size(); + // ACL_CALL(aclopSetAttrInt(prepare.opAttr_, "dim", x.rank()-1)); // axis == -1 always in TopK ctor! + // } + + ACL_CALL(aclopCompileAndExecute(op_type, + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); + } + +#define DECLARE_IMPL(T) \ + template void \ + TopK::compute(const StorageView& x, \ + StorageView& values, \ + StorageView& indices) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/src/ops/topp_mask_npu.cc b/src/ops/topp_mask_npu.cc new file mode 100644 index 000000000..714b5459d --- /dev/null +++ b/src/ops/topp_mask_npu.cc @@ -0,0 +1,28 @@ +#include "ctranslate2/ops/topp_mask.h" + +namespace ctranslate2 { + namespace ops { + + template + void TopPMask::compute(const StorageView& input, + const StorageView& probs, + StorageView& output) const { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + + template<> + dim_t TopPMask::max_num_classes() { + THROW_RUNTIME_ERROR("not implemented in CANN"); + } + +#define DECLARE_IMPL(T) \ + template void TopPMask::compute(const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} diff --git a/src/ops/transpose.cc b/src/ops/transpose.cc index 48bc031ab..ae9375005 100644 --- a/src/ops/transpose.cc +++ b/src/ops/transpose.cc @@ -1,10 +1,47 @@ #include "ctranslate2/ops/transpose.h" - +#ifdef CT2_WITH_CANN +#include "../cann/utils.h" +#endif #include "dispatch.h" namespace ctranslate2 { namespace ops { + // CANN can handle transpose using StorageView directly without the need of a primitive definition + template + void Transpose::handleCann(const StorageView& x, const std::vector& perm, StorageView& y) const { +#ifdef CT2_WITH_CANN + const auto a_b_type = cann::getACLType(); + const auto perm_type = cann::getACLType(); + aclFormat format = ACL_FORMAT_ND; + + cann::CannPreparation prepare; + + Shape::value_type perm_size = perm.size(); + Shape perm_shape = {perm_size}; + cann_prepare_inputdesc(prepare, a_b_type, x.shape().size(), x.shape().data(), format); + cann_prepare_inputdesc(prepare, perm_type, perm_shape.size(), perm_shape.data(), format); + cann_prepare_outputdesc(prepare, a_b_type, y.shape().size(), y.shape().data(), format); + cann_prepare_inputbuffer(prepare, const_cast(x.data()), x.size_in_bytes()); + cann_prepare_inputbuffer(prepare, const_cast(perm.data()), perm.size()*sizeof(dim_t)); + cann_prepare_outputbuffer(prepare, y.data(), y.size_in_bytes()); + + ACL_CALL(aclopCompileAndExecute("Transpose", + prepare._inputDesc.size(), + prepare._inputDesc.data(), + prepare._inputBuffers.data(), + prepare._outputDesc.size(), + prepare._outputDesc.data(), + prepare._outputBuffers.data(), + prepare._opAttr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + cann::get_aclrt_stream())); + ACL_CALL(aclrtSynchronizeStream(cann::get_aclrt_stream())); +#endif + } + Transpose::Transpose(const std::vector& perm) : _perm(perm) { } diff --git a/src/storage_view.cc b/src/storage_view.cc index 0cbdf25c2..a0dc14b15 100644 --- a/src/storage_view.cc +++ b/src/storage_view.cc @@ -159,6 +159,7 @@ namespace ctranslate2 { } StorageView& StorageView::reserve(dim_t size) { + // if new size is smaller than the allocated size, do not shrink and leave _allocated_size as is if (size <= _allocated_size) return *this; release(); @@ -174,6 +175,10 @@ namespace ctranslate2 { return _allocator; } +dim_t StorageView::size_in_bytes() const { + return size() * item_size(); +} + dim_t StorageView::item_size() const { dim_t size = 0; TYPE_DISPATCH(_dtype, size = sizeof (T)); @@ -366,6 +371,33 @@ namespace ctranslate2 { return data() + offset; } + template + const T* StorageView::index(const std::vector& indices) const { + const dim_t num_indices = indices.size(); + if (num_indices != rank()) + THROW_INVALID_ARGUMENT("number of indexed dimensions (" + + std::to_string(indices.size()) + + ") does not match the storage rank (" + + std::to_string(rank()) + ")"); + + dim_t offset = 0; + if (num_indices > 0) { + dim_t stride = 1; + auto index_it = std::crbegin(indices); + auto dim_it = std::crbegin(_shape); + for (; index_it != std::crend(indices); ++index_it, ++dim_it) { + offset += *index_it * stride; + stride *= *dim_it; + } + } + + if (offset >= _size) + THROW_INVALID_ARGUMENT("computed index is out of bounds (" + + std::to_string(offset) + " >= " + + std::to_string(_size) + ")"); + return data() + offset; + } + StorageView& StorageView::copy_from(const StorageView& other, bool synchronous) { resize_as(other); TYPE_DISPATCH(other._dtype, copy_from(other.data(), other._size, other._device, synchronous)); @@ -400,8 +432,8 @@ namespace ctranslate2 { return *this; } - StorageView& StorageView::zero() { - DEVICE_AND_TYPE_DISPATCH(_device, _dtype, primitives::fill(data(), T(0), _size)); + StorageView& StorageView::zero(bool synchronous) { + DEVICE_AND_TYPE_DISPATCH(_device, _dtype, primitives::zero(data(), _size, synchronous)); return *this; } @@ -417,6 +449,13 @@ namespace ctranslate2 { else cross_device_primitives::copy(data, this->data(), size); } else +#elif CT2_WITH_CANN + if (device != _device) { + if (device == Device::CANN) + cross_device_primitives::copy(data, this->data(), size); + else + cross_device_primitives::copy(data, this->data(), size); + } else #endif { DEVICE_DISPATCH(device, primitives::copy(data, this->data(), size)); @@ -495,6 +534,8 @@ namespace ctranslate2 { template T* StorageView::index(std::initializer_list indices); \ template const T* \ StorageView::index(std::initializer_list indices) const; \ + template const T* \ + StorageView::index(const std::vector& indices) const; \ template T \ StorageView::scalar_at(std::initializer_list indices) const; \ template StorageView& StorageView::view(T* data, Shape shape); \ diff --git a/src/thread_pool.cc b/src/thread_pool.cc index d0aad775b..317687c1e 100644 --- a/src/thread_pool.cc +++ b/src/thread_pool.cc @@ -1,6 +1,6 @@ #include "ctranslate2/thread_pool.h" - #include "ctranslate2/utils.h" +#include "ctranslate2/devices.h" namespace ctranslate2 { @@ -121,6 +121,7 @@ namespace ctranslate2 { finalize(); local_worker = nullptr; + finalize_device(); } diff --git a/src/types.cc b/src/types.cc index 2431bce66..5afb9b680 100644 --- a/src/types.cc +++ b/src/types.cc @@ -100,6 +100,15 @@ namespace ctranslate2 { #else (void)device_index; return false; +#endif + } + case Device::CANN: { +#ifdef CT2_WITH_CANN + (void)device_index; + return false; // may change later when operators support bfloat16 +#else + (void)device_index; + return false; #endif } default: @@ -118,6 +127,14 @@ namespace ctranslate2 { return false; #endif } + case Device::CANN: { +#ifdef CT2_WITH_CANN + return true; +#else + (void)device_index; + return false; +#endif + } default: return false; } @@ -140,6 +157,13 @@ namespace ctranslate2 { #else (void)device_index; return false; +#endif + case Device::CANN: +#ifdef CT2_WITH_CANN + return true; +#else + (void)device_index; + return false; #endif case Device::CPU: return cpu::has_gemm_backend(ComputeType::INT8); @@ -351,6 +375,11 @@ namespace ctranslate2 { && cuda::gpu_has_fp16_tensor_cores(device_index)) return 8; } +#elif CT2_WITH_CANN + if (device == Device::CANN) { + if (compute_type == ComputeType::FLOAT16 || compute_type == ComputeType::BFLOAT16) + return 8; + } #else (void)compute_type; (void)device; diff --git a/src/utils.cc b/src/utils.cc index f0eb29509..22ff0b312 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -6,6 +6,8 @@ #ifdef CT2_WITH_CUDA # include "./cuda/utils.h" +#elif CT2_WITH_CANN +# include "./cann/utils.h" #endif #include @@ -67,6 +69,12 @@ namespace ctranslate2 { cuda::gpu_has_fp16_tensor_cores(i)); spdlog::info(" - Allow BF16: {}", mayiuse_bfloat16(Device::CUDA, i)); } +#elif CT2_WITH_CANN + spdlog::info("NPU:"); + spdlog::info(" - Number of NPU cores: {}", cann::get_npu_count()); + aclrtRunMode runMode; + ACL_CALL(aclrtGetRunMode(&runMode)); + spdlog::info(" - aclrtRunMode: {}", runMode == ACL_DEVICE ? "ACL_DEVICE" : "ACL_HOST"); #endif } @@ -74,6 +82,10 @@ namespace ctranslate2 { return get_device_count(Device::CUDA); } + int get_npu_count() { + return get_device_count(Device::CANN); + } + static inline size_t get_default_num_threads() { constexpr size_t default_num_threads = 4; const size_t max_num_threads = std::thread::hardware_concurrency(); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 283c49db2..e57708a72 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -32,4 +32,9 @@ target_link_libraries(benchmark_ops if(WITH_CUDA) target_link_libraries(benchmark_ops ${CUDA_LIBRARIES}) +elseif(WITH_CANN) + set(ASCEND_DIR /usr/local/Ascend) + set(ASCEND_CL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64) + set(ascendcl_lib ${ASCEND_CL_DIR}/libascendcl.so) + target_link_libraries(benchmark_ops ${ascendcl_lib}) endif() diff --git a/tests/benchmark_ops.cc b/tests/benchmark_ops.cc index df3f5d1ba..9f7bd51b2 100644 --- a/tests/benchmark_ops.cc +++ b/tests/benchmark_ops.cc @@ -123,7 +123,7 @@ int main(int argc, char* argv[]) { } std::string op = argv[1]; - Device device = std::string(argv[2]) == "cuda" ? Device::CUDA : Device::CPU; + Device device = ctranslate2::str_to_device(std::string(argv[2])); std::string dtype_str = argc > 3 ? argv[3] : "float32"; DataType dtype = DataType::FLOAT32; if (dtype_str == "int16") diff --git a/tests/benchmark_utils.h b/tests/benchmark_utils.h index 55ee77b0e..3a5567234 100644 --- a/tests/benchmark_utils.h +++ b/tests/benchmark_utils.h @@ -7,6 +7,10 @@ #ifdef CT2_WITH_CUDA # include # define SYNCHRONIZE cudaDeviceSynchronize() +#elif CT2_WITH_CANN +#include +// can also check failure ACL_CALL(..) +# define SYNCHRONIZE aclrtSynchronizeDevice() #else # define SYNCHRONIZE do {} while (false) #endif diff --git a/tests/layers_test.cc b/tests/layers_test.cc index 3a8e40958..989fcf14d 100644 --- a/tests/layers_test.cc +++ b/tests/layers_test.cc @@ -23,6 +23,8 @@ TEST(LayerTest, MakeRelativePositions2D) { TEST_P(LayerDeviceFPTest, Alibi) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = std::max(GetParam().error, float(1e-4)); @@ -56,6 +58,10 @@ TEST_P(LayerDeviceFPTest, Alibi) { -0.015625, -0.01171875, -0.0078125, -0.00390625, 0.0}); layers::Alibi alibi; + if(device == Device::CANN && dtype == DataType::BFLOAT16) { + ASSERT_RAISES(zero.to(dtype), std::runtime_error); + return; + } StorageView x = zero.to(dtype); alibi.apply(x); expect_storage_eq(x.to_float32(), expected, error); @@ -97,6 +103,8 @@ TEST_P(LayerDeviceFPTest, Alibi) { TEST_P(LayerDeviceFPTest, RotaryEmbedding) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; @@ -167,6 +175,10 @@ TEST_P(LayerDeviceFPTest, RotaryEmbedding) { { layers::RotaryEmbeddings rotary_embeddings; + if(device == Device::CANN && dtype == DataType::BFLOAT16) { + ASSERT_RAISES(input.to(dtype), std::runtime_error); + return; + } StorageView x = input.to(dtype); rotary_embeddings.apply(x, 2); expect_storage_eq(x.to_float32(), expected, error); @@ -180,6 +192,22 @@ TEST_P(LayerDeviceFPTest, RotaryEmbedding) { } } +TEST_P(LayerDeviceFPTest, PositionEncoder) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + // PositionEcoder is an abstract class, so we cannot instantiate a PositionEncoder object. + // Instead, we create a SinusoidalPositionEncoder which inherits from PositionEncoder. + layers::SinusoidalPositionEncoder position_encoder(6, dtype, device); + StorageView input({1, 1, 6}, std::vector{-0.2, -1.3, 0.1, -0.6, 2.0, 1.1}, device); + StorageView expected({1, 1, 6}, std::vector{0.641471, -1.29, 0.1001, -0.0596977, 2.99995, 2.1}, device); + input = input.to(dtype); + position_encoder(input); + expect_storage_eq(input.to_float32(), expected, error); +} + TEST(LayerTest, Padder) { const StorageView lengths({3}, std::vector{2, 3, 1}); const Padder padder(lengths, /*max_time=*/4); @@ -257,4 +285,10 @@ INSTANTIATE_TEST_SUITE_P(CUDA, LayerDeviceFPTest, FloatType{Device::CUDA, DataType::FLOAT16, 1e-2}, FloatType{Device::CUDA, DataType::BFLOAT16, 1e-2}), fp_test_name); +#elif CT2_WITH_CANN +INSTANTIATE_TEST_SUITE_P(CANN, LayerDeviceFPTest, + ::testing::Values(FloatType{Device::CANN, DataType::FLOAT32, 1e-5}, + FloatType{Device::CANN, DataType::FLOAT16, 1e-2}, + FloatType{Device::CANN, DataType::BFLOAT16, 1e-2}), + fp_test_name); #endif diff --git a/tests/ops_test.cc b/tests/ops_test.cc index 1ceae1dfd..8ea921948 100644 --- a/tests/ops_test.cc +++ b/tests/ops_test.cc @@ -1,7 +1,7 @@ -#include #include "test_utils.h" #include "ctranslate2/layers/attention.h" #include "ctranslate2/ops/ops.h" +#include "ctranslate2/devices.h" TEST(OpTest, Transpose1D) { StorageView x({4}, std::vector{1, 2, 3, 4}); @@ -151,16 +151,208 @@ TEST_P(OpDeviceTest, Add) { expect_storage_eq(c, expected); } -TEST_P(OpDeviceTest, AddScalar) { +TEST_P(OpDeviceTest, AddTensors2D) { Device device = GetParam(); - StorageView a({4}, std::vector{1, 2, 3, 4}, device); - StorageView b(static_cast(3)); - StorageView expected({4}, std::vector{4, 5, 6, 7}, device); + StorageView a({4, 2}, std::vector{1.69, 2, 3, 4, 17.42, 2, 3, 4.333}, device); + StorageView b({4, 2}, std::vector{2, 3, 4, 5, 1.42, 2, 3, 4.232}, device); + StorageView expected({4, 2}, std::vector{3.69, 5, 7, 9, 18.84, 4, 6, 8.565}, device); StorageView c(a.device()); ops::Add()(a, b, c); expect_storage_eq(c, expected); } +TEST_P(OpDeviceTest, AddLargeTensors2D) { + Device device = GetParam(); + StorageView a({4000, 200}, 42, device); + StorageView b({4000, 200}, 17, device); + StorageView expected({4000, 200}, 59, device); + StorageView c(DataType::INT32, a.device()); + ops::Add()(a, b, c); + expect_storage_eq(c, expected); +} + +TEST_P(OpDeviceTest, AddScalarTo1DTensor) { + Device device = GetParam(); + StorageView a({4}, std::vector{1.17, 2.17, 3.17, 4.17}, device); + StorageView b(static_cast(3.42)); + StorageView expected({4}, std::vector{4.59, 5.59, 6.59, 7.59}, device); + StorageView c(a.device()); + ops::Add()(a, b, c); + expect_storage_eq(c, expected); +} + +TEST_P(OpDeviceTest, AddScalarTo2DTensor) { + Device device = GetParam(); + StorageView a({4, 2}, float16_t(42), device); + StorageView b(float16_t(17)); + StorageView expected({4, 2}, float16_t(59), device); + StorageView c(DataType::FLOAT16, a.device()); + ops::Add()(a, b, c); + expect_storage_eq(c, expected); +} + +TEST_P(OpDeviceTest, AddScalarToLarge2DTensor) { + Device device = GetParam(); + StorageView a({4000, 200}, 42.f, device); + StorageView b(17.f); + StorageView expected({4000, 200}, 59.f, device); + StorageView c(DataType::FLOAT32, a.device()); + ops::Add()(a, b, c); + expect_storage_eq(c, expected); +} + +TEST_P(OpDeviceTest, BiasAdd) { + Device device = GetParam(); + StorageView value({4, 3}, std::vector{1, 2, 3, + 4, 5, 6, + 7, 8, 9, + 10, 11, 12}, device); + StorageView bias({3}, std::vector{1, 2, 3}, device); + StorageView expected({4, 3}, std::vector{2, 4, 6, + 5, 7, 9, + 8, 10, 12, + 11, 13, 15}, device); + StorageView c(device); + ops::BiasAdd()(value, bias, c); + expect_storage_eq(c, expected); +} + +TEST_P(OpDeviceTest, BiasAdd3D) { + Device device = GetParam(); + StorageView value({4, 3, 2}, std::vector{1, 2, + 3, 4, + 5, 6, + + 7, 8, + 9, 10.4265, + 11, 12, + + 2, 3, + 4, 5, + 6, 7.917, + + 8, 9, + 10, 11, + 12, 13}, device); + StorageView bias({2}, std::vector{1, 2}, device); + StorageView expected({4, 3, 2}, std::vector{2, 4, + 4, 6, + 6, 8, + + 8, 10, + 10, 12.4265, + 12, 14, + + 3, 5, + 5, 7, + 7, 9.917, + + 9, 11, + 11, 13, + 13, 15}, device); + StorageView c(device); + ops::BiasAdd()(value, bias, c); + expect_storage_eq(c, expected); +} + +TEST_P(OpDeviceTest, MatMul) { + Device device = GetParam(); + StorageView a({4, 2}, std::vector{1, 2, + 3, 4, + 5, 6, + 7, 8}, device); + StorageView b({2, 4}, std::vector{1, 3, 5, 7, + 2, 4, 6, 8}, device); + StorageView expected({4,4}, std::vector{5, 11, 17, 23, + 11, 25, 39, 53, + 17, 39, 61, 83, + 23, 53, 83, 113}, device); + StorageView c(a.device()); + ops::MatMul()(a, b, c); + expect_storage_eq(c, expected); +} + +TEST_P(OpDeviceTest, MatMulBatchLargerThanOne) { + // Invoke the case of batch_size > 1 + Device device = GetParam(); + StorageView a({2,2,3}, std::vector{1, 2, 3, 4, 5,6, 7, 8, 9, 10, 11, 12}, device); + StorageView b({2,3,2}, std::vector{13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, device); + StorageView expected({2, 2, 2}, std::vector{94, 100, 229, 244, 508, 532, 697, 730}, device); + StorageView c(DataType::FLOAT32, device); + ops::MatMul()(a, b, c); + expect_storage_eq(c, expected); +} + +TEST_P(OpDeviceTest, MatMulBatchWithIntScaling) { + // Invoke the case of batch_size > 1 + Device device = GetParam(); + StorageView a({2,2,3}, std::vector{1, 2, 3, 4, 5,6, 7, 8, 9, 10, 11, 12}, device); + StorageView b({2,3,2}, std::vector{13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, device); + StorageView expected({2, 2, 2}, std::vector{188, 200, 458, 488, 1016, 1064, 1394, 1460}, device); + StorageView c(DataType::FLOAT32, device); + ops::MatMul(false, false, 2)(a, b, c); + expect_storage_eq(c, expected); +} + +TEST_P(OpDeviceTest, MatMulBatchWithDecimalScaling) { + // Invoke the case of batch_size > 1 + Device device = GetParam(); + StorageView a({2,2,3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, device); + StorageView b({2,3,2}, std::vector{13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}, device); + StorageView expected({2, 2, 2}, std::vector{11.75, 12.5, 28.625, 30.5, 63.5, 66.5, 87.125, 91.25}, device); + StorageView c(DataType::FLOAT32, device); + ops::MatMul(false, false, 0.125)(a, b, c); + expect_storage_eq(c, expected); +} + +TEST_P(OpDeviceTest, MatMulTransposeA) { + Device device = GetParam(); + StorageView a({2, 4}, std::vector{1, 3, 5, 7, 2, 4, 6, 8}, device); + StorageView b({2, 4}, std::vector{1, 3, 5, 7, 2, 4, 6, 8}, device); + StorageView expected({4,4}, std::vector{5, 11, 17, 23, 11, 25, 39, 53, 17, 39, 61, 83, 23, 53, 83, 113}, device); + StorageView c(a.device()); + ops::MatMul op(true); + op(a, b, c); + expect_storage_eq(c, expected); +} + +TEST_P(OpDeviceTest, MatMulTransposeB) { + Device device = GetParam(); + StorageView a({4, 2}, std::vector{1, 2, + 3, 4, + 5, 6, + 7, 8}, device); + StorageView b({4, 2}, std::vector{1, 2, + 3, 4, + 5, 6, + 7, 8}, device); + StorageView expected({4,4}, std::vector{5, 11, 17, 23, 11, 25, 39, 53, 17, 39, 61, 83, 23, 53, 83, 113}, device); + StorageView c(a.device()); + ops::MatMul op(false, true); + op(a, b, c); + expect_storage_eq(c, expected); +} + +TEST_P(OpDeviceTest, MatMulTransposeBWithDecimalScaling) { + Device device = GetParam(); + StorageView a({4, 2}, std::vector{1, 2, + 3, 4, + 5, 6, + 7, 8}, device); + StorageView b({4, 2}, std::vector{1, 2, + 3, 4, + 5, 6, + 7, 8}, device); + StorageView expected({4,4}, std::vector{0.625, 1.375, 2.125, 2.875, + 1.375, 3.125, 4.875, 6.625, + 2.125, 4.875, 7.625, 10.375, + 2.875, 6.625, 10.375, 14.125}, device); + StorageView c(a.device()); + ops::MatMul op(false, true, 0.125); + op(a, b, c); + expect_storage_eq(c, expected); +} + TEST_P(OpDeviceTest, Mul) { Device device = GetParam(); StorageView a({4}, std::vector{1, 2, 3, 4}, device); @@ -183,6 +375,8 @@ TEST_P(OpDeviceTest, MulScalar) { TEST_P(OpDeviceTest, Sub) { Device device = GetParam(); + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; StorageView a({4}, std::vector{1, 2, 3, 4}, device); StorageView b({4}, std::vector{2, 3, 4, 5}, device); StorageView expected({4}, std::vector{-1, -1, -1, -1}, device); @@ -193,6 +387,8 @@ TEST_P(OpDeviceTest, Sub) { TEST_P(OpDeviceTest, TileFirstDim) { Device device = GetParam(); + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; StorageView input({2, 4}, std::vector{1, 2, 3, 4, 5, 6, 7, 8}, device); StorageView expected_output({4, 4}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}, device); StorageView output(device); @@ -202,6 +398,8 @@ TEST_P(OpDeviceTest, TileFirstDim) { TEST_P(OpDeviceTest, TileLastDim) { Device device = GetParam(); + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; StorageView input({2, 2}, std::vector{1, 2, 3, 4}, device); StorageView expected_output({2, 4}, std::vector{1, 2, 1, 2, 3, 4, 3, 4}, device); StorageView output(device); @@ -211,6 +409,8 @@ TEST_P(OpDeviceTest, TileLastDim) { TEST_P(OpDeviceTest, TileMiddleDim) { Device device = GetParam(); + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; StorageView input({2, 1, 3}, std::vector{1, 2, 3, 4, 5, 6}, device); StorageView expected_output({2, 3, 3}, std::vector{1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6}, device); StorageView output(device); @@ -227,6 +427,26 @@ TEST_P(OpDeviceTest, ConcatEmpty) { expect_storage_eq(x, a); } +TEST_P(OpDeviceTest, ConcatBasic) { + Device device = GetParam(); + StorageView a({2, 3}, std::vector{1, 2, 3, 4, 5, 6}, device); + StorageView b({2, 3}, std::vector{7, 8, 9, 10, 11, 12}, device); + StorageView c({4,3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, device); + StorageView x(device); + ops::Concat(0)({&a, &b}, x); + expect_storage_eq(x, c); +} + +TEST_P(OpDeviceTest, ConcatNegativeAxis) { + Device device = GetParam(); + StorageView a({2, 2, 2}, std::vector{1, 2, 2, 3, 4, 4, 5, 3}, device); + StorageView b({2, 2, 2}, std::vector{7, 4, 8, 4, 2, 10, 15, 11}, device); + StorageView c({2, 2, 4}, std::vector{1, 2, 7, 4, 2, 3, 8, 4, 4, 4, 2, 10, 5, 3, 15, 11}, device); + StorageView x(device); + ops::Concat(-1)({&a, &b}, x); + expect_storage_eq(x, c); +} + TEST_P(OpDeviceTest, ConcatSplitBatch) { Device device = GetParam(); StorageView a({2, 2}, std::vector{1, 2, 3, 4}, device); @@ -334,8 +554,142 @@ TEST_P(OpDeviceTest, SplitNoCopyEqualParts) { EXPECT_EQ(z.data(), x.data() + 4); } +TEST_P(OpDeviceTest, SplitAxis0EqualLengthParts2) { + Device device = GetParam(); + StorageView input({4, 2}, std::vector{1.42, -2.42, + 3.42, 4.42, + 5.42, 6.42, + 7.42, -8.42}, device); + StorageView output1(device); + StorageView output2(device); + ops::Split(0)(input, output1, output2); + StorageView expected_output1({2, 2}, std::vector{1.42, -2.42, 3.42, 4.42}, device); + StorageView expected_output2({2, 2}, std::vector{5.42, 6.42, 7.42, -8.42}, device); + expect_storage_eq(output1, expected_output1); + expect_storage_eq(output2, expected_output2); +} + +TEST_P(OpDeviceTest, SplitAxis0EqualLengthParts3) { + Device device = GetParam(); + StorageView input({6, 2}, std::vector{1.42, 17.24, + -2.42, 42.56, + 3.42, -101.6, + 4.42, -500.543, + 5.42, 6.42, + 7.42, -8.42}, device); + StorageView output1(device); + StorageView output2(device); + StorageView output3(device); + ops::Split(0)(input, output1, output2, output3); + StorageView expected_output1({2, 2}, std::vector{1.42, 17.24, -2.42, 42.56}, device); + StorageView expected_output2({2, 2}, std::vector{3.42, -101.6, 4.42, -500.543}, device); + StorageView expected_output3({2, 2}, std::vector{5.42, 6.42, 7.42, -8.42}, device); + expect_storage_eq(output1, expected_output1); + expect_storage_eq(output2, expected_output2); + expect_storage_eq(output3, expected_output3); +} + +TEST_P(OpDeviceTest, SplitAxis1EqualLengthParts2) { + Device device = GetParam(); + StorageView input({4, 2}, std::vector{1, 2, + 3, 4, + 5, 6, + 7, 8}, device); + StorageView output1(device); + StorageView output2(device); + ops::Split(1)(input, output1, output2); + StorageView expected_output1({4, 1}, std::vector{1, 3, 5, 7}, device); + StorageView expected_output2({4, 1}, std::vector{2, 4, 6, 8}, device); + expect_storage_eq(output1, expected_output1); + expect_storage_eq(output2, expected_output2); +} + +TEST_P(OpDeviceTest, SplitAxis1EqualLengthParts3) { + Device device = GetParam(); + StorageView input({2, 6}, std::vector{1.42, 17.24, -2.42, 42.56, 3.42, -101.6, + 4.42, -500.543, 5.42, 6.42, 7.42, -8.42}, device); + StorageView output1(device); + StorageView output2(device); + StorageView output3(device); + ops::Split(1)(input, output1, output2, output3); + StorageView expected_output1({2, 2}, std::vector{1.42, 17.24, 4.42, -500.543}, device); + StorageView expected_output2({2, 2}, std::vector{-2.42, 42.56, 5.42, 6.42}, device); + StorageView expected_output3({2, 2}, std::vector{3.42, -101.6, 7.42, -8.42}, device); + expect_storage_eq(output1, expected_output1); + expect_storage_eq(output2, expected_output2); + expect_storage_eq(output3, expected_output3); +} + +TEST_P(OpDeviceTest, Axis0NonEqualLengthParts2) { + Device device = GetParam(); + StorageView input({4, 2}, std::vector{1.42, -2.42, + 3.42, 4.42, + 5.42, 6.42, + 7.42, -8.42}, device); + StorageView output1(device); + StorageView output2(device); + ops::Split(0, {3, 1})(input, output1, output2); + StorageView expected_output1({3, 2}, std::vector{1.42, -2.42, 3.42, 4.42, 5.42, 6.42,}, device); + StorageView expected_output2({1, 2}, std::vector{7.42, -8.42}, device); + expect_storage_eq(output1, expected_output1); + expect_storage_eq(output2, expected_output2); +} + +TEST_P(OpDeviceTest, Axis0NonEqualLengthParts3) { + Device device = GetParam(); + StorageView input({6, 2}, std::vector{1.42, 17.24, + -2.42, 42.56, + 3.42, -101.6, + 4.42, -500.543, + 5.42, 6.42, + 7.42, -8.42}, device); + StorageView output1(device); + StorageView output2(device); + StorageView output3(device); + ops::Split(0, {1, 2, 3})(input, output1, output2, output3); + StorageView expected_output1({1, 2}, std::vector{1.42, 17.24}, device); + StorageView expected_output2({2, 2}, std::vector{-2.42, 42.56, 3.42, -101.6}, device); + StorageView expected_output3({3, 2}, std::vector{4.42, -500.543, 5.42, 6.42, 7.42, -8.42}, device); + expect_storage_eq(output1, expected_output1); + expect_storage_eq(output2, expected_output2); + expect_storage_eq(output3, expected_output3); +} + +TEST_P(OpDeviceTest, SplitAxis1NonEqualLengthParts2) { + Device device = GetParam(); + StorageView input({4, 3}, std::vector{1, 2, 3, + 4, 5, 6, + 7, 8, -5, + -6, -7, -8}, device); + StorageView output1(device); + StorageView output2(device); + ops::Split(1, {2, 1})(input, output1, output2); + StorageView expected_output1({4, 2}, std::vector{1, 2, 4, 5, 7, 8, -6, -7}, device); + StorageView expected_output2({4, 1}, std::vector{3, 6, -5, -8}, device); + expect_storage_eq(output1, expected_output1); + expect_storage_eq(output2, expected_output2); +} + +TEST_P(OpDeviceTest, SplitAxis1NonEqualLengthParts3) { + Device device = GetParam(); + StorageView input({2, 6}, std::vector{1.42, 17.24, -2.42, 42.56, 3.42, -101.6, + 4.42, -500.543, 5.42, 6.42, 7.42, -8.42}, device); + StorageView output1(device); + StorageView output2(device); + StorageView output3(device); + ops::Split(1, {1, 2, 3})(input, output1, output2, output3); + StorageView expected_output1({2, 1}, std::vector{1.42, 4.42}, device); + StorageView expected_output2({2, 2}, std::vector{17.24, -2.42, -500.543, 5.42}, device); + StorageView expected_output3({2, 3}, std::vector{42.56, 3.42, -101.6, 6.42, 7.42, -8.42}, device); + expect_storage_eq(output1, expected_output1); + expect_storage_eq(output2, expected_output2); + expect_storage_eq(output3, expected_output3); +} + TEST_P(OpDeviceTest, Mean) { const Device device = GetParam(); + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const StorageView input({2, 3, 2}, std::vector{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 @@ -498,9 +852,39 @@ TEST_P(OpDeviceFPTest, Gemm) { StorageView expected( {4, 4}, std::vector{3, 2, 2, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 2, 3}, device); ops::Gemm op(1.0, 1.0, false, false); + + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; y = y.to(dtype); op(a.to(dtype), b.to(dtype), y); expect_storage_eq(y.to_float32(), expected, error); + +}; + +TEST_P(OpDeviceTest, GemmFloat16) { + Device device = GetParam(); + if (!mayiuse_float16(device)) + return; + StorageView a({8, 8}, float16_t(1.6), device); + StorageView b({8, 8}, float16_t(1.4), device); + StorageView c({8, 8}, float16_t(0.75), device); + StorageView expected({8, 8}, float16_t(20.92), device); + ops::Gemm op(1.0, 4, false, false); + op(a, b, c); + expect_storage_eq(c, expected); +}; + +TEST_P(OpDeviceTest, GemmFloat32) { + Device device = GetParam(); + StorageView a( + {2, 2}, std::vector{1, 1, 1, 1}, device); + StorageView b(a); + StorageView expected( + {2, 2}, std::vector{3, 3, 3, 3}, device); + StorageView c({2, 2}, std::vector{1, 1, 1, 1}, device); + ops::Gemm op(1.0, 1.0, false, false); + op(a, b, c); + expect_storage_eq(c, expected); }; TEST_P(OpDeviceTest, GemmInt8) { @@ -530,10 +914,58 @@ TEST_P(OpDeviceTest, GemmInt8) { expect_storage_eq(c, expected); }; -TEST_P(OpDeviceFPTest, TopK) { +TEST_P(OpDeviceFPTest, GemmTransposeB) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + StorageView a({2, 3}, std::vector{1, 2, 3, + 4, 5, 6}, device); + StorageView b({4, 3}, std::vector{1, 2, 3, + 4, 1, 2, + 3, 4, 1, + 2, 3, 4}, device); + // check multiple constructors for c. + StorageView c({2, 4}, DataType::FLOAT32, device); + StorageView expected({2, 4}, std::vector{14, 12, 14, 20, + 32, 33, 38, 47}, device); + ops::Gemm op(1.0, 0, false, true); + + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + c = c.to(dtype); + op(a.to(dtype), b.to(dtype), c); + expect_storage_eq(c.to_float32(), expected, error); +}; + +TEST_P(OpDeviceFPTest, TopKBasic) { const Device device = GetParam().device; const DataType dtype = GetParam().dtype; const float error = GetParam().error; + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + const int k = 3; + StorageView input({1, 12}, std::vector{1, 2, 98, 1, 1, 99, 3, 1, 3, 96, 4, 1}, device); + StorageView expected_values({1, 3}, std::vector{99, 98, 96}, device); + StorageView expected_indices({1, 3}, std::vector{5, 2, 9}, device); + StorageView values(dtype, device); + StorageView indices(expected_indices.dtype(), device); + ops::TopK op(k); + op(input.to(dtype), values, indices); + expect_storage_eq(values.to_float32(), expected_values, error); + expect_storage_eq(indices, expected_indices); +} + +TEST_P(OpDeviceFPTest, TopK) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + float error = GetParam().error; + if(device == Device::CANN) { + if(dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + else if(dtype == DataType::FLOAT32) { + error = 3.907e-4; // FLOAT32 case does not comply with predefined error value + } + } const int k = 3; StorageView input({2, 6}, std::vector{0.1, -0.5, 2.0, 0.0, 0.2, 0.6, 1.0, 1.1, 0.2, 0.3, -0.2, 0.0}, device); StorageView expected_values({2, 3}, std::vector{2.0, 0.6, 0.2, 1.1, 1.0, 0.3}, device); @@ -556,13 +988,19 @@ TEST_P(OpDeviceTest, TopKVariableDepth) { StorageView values(expected_values.dtype(), device); StorageView indices(expected_indices.dtype(), device); op(input, values, indices); - expect_storage_eq(values, expected_values); + { + const float cann_error = 3.907e-4; // FLOAT32 case presents error + device == Device::CANN ? expect_storage_eq(values, expected_values, cann_error) : expect_storage_eq(values, expected_values); + } expect_storage_eq(indices, expected_indices); StorageView input2({2, 4}, std::vector{0.1, 2.0, 0.2, 0.6, 1.0, 1.1, 0.2, 0.3}, device); StorageView expected_values2({2, 3}, std::vector{2.0, 0.6, 0.2, 1.1, 1.0, 0.3}, device); StorageView expected_indices2({2, 3}, std::vector{1, 3, 2, 1, 0, 3}, device); op(input2, values, indices); - expect_storage_eq(values, expected_values2); + { + const float cann_error = 3.907e-4; // FLOAT32 case presents error + device == Device::CANN ? expect_storage_eq(values, expected_values2, cann_error) : expect_storage_eq(values, expected_values2); + } expect_storage_eq(indices, expected_indices2); } @@ -580,7 +1018,10 @@ TEST_P(OpDeviceTest, TopKChangeK) { StorageView values_k2(expected_values_k2.dtype(), device); StorageView indices_k2(expected_indices_k2.dtype(), device); ops::TopK(2)(input, values_k2, indices_k2); - expect_storage_eq(values_k2, expected_values_k2); + { + const float cann_error = 3.907e-4; // FLOAT32 case presents error + device == Device::CANN ? expect_storage_eq(values_k2, expected_values_k2, cann_error) : expect_storage_eq(values_k2, expected_values_k2); + } expect_storage_eq(indices_k2, expected_indices_k2); const StorageView expected_values_k3({2, 3}, std::vector{2.0, 0.6, 0.2, 1.1, 1.0, 0.3}, device); @@ -588,12 +1029,17 @@ TEST_P(OpDeviceTest, TopKChangeK) { StorageView values_k3(expected_values_k3.dtype(), device); StorageView indices_k3(expected_indices_k3.dtype(), device); ops::TopK(3)(input, values_k3, indices_k3); - expect_storage_eq(values_k3, expected_values_k3); + { + const float cann_error = 3.907e-4; // FLOAT32 case presents error + device == Device::CANN ? expect_storage_eq(values_k3, expected_values_k3, cann_error) : expect_storage_eq(values_k3, expected_values_k3); + } expect_storage_eq(indices_k3, expected_indices_k3); } TEST_P(OpDeviceFPTest, TopPMask) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; constexpr float inf = std::numeric_limits::infinity(); @@ -617,7 +1063,10 @@ TEST_P(OpDeviceFPTest, SoftMax) { const float error = GetParam().error; StorageView x = StorageView({2, 5}, std::vector{ -0.2, 3.0, 1.2, -1.1, 0.0, - 4.6, 3.3, 0.2, -1.6, 1.0}, device).to(dtype); + 4.6, 3.3, 0.2, -1.6, 1.0}, device); + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + x = x.to(dtype); StorageView expected({2, 5}, std::vector{ 0.032035, 0.785904, 0.129909, 0.013025, 0.039128, 0.760941, 0.207381, 0.009342, 0.001544, 0.020792}, device); @@ -628,13 +1077,51 @@ TEST_P(OpDeviceFPTest, SoftMax) { expect_storage_eq(x.to_float32(), expected, error); } +TEST_P(OpDeviceFPTest, SoftMax1D) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + StorageView x = StorageView({5}, std::vector{ + -0.2, 3.0, 1.2, -1.1, 0.0}, device); + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + x = x.to(dtype); + StorageView expected({5}, std::vector{ + 0.032035, 0.785904, 0.129909, 0.013025, 0.039128}, device); + StorageView y(dtype, device); + ops::SoftMax()(x, y); + expect_storage_eq(y.to_float32(), expected, error); + ops::SoftMax()(x); + expect_storage_eq(x.to_float32(), expected, error); +} + +TEST_P(OpDeviceFPTest, SoftMax1DWithLength) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + StorageView x = StorageView({5}, std::vector{ + -0.2, 3.0, 1.2, -1.1, 42.17}, device); + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + x = x.to(dtype); + StorageView lengths({1}, std::vector{4}, device); + StorageView expected({5}, std::vector{ + 0.0333396, 0.8179057, 0.1351989, 0.013554, 0}, device); + StorageView y(dtype, device); + ops::SoftMax()(x, lengths, y); + expect_storage_eq(y.to_float32(), expected, error); +} + TEST_P(OpDeviceFPTest, LogSoftMax) { const Device device = GetParam().device; const DataType dtype = GetParam().dtype; const float error = GetParam().error; StorageView x = StorageView({2, 10}, std::vector{ -0.2, 3.0, 1.2, -1.1, 0.0, 0.2, -3.0, -1.2, 1.1, 0.0, - 4.6, 3.3, 0.2, -1.6, 1.0, -4.6, -3.3, -0.2, 1.6, -1.0}, device).to(dtype); + 4.6, 3.3, 0.2, -1.6, 1.0, -4.6, -3.3, -0.2, 1.6, -1.0}, device); + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + x = x.to(dtype); StorageView expected({2, 10}, std::vector{ -3.638294, -0.438294, -2.238294, -4.538294, -3.438294, -3.238294, -6.438294, -4.638294, -2.338294, -3.438294, -0.319434, -1.619434, -4.719434, -6.519434, -3.919434, -9.519434, -8.219434, -5.119434, -3.319434, -5.919434}, device); @@ -645,6 +1132,27 @@ TEST_P(OpDeviceFPTest, LogSoftMax) { expect_storage_eq(x.to_float32(), expected, error * 10); } +TEST_P(OpDeviceFPTest, MaskedLogSoftMax) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + StorageView x = StorageView({3, 10}, std::vector{ + -0.2, 3.0, 1.2, -1.1, 0.0, 0.2, -3.0, -1.2, 1.1, 0.0, + 4.6, 3.3, 0.2, -1.6, 1.0, -4.6, -3.3, -0.2, 1.6, -1.0, + -1.1, 0.0, 0.2, -3.0, -1.2, 4.6, 3.3, 0.2, -1.6, 1.0}, device); + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + x = x.to(dtype); + StorageView lengths({3}, std::vector{3, 5, 7}, device); + StorageView expected({3, 10}, std::vector{ + -3.38735985, -0.18735980, -1.98735976, 0, 0, 0, 0, 0,0,0, + -0.27319955, -1.57319951, -4.67319965, -6.47319936, -3.87319946, 0, 0, 0,0,0, + -5.96369791, -4.86369800, -4.66369819, -7.86369800, -6.06369781, -0.26369810, -1.56369805, 0,0,0}, device); + StorageView y(dtype, device); + ops::LogSoftMax()(x, lengths, y); + expect_storage_eq(y.to_float32(), expected, error * 10); +} + TEST_P(OpDeviceFPTest, MaskedSoftMax) { const Device device = GetParam().device; const DataType dtype = GetParam().dtype; @@ -657,13 +1165,81 @@ TEST_P(OpDeviceFPTest, MaskedSoftMax) { 0.033797, 0.829145, 0.137056, 0, 0, 0.777098, 0.211783, 0.009540, 0.001577, 0}, device); StorageView y(dtype, device); + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + ops::SoftMax()(x.to(dtype), lengths, y); + expect_storage_eq(y.to_float32(), expected, error); +} + +TEST_P(OpDeviceFPTest, MaskedSoftMaxWithLengthsEqualToLastDim) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + StorageView x({2, 10}, std::vector{ + -0.2, 3.0, 1.2, -1.1, 0.0, 4.6, 3.3, 0.2, 3.0, 1.21, + 4.6, 3.3, 0.2, -1.6, 1.0, 1.2, -1.1, 0.0, 0.17, 0.42}, device); + StorageView lengths({2}, std::vector{10, 10}, device); + StorageView expected({2, 10}, std::vector{ + 0.0046304, 0.1135965, 0.0187773, 0.0018825, 0.0056556, 0.5626471, 0.15333, 0.0069078, 0.11359, 0.01896, + 0.72038, 0.19632, 0.00884, 0.00146, 0.01968, 0.02404, 0.00241, 0.00724, 0.00858, 0.01102}, device); + StorageView y(dtype, device); + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; ops::SoftMax()(x.to(dtype), lengths, y); + + expect_storage_eq(y.to_float32(), expected, error); +} + +TEST_P(OpDeviceFPTest, MaskedSoftMax4D) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + const float error = GetParam().error; + StorageView x({2, 2, 3, 3}, std::vector{ + 0.08784354, 0.67030656, 0.8866086, + 0.08053982, 0.9826797, 0.7965635, + 0.48865926, 0.8635745, 0.21703207, + 0.0742166, 0.0623771, 0.7590432, + 0.43742728, 0.12613738, 0.53697634, + 0.05396891, 0.04152167, 0.66332567, + 0.6386628, 0.23325896, 0.6977577, + 0.06948507, 0.10246396, 0.6232395, + 0.7822603, 0.3168552, 0.11804962, + 0.1133163, 0.29983068, 0.43074536, + 0.7321733, 0.48709297, 0.35727918, + 0.8421174, 0.9135181, 0.77135813 + }, device); + StorageView mask({2, 2, 3}, std::vector{ + 1, 2, 3, + 1, 2, 3, + 1, 2, 2, + 1, 2, 2 + }, device); + StorageView expected({2, 2, 3, 3}, std::vector{ + 1, 0, 0, + 0.28861094, 0.71138906, 0, + 0.310848, 0.45224282, 0.23690917, + 1, 0, 0, + 0.57720006, 0.42279992, 0, + 0.26130962, 0.25807717, 0.48061317, + 1, 0, 0, + 0.49175602, 0.508244, 0, + 0.61429566, 0.3857044, 0, + 1, 0, 0, + 0.56096524, 0.43903476, 0, + 0.48215744, 0.5178426, 0 + }, device); + StorageView y(dtype, device); + ops::SoftMax()(x.to(dtype), mask, y); expect_storage_eq(y.to_float32(), expected, error); } TEST_P(OpDeviceFPTest, MaskedSoftMaxTriangular) { const Device device = GetParam().device; const DataType dtype = GetParam().dtype; + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; const float error = GetParam().error; StorageView x({2, 2, 3, 3}, std::vector{ 0.08784354, 0.67030656, 0.8866086, @@ -703,6 +1279,8 @@ TEST_P(OpDeviceFPTest, MaskedSoftMaxTriangular) { TEST_P(OpDeviceFPTest, LayerNorm) { const Device device = GetParam().device; const DataType dtype = GetParam().dtype; + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; const float error = GetParam().error; StorageView gamma({5}, std::vector{0.2, 2.1, 1.1, -0.6, 0.7}, device); StorageView beta({5}, std::vector{-6.6, -5.7, 0.01, 2.0, 0}, device); @@ -717,10 +1295,102 @@ TEST_P(OpDeviceFPTest, LayerNorm) { expect_storage_eq(y.to_float32(), expected, error); } +TEST_P(OpDeviceFPTest, LayerNormZerosAndOnes) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + const float error = GetParam().error; + StorageView gamma({2}, 0.f, device); + StorageView beta({2}, 1.f, device); + StorageView x({5, 2}, std::vector{ + 0, 10, + 20, 30, + 40, 50, + 60, 70, + 80, 90}, device); + StorageView expected({5, 2}, std::vector{ + 1, 1, + 1, 1, + 1, 1, + 1, 1, + 1, 1}, device); + StorageView y(dtype, device); + ops::LayerNorm()(beta.to(dtype), gamma.to(dtype), x.to(dtype), y); + expect_storage_eq(y.to_float32(), expected, error); +} + +TEST_P(OpDeviceFPTest, LayerNorm3DEasy) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + const float error = GetParam().error; + StorageView gamma({2}, std::vector{0.2, 2.1}, device); + StorageView beta({2}, std::vector{-6.6, -5.7}, device); + StorageView x({2, 5, 2}, std::vector{ + 0, 10, + 20, 30, + 40, 50, + 60, 70, + 80, 90, + + 0, 10, + 20, 30, + 40, 50, + 60, 70, + 80, 90}, device); + StorageView expected({2, 5, 2}, std::vector{ + -6.79999, -3.6, + -6.79999, -3.6, + -6.79999, -3.6, + -6.79999, -3.6, + -6.79999, -3.6, + + -6.79999, -3.6, + -6.79999, -3.6, + -6.79999, -3.6, + -6.79999, -3.6, + -6.79999, -3.6}, device); + StorageView y(dtype, device); + ops::LayerNorm()(beta.to(dtype), gamma.to(dtype), x.to(dtype), y); + expect_storage_eq(y.to_float32(), expected, error); +} + + +TEST_P(OpDeviceFPTest, LayerNorm3DHard) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + const float error = GetParam().error; + StorageView gamma({4}, std::vector{0.2, 2.1, -1.3, -4.2}, device); + StorageView beta({4}, std::vector{2.2, 4.43, -1.6, -1.7}, device); + StorageView x({2, 3, 4}, std::vector{ + 4.5, 0.6, 0.5, 0.6, + 0.5, 4.6, 5.5, 6.67, + -5.5, 0.2, 0.5, -7.46, + + -0.5, 0.6, 0.4, 0.6, + 0.5, 17.6, 0.1, -0.62, + -42.4, 78.6, 0.5, 0.6}, device); + StorageView expected({2, 3, 4}, std::vector{ + 2.546310, 3.259002, -0.798791, 0.641994, + 1.871333, 4.685378, -2.261745, -5.953297, + 2.060307, 6.396746, -2.929379, 3.594856, + + 1.859225, 5.930507, -1.957263, -4.701014, + 2.097962, 8.062276, -0.868645, 1.058934, + 1.963113, 7.761240, -1.337295, -0.860878}, device); + StorageView y(dtype, device); + ops::LayerNorm()(beta.to(dtype), gamma.to(dtype), x.to(dtype), y); + expect_storage_eq(y.to_float32(), expected, error); +} + TEST_P(OpDeviceFPTest, LayerNormAxis) { const Device device = GetParam().device; - if (device == Device::CUDA) { - GTEST_SKIP() << "Generalized LayerNorm is not implemented on GPU"; + if (device == Device::CUDA || device == Device::CANN) { + GTEST_SKIP() << "Generalized LayerNorm is not implemented on GPU and NPU"; } const DataType dtype = GetParam().dtype; const float error = GetParam().error; @@ -745,6 +1415,8 @@ TEST_P(OpDeviceFPTest, LayerNormAxis) { TEST_P(OpDeviceFPTest, RMSNorm) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; StorageView gamma({5}, std::vector{0.2, 2.1, 1.1, -0.6, 0.7}, device); @@ -761,6 +1433,8 @@ TEST_P(OpDeviceFPTest, RMSNorm) { TEST_P(OpDeviceTest, QuantizeINT8) { Device device = GetParam(); + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; StorageView a({2, 4}, std::vector{-10, -3, 5, 2, 5, 21, -3, 0}, device); StorageView scale(DataType::FLOAT32, device); StorageView qa(DataType::INT8, device); @@ -785,6 +1459,8 @@ TEST_P(OpDeviceTest, QuantizeINT8) { TEST_P(OpDeviceTest, QuantizeINT8ZeroRow) { Device device = GetParam(); + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; StorageView a({2, 4}, std::vector{-10, -3, 5, 2, 0, 0, 0, 0}, device); StorageView scale(DataType::FLOAT32, device); StorageView qa(DataType::INT8, device); @@ -809,6 +1485,8 @@ TEST_P(OpDeviceTest, QuantizeINT8ZeroRow) { TEST_P(OpDeviceFPTest, Multinomial) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; StorageView input({2, 4}, std::vector{0.2, 0.1, 0.6, 0.1, 0.7, 0.2, 0.0, 0.1}, device); StorageView output(DataType::INT32, device); @@ -836,12 +1514,40 @@ TEST_P(OpDeviceFPTest, ReLU) { StorageView input({2, 5}, std::vector{-1, 1, 2, -2, 2, 4, -3, 0, -1, -3}, device); StorageView expected({2, 5}, std::vector{0, 1, 2, 0, 2, 4, 0, 0, 0, 0}, device); StorageView output(dtype, device); + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + ops::ReLU()(input.to(dtype), output); + expect_storage_eq(output.to_float32(), expected, error); +} + +TEST_P(OpDeviceFPTest, ReLULarge) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + StorageView input({2, 5, 6}, std::vector{-1.12, 1.55, 2.3, -2.42, 2.17, 4.5, -3.27, 0.12, -1.55, -3.17, + -1, 1, 2, -2, 2, 4, -3, 0, -1, -32.17, + -1, 1, 2, -2, 2, 4, -3, 0, -1, -3, + -5.12, 9.55, 2.3, -2.42, 2.17, 4.5, 3.27, 1.12, -8.55, -33.17, + -1, 1, 2, -2, 2, 4, -3, 0, -1, -3, + -1, 1, 2, -2, 2, 4, -3, 0.42, -1, -3.42}, device); + StorageView expected({2, 5, 6}, std::vector{0, 1.55, 2.3, 0, 2.17, 4.5, 0, 0.12, 0, 0, + 0, 1, 2, 0, 2, 4, 0, 0, 0, 0, + 0, 1, 2, 0, 2, 4, 0, 0, 0, 0, + 0, 9.55, 2.3, 0, 2.17, 4.5, 3.27, 1.12, 0, 0, + 0, 1, 2, 0, 2, 4, 0, 0, 0, 0, + 0, 1, 2, 0, 2, 4, 0, 0.42, 0, 0}, device); + + StorageView output(dtype, device); + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; ops::ReLU()(input.to(dtype), output); expect_storage_eq(output.to_float32(), expected, error); } TEST_P(OpDeviceFPTest, GELU) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; StorageView input({2}, std::vector{0.2, -1.3}, device); @@ -853,6 +1559,10 @@ TEST_P(OpDeviceFPTest, GELU) { TEST_P(OpDeviceFPTest, GELUTanh) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; StorageView input({2}, std::vector{0.2, -1.3}, device); @@ -865,6 +1575,8 @@ TEST_P(OpDeviceFPTest, GELUTanh) { TEST_P(OpDeviceFPTest, GELUSigmoid) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; StorageView input({2}, std::vector{0.2, -1.3}, device); @@ -877,6 +1589,8 @@ TEST_P(OpDeviceFPTest, GELUSigmoid) { TEST_P(OpDeviceFPTest, Swish) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; StorageView input({2}, std::vector{0.2, -1.3}, device); @@ -886,8 +1600,46 @@ TEST_P(OpDeviceFPTest, Swish) { expect_storage_eq(output.to_float32(), expected, error); } +TEST_P(OpDeviceFPTest, Cos) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + std::vector input_vec({0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4}); + std::vector expected_vec; + expected_vec.reserve(input_vec.size()); + std::transform(input_vec.begin(), input_vec.end(), std::back_inserter(expected_vec), + [](const float& i){return std::cos(i);}); + StorageView input({2, 4}, input_vec, device); + StorageView expected({2, 4}, expected_vec, device); + StorageView output(dtype, device); + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + ops::Cos()(input.to(dtype), output); + expect_storage_eq(output.to_float32(), expected, error); +} + +TEST_P(OpDeviceFPTest, Sin) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + std::vector input_vec({0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4}); + std::vector expected_vec; + expected_vec.reserve(input_vec.size()); + std::transform(input_vec.begin(), input_vec.end(), std::back_inserter(expected_vec), + [](const float& i){return std::sin(i);}); + StorageView input({2, 4}, input_vec, device); + StorageView expected({2, 4}, expected_vec, device); + StorageView output(dtype, device); + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + ops::Sin()(input.to(dtype), output); + expect_storage_eq(output.to_float32(), expected, error); +} + TEST_P(OpDeviceFPTest, Tanh) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; StorageView x({1, 5}, std::vector{-2, -1.5, 0, 1.5, 2}, device); @@ -901,6 +1653,8 @@ TEST_P(OpDeviceFPTest, Tanh) { TEST_P(OpDeviceFPTest, Log) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; std::vector input_vec({0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4}); @@ -917,6 +1671,8 @@ TEST_P(OpDeviceFPTest, Log) { TEST_P(OpDeviceFPTest, LogLimits) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; StorageView values({2}, std::vector{0.f, -1.f}, device); @@ -961,6 +1717,8 @@ void TestMinMax(Device device, const Ops& ops, const Func& func){ TEST_P(OpDeviceTest, Min) { Device device = GetParam(); + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; auto ops = ops::Min(); TestMinMax(device, ops, [](float left, float right){ return left > right? right : left; @@ -969,6 +1727,8 @@ TEST_P(OpDeviceTest, Min) { TEST_P(OpDeviceTest, Max) { Device device = GetParam(); + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; auto ops = ops::Max(); TestMinMax(device, ops, [](float left, float right){ return left > right? left : right; @@ -998,6 +1758,8 @@ TEST_P(OpDeviceFPTest, Conv1D) { const Device device = GetParam().device; if (device == Device::CUDA) GUARD_CONV1D_GPU_TEST; + else if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; const StorageView expected({2, 4, 2}, std::vector{ @@ -1018,6 +1780,8 @@ TEST_P(OpDeviceFPTest, Conv1DNoBias) { const Device device = GetParam().device; if (device == Device::CUDA) GUARD_CONV1D_GPU_TEST; + else if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; const StorageView expected({2, 4, 2}, std::vector{ @@ -1037,6 +1801,8 @@ TEST_P(OpDeviceFPTest, Conv1DPadding) { const Device device = GetParam().device; if (device == Device::CUDA) GUARD_CONV1D_GPU_TEST; + else if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; const StorageView expected({2, 4, 4}, std::vector{ @@ -1061,6 +1827,8 @@ TEST_P(OpDeviceFPTest, Conv1DStride) { const Device device = GetParam().device; if (device == Device::CUDA) GUARD_CONV1D_GPU_TEST; + else if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; const StorageView expected({2, 4, 1}, std::vector{ @@ -1079,6 +1847,8 @@ TEST_P(OpDeviceFPTest, Conv1DPaddingAndStride) { const Device device = GetParam().device; if (device == Device::CUDA) GUARD_CONV1D_GPU_TEST; + else if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const float error = GetParam().error; const StorageView expected({2, 4, 2}, std::vector{ @@ -1095,6 +1865,27 @@ TEST_P(OpDeviceFPTest, Conv1DPaddingAndStride) { expect_storage_eq(output.to_float32(), expected, error); } +TEST_P(OpDeviceFPTest, SplitAxis0EqualLengthParts) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + if(device == Device::CANN && dtype == DataType::BFLOAT16) + GUARD_BFLOAT16_NPU_TEST; + const float error = GetParam().error; + StorageView input({4, 2}, std::vector{1.42, -2.42, + 3.42, 4.42, + 5.42, 6.42, + 7.42, -8.42}, device); + StorageView output1(dtype, device); + StorageView output2(dtype, device); + ops::Split(0)(input.to(dtype), output1, output2); + StorageView expected_output1({2, 2}, std::vector{1.42, -2.42, 3.42, 4.42}, device); + StorageView expected_output2({2, 2}, std::vector{5.42, 6.42, 7.42, -8.42}, device); + EXPECT_EQ(output1.dtype(), dtype); + EXPECT_EQ(output2.dtype(), dtype); + expect_storage_eq(output1.to_float32(), expected_output1, error); + expect_storage_eq(output2.to_float32(), expected_output2, error); +} + INSTANTIATE_TEST_SUITE_P(CPU, OpDeviceTest, ::testing::Values(Device::CPU)); INSTANTIATE_TEST_SUITE_P(CPU, OpDeviceFPTest, @@ -1107,4 +1898,11 @@ INSTANTIATE_TEST_SUITE_P(CUDA, OpDeviceFPTest, FloatType{Device::CUDA, DataType::FLOAT16, 1e-2}, FloatType{Device::CUDA, DataType::BFLOAT16, 1e-2}), fp_test_name); +#elif CT2_WITH_CANN +INSTANTIATE_TEST_SUITE_P(CANN, OpDeviceTest, ::testing::Values(Device::CANN)); +INSTANTIATE_TEST_SUITE_P(CANN, OpDeviceFPTest, + ::testing::Values(FloatType{Device::CANN, DataType::FLOAT32, 1e-5}, + FloatType{Device::CANN, DataType::FLOAT16, 1e-2}, + FloatType{Device::CANN, DataType::BFLOAT16, 1e-2}), + fp_test_name); #endif diff --git a/tests/primitives_test.cc b/tests/primitives_test.cc index 9f603de33..10358bb54 100644 --- a/tests/primitives_test.cc +++ b/tests/primitives_test.cc @@ -5,8 +5,48 @@ class PrimitiveTest : public ::testing::TestWithParam { }; +TEST_P(PrimitiveTest, FillFloat16) { + const Device device = GetParam(); + StorageView x({2, 3}, DataType::FLOAT16, device); + auto fill_value = float16_t(42.23); + StorageView expected({2, 3}, std::vector{fill_value, fill_value, fill_value, + fill_value, fill_value, fill_value}, device); + DEVICE_DISPATCH(device, primitives::fill(x.data(), fill_value, x.size())); + expect_storage_eq(x, expected); +} + +TEST_P(PrimitiveTest, FillFloat32) { + const Device device = GetParam(); + StorageView x({2, 3}, DataType::FLOAT32, device); + auto fill_value = 42.23f; + StorageView expected({2, 3}, std::vector{fill_value, fill_value, fill_value, + fill_value, fill_value, fill_value}, device); + DEVICE_DISPATCH(device, primitives::fill(x.data(), fill_value, x.size())); + expect_storage_eq(x, expected); +} + +TEST_P(PrimitiveTest, ZeroFloat16) { + const Device device = GetParam(); + StorageView x({2, 3}, DataType::FLOAT16, device); + StorageView expected({2, 3}, std::vector{float16_t(0), float16_t(0), float16_t(0), + float16_t(0), float16_t(0), float16_t(0)}, device); + DEVICE_DISPATCH(device, primitives::zero(x.data(), x.size())); + expect_storage_eq(x, expected); +} + +TEST_P(PrimitiveTest, ZeroFloat32) { + const Device device = GetParam(); + StorageView x({2, 3}, DataType::FLOAT32, device); + StorageView expected({2, 3}, std::vector{0, 0, 0, + 0, 0, 0}, device); + DEVICE_DISPATCH(device, primitives::zero(x.data(), x.size())); + expect_storage_eq(x, expected); +} + TEST_P(PrimitiveTest, StridedFill) { const Device device = GetParam(); + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; StorageView x({3, 2}, float(0), device); StorageView expected({3, 2}, std::vector{1, 0, 1, 0, 1, 0}, device); DEVICE_DISPATCH(device, primitives::strided_fill(x.data(), 1.f, 2, 3)); @@ -14,16 +54,42 @@ TEST_P(PrimitiveTest, StridedFill) { } TEST_P(PrimitiveTest, IndexedFill) { + const Device device = GetParam(); + StorageView x({6}, float(0), device); + StorageView ids({3}, std::vector{0, 2, 5}, device); + StorageView expected({6}, std::vector{1, 0, 1, 0, 0, 1}, device); + DEVICE_DISPATCH(device, primitives::indexed_fill(x.data(), 1.f, ids.data(), 3, x.size())); + expect_storage_eq(x, expected); +} + +TEST_P(PrimitiveTest, IndexedFill2D) { const Device device = GetParam(); - StorageView x({6}, float(0), device); - StorageView ids({3}, std::vector{0, 2, 5}, device); - StorageView expected({6}, std::vector{1, 0, 1, 0, 0, 1}, device); - DEVICE_DISPATCH(device, primitives::indexed_fill(x.data(), 1.f, ids.data(), 3)); + StorageView x({3, 3}, std::vector{1, 2, 3, + 4, 5, 6, + 7, 8, 9}, device); + StorageView ids({6}, std::vector{0, 2, 3, 5, 6, 8}, device); + StorageView expected({3, 3}, std::vector{-1, 2, -1, + -1, 5, -1, + -1, 8, -1}, device); + DEVICE_DISPATCH(device, primitives::indexed_fill(x.data(), -1.0f, ids.data(), 6, x.size())); + expect_storage_eq(x, expected); +} + +TEST_P(PrimitiveTest, IndexedFill2DComplexFloats) { + const Device device = GetParam(); + StorageView x({2, 3}, std::vector{-1.89935, -1.89909, 8.05185, + -1e+10, -1e+10, -1e+10}, device); + StorageView ids({2}, std::vector{2, 5}, device); + StorageView expected({2, 3}, std::vector{-1.89935, -1.89909, -3.40282e+38, + -1e+10, -1e+10, -3.40282e+38}, device); + DEVICE_DISPATCH(device, primitives::indexed_fill(x.data(), -3.40282e+38f, ids.data(), 2, x.size())); expect_storage_eq(x, expected); } TEST_P(PrimitiveTest, LogSumExp) { const Device device = GetParam(); + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; StorageView x({8}, std::vector{0.6, 0.2, -1.2, 0.1, 0.3, 0.5, -1.3, 0.2}, device); float result = 0; DEVICE_DISPATCH(device, result = primitives::logsumexp(x.data(), x.size())); @@ -32,6 +98,8 @@ TEST_P(PrimitiveTest, LogSumExp) { TEST_P(PrimitiveTest, PenalizePreviousTokens) { const Device device = GetParam(); + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const float penalty = 1.2f; StorageView scores({2, 4}, std::vector{0.6, 0.2, -1.2, 0.1, 0.3, 0.5, -1.3, 0.2}); StorageView previous_ids({2, 2}, std::vector{2, 2, 1, 2}, device); @@ -51,7 +119,150 @@ TEST_P(PrimitiveTest, PenalizePreviousTokens) { expect_storage_eq(scores, expected); } +TEST_P(PrimitiveTest, AddDepthBroadcast2DInput) { + const Device device = GetParam(); + StorageView x1({2, 8}, std::vector{12, -20, 31, 0.3, -42.17, 17.42, 40.5, -0.001, + 112, 20, -31, 40.3, -4.17, -7.42, -50.34, 2.031}, device); + StorageView x2({2}, std::vector{0.1, 0.2}, device); + StorageView expected({2, 8}, std::vector{12.1, -19.9, 31.1, 0.4, -42.07, 17.52, 40.6, 0.099, + 112.2, 20.2, -30.8, 40.5, -3.97, -7.22, -50.14, 2.231}, device); + StorageView output({2, 8}, DataType::FLOAT32, device); + DEVICE_DISPATCH(device, primitives::add_depth_broadcast(x2.data(), x1.data(), output.data(), x2.size(), x1.size())); + expect_storage_eq(output, expected); +} + +TEST_P(PrimitiveTest, AddDepthBroadcast3DInput) { + const Device device = GetParam(); + StorageView x1({2, 3, 4}, std::vector{12, -20.54, 31.1, 0.3, + 42.17, 17.42, 40.5, -0.001, + 12.2, -20, -31, 40.3, + + 7.4, -50.34, 2, 0.12, + 1, 20, -31, 40.3, + 2, -20, 31, 0.3}, device); + StorageView x2({2}, std::vector{1.1, 2.1}, device); + StorageView expected({2, 3, 4}, std::vector{13.1, -19.44, 32.2, 1.4, + 43.27, 18.52, 41.6, 1.099, + 13.3, -18.9, -29.9, 41.4, + + 9.5, -48.24, 4.1, 2.22, + 3.1, 22.1, -28.9, 42.4, + 4.1, -17.9, 33.1, 2.4}, device); + DEVICE_DISPATCH(device, primitives::add_depth_broadcast(x2.data(), x1.data(), x1.data(), x2.size(), x1.size())); + expect_storage_eq(x1, expected); +} + +TEST_P(PrimitiveTest, PrepareLengthMask) { + const Device device = GetParam(); + StorageView lengths({2}, std::vector{17, 42}, device); + StorageView mask({2, 3, 4}, DataType::INT32, device); + DEVICE_DISPATCH(device, primitives::prepare_length_mask( + lengths.data(), /*batch_size*/ lengths.size(), /*num_heads*/ 3, /*num_queries*/ 4, + /*mask_future*/ false, /*multi_query*/ false, mask.data())); + StorageView expected({2, 3, 4}, std::vector{17, 17, 17, 17, + 17, 17, 17, 17, + 17, 17, 17, 17, + + 42, 42, 42, 42, + 42, 42, 42, 42, + 42, 42, 42, 42}, device); + expect_storage_eq(mask, expected); +} + +TEST_P(PrimitiveTest, PrepareLengthMaskMultiQuery) { + const Device device = GetParam(); + StorageView lengths({2}, std::vector{17, 42}, device); + StorageView mask({2, 4, 3}, DataType::INT32, device); + DEVICE_DISPATCH(device, primitives::prepare_length_mask( + lengths.data(), /*batch_size*/ lengths.size(), /*num_heads*/ 3, /*num_queries*/ 4, + /*mask_future*/ false, /*multi_query*/ true, mask.data())); + StorageView expected({2, 4, 3}, std::vector{17, 17, 17, + 17, 17, 17, + 17, 17, 17, + 17, 17, 17, + + 42, 42, 42, + 42, 42, 42, + 42, 42, 42, + 42, 42, 42}, device); + expect_storage_eq(mask, expected); +} + +TEST_P(PrimitiveTest, PrepareLengthMaskMultiQueryMaskFuture) { + const Device device = GetParam(); + StorageView lengths({2}, std::vector{17, 42}, device); + StorageView mask({2, 4, 3}, DataType::INT32, device); + DEVICE_DISPATCH(device, primitives::prepare_length_mask( + lengths.data(), /*batch_size*/ lengths.size(), /*num_heads*/ 3, /*num_queries*/ 4, + /*mask_future*/ true, /*multi_query*/ true, mask.data())); + StorageView expected({2, 4, 3}, std::vector{1, 1, 1, + 2, 2, 2, + 3, 3, 3, + 4, 4, 4, + + 1, 1, 1, + 2, 2, 2, + 3, 3, 3, + 4, 4, 4}, device); + expect_storage_eq(mask, expected); +} + +TEST_P(PrimitiveTest, PrepareLengthMaskMultiQueryMaskFutureSmallLength) { + const Device device = GetParam(); + StorageView lengths({2}, std::vector{3, 2}, device); + StorageView mask({2, 4, 3}, DataType::INT32, device); + DEVICE_DISPATCH(device, primitives::prepare_length_mask( + lengths.data(), /*batch_size*/ lengths.size(), /*num_heads*/ 3, /*num_queries*/ 4, + /*mask_future*/ true, /*multi_query*/ true, mask.data())); + StorageView expected({2, 4, 3}, std::vector{1, 1, 1, + 2, 2, 2, + 3, 3, 3, + 3, 3, 3, + + 1, 1, 1, + 2, 2, 2, + 2, 2, 2, + 2, 2, 2}, device); + expect_storage_eq(mask, expected); +} + +TEST_P(PrimitiveTest, PrepareLengthMaskMaskFuture) { + const Device device = GetParam(); + StorageView lengths({2}, std::vector{17, 42}, device); + StorageView mask({2, 3, 4}, DataType::INT32, device); + DEVICE_DISPATCH(device, primitives::prepare_length_mask( + lengths.data(), /*batch_size*/ lengths.size(), /*num_heads*/ 3, /*num_queries*/ 4, + /*mask_future*/ true, /*multi_query*/ false, mask.data())); + StorageView expected({2, 3, 4}, std::vector{1, 2, 3, 4, + 1, 2, 3, 4, + 1, 2, 3, 4, + + 1, 2, 3, 4, + 1, 2, 3, 4, + 1, 2, 3, 4}, device); + expect_storage_eq(mask, expected); +} + +TEST_P(PrimitiveTest, PrepareLengthMaskMaskFutureSmallLength) { + const Device device = GetParam(); + StorageView lengths({2}, std::vector{2, 3}, device); + StorageView mask({2, 3, 4}, DataType::INT32, device); + DEVICE_DISPATCH(device, primitives::prepare_length_mask( + lengths.data(), /*batch_size*/ lengths.size(), /*num_heads*/ 3, /*num_queries*/ 4, + /*mask_future*/ true, /*multi_query*/ false, mask.data())); + StorageView expected({2, 3, 4}, std::vector{1, 2, 2, 2, + 1, 2, 2, 2, + 1, 2, 2, 2, + + 1, 2, 3, 3, + 1, 2, 3, 3, + 1, 2, 3, 3}, device); + expect_storage_eq(mask, expected); +} + INSTANTIATE_TEST_SUITE_P(CPU, PrimitiveTest, ::testing::Values(Device::CPU)); #ifdef CT2_WITH_CUDA INSTANTIATE_TEST_SUITE_P(CUDA, PrimitiveTest, ::testing::Values(Device::CUDA)); +#elif CT2_WITH_CANN +INSTANTIATE_TEST_SUITE_P(CANN, PrimitiveTest, ::testing::Values(Device::CANN)); #endif diff --git a/tests/storage_view_test.cc b/tests/storage_view_test.cc index bc6da8825..b5661b816 100644 --- a/tests/storage_view_test.cc +++ b/tests/storage_view_test.cc @@ -64,6 +64,66 @@ TEST(StorageViewTest, ExpandDimsAndSqueeze) { class StorageViewDeviceTest : public ::testing::TestWithParam { }; +TEST_P(StorageViewDeviceTest, StorageViewConstruction) { + StorageView x({2, 2}, std::vector{1, 2, 3, 4}, GetParam()); +} + +TEST_P(StorageViewDeviceTest, StorageViewEquality) { + const Device device = GetParam(); + StorageView x({2, 2}, std::vector{1, 2, 3, 4}, device); + StorageView expected({2, 2}, std::vector{1, 2, 3, 4}, device); + expect_storage_eq(x, expected); +} + +TEST_P(StorageViewDeviceTest, Shape) { + StorageView a({2, 2}, std::vector{1, 2, 3, 4}, GetParam()); + EXPECT_EQ(a.size(), 4); + EXPECT_EQ(a.dim(0), 2); + EXPECT_EQ(a.dim(1), 2); +} + +TEST_P(StorageViewDeviceTest, ToVector) { + StorageView a({2, 2}, std::vector{1, 2, 3, 4}, GetParam()); + auto vec = a.to_vector(); + EXPECT_EQ(vec[0], 1.); + EXPECT_EQ(vec[1], 2.); + EXPECT_EQ(vec[2], 3.); + EXPECT_EQ(vec[3], 4.); +} + +TEST_P(StorageViewDeviceTest, ScalarAt) { + const Device device = GetParam(); + const dim_t index = 1; + { + StorageView values({2}, std::vector{22, 33}, device); + EXPECT_EQ(values.scalar_at({index}), 33); + } + { + StorageView values({2}, std::vector{'c', 'd'}, device); + EXPECT_EQ(values.scalar_at({index}), 'd'); + } + { + StorageView values({2}, std::vector{0.2f, -1.f}, device); + EXPECT_EQ(values.scalar_at({index}), -1.f); + } +} + +TEST_P(StorageViewDeviceTest, ScalarFill) { + const Device device = GetParam(); + StorageView a({2, 2}, 42.f, device); + StorageView expected({2, 2}, std::vector{42.f, 42.f, + 42.f, 42.f}, device); + expect_storage_eq(a, expected); + StorageView b({3, 4}, 17, device); + StorageView expected2({3, 4}, std::vector{17, 17, 17, 17, + 17, 17, 17, 17, + 17, 17, 17, 17,}, device); + expect_storage_eq(b, expected2); + StorageView c({200, 201}, (float16_t) 42, device); + StorageView expected3({200, 201}, (float16_t) 42, device); + expect_storage_eq(c, expected3); +} + TEST_P(StorageViewDeviceTest, HalfConversion) { const Device device = GetParam(); const StorageView a({4}, std::vector{1, 2, 3, 4}, device); @@ -77,4 +137,6 @@ TEST_P(StorageViewDeviceTest, HalfConversion) { INSTANTIATE_TEST_SUITE_P(CPU, StorageViewDeviceTest, ::testing::Values(Device::CPU)); #ifdef CT2_WITH_CUDA INSTANTIATE_TEST_SUITE_P(CUDA, StorageViewDeviceTest, ::testing::Values(Device::CUDA)); +#elif CT2_WITH_CANN +INSTANTIATE_TEST_SUITE_P(CANN, StorageViewDeviceTest, ::testing::Values(Device::CANN)); #endif diff --git a/tests/test.cc b/tests/test.cc index 9c9a03c4f..7407f3620 100644 --- a/tests/test.cc +++ b/tests/test.cc @@ -1,13 +1,29 @@ #include +std::string g_data_dir; + +#ifdef CT2_WITH_CANN #include "test_utils.h" -std::string g_data_dir; +class CannTestEnvironment : public ::testing::Environment { +public: + void SetUp() override { + cann_test_setup(); + } + void TearDown() override { + cann_test_tear_down(); + } +}; +#endif int main(int argc, char *argv[]) { testing::InitGoogleTest(&argc, argv); if (argc < 2) throw std::invalid_argument("missing data directory"); g_data_dir = argv[1]; + +#ifdef CT2_WITH_CANN + ::testing::AddGlobalTestEnvironment(new CannTestEnvironment); +#endif return RUN_ALL_TESTS(); } diff --git a/tests/test_utils.h b/tests/test_utils.h index a44724e0a..0dcd061a0 100644 --- a/tests/test_utils.h +++ b/tests/test_utils.h @@ -1,10 +1,14 @@ #pragma once #include +#include #include "ctranslate2/storage_view.h" #include "type_dispatch.h" +#ifdef CT2_WITH_CANN +#include "cann/utils.h" +#endif using namespace ctranslate2; @@ -79,3 +83,32 @@ struct FloatType { inline std::string fp_test_name(::testing::TestParamInfo param_info) { return dtype_name(param_info.param.dtype); } +#ifdef CT2_WITH_CANN +# define GUARD_BFLOAT16_NPU_TEST GTEST_SKIP() << "BFLOAT16 not supported" +# define GUARD_OPERATOR_NPU_TEST GTEST_SKIP() << "Operator not implemented in CANN" + +inline void cann_test_setup() { + ctranslate2::initialize_device(); + // set_device_index has to always be called at least once before get_device_index + const auto device_index = 0; + ctranslate2::set_device_index(Device::CANN, device_index); +} + +inline void cann_test_allow_acl_finalize() { + ctranslate2::cann::AclDeviceEnabler::set_allow_acl_finalize(true); +} + +inline void cann_test_disallow_acl_finalize() { + ctranslate2::cann::AclDeviceEnabler::set_allow_acl_finalize(false); +} + +inline void cann_test_tear_down() { + cann_test_allow_acl_finalize(); // Make sure that acl_finalize can be invoked + ctranslate2::cann::AclDeviceEnabler::acl_finalize(); +} + +#else +# define GUARD_BFLOAT16_NPU_TEST do {} while (0) +# define GUARD_OPERATOR_NPU_TEST do {} while (0) +#endif + diff --git a/tests/translator_test.cc b/tests/translator_test.cc index c3edc3b5f..5212dc944 100644 --- a/tests/translator_test.cc +++ b/tests/translator_test.cc @@ -100,6 +100,11 @@ class SearchVariantTest : public ::testing::TestWithParam { }; static Translator default_translator(Device device = Device::CPU) { +#ifdef CT2_WITH_CANN + // Do not allow acl_finalize on Translator destruction to avoid premature acl_finalize call. + // Gtest TearDown takes care of that. + cann_test_disallow_acl_finalize(); +#endif return Translator(default_model_dir(), device); } @@ -513,6 +518,8 @@ class BiasedDecodingDeviceFPTest : public ::testing::TestWithParam { TEST_P(BiasedDecodingDeviceFPTest, OneBatchOneBeam) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const dim_t vocab_size = 2; const dim_t batch_size = 1; @@ -547,6 +554,8 @@ TEST_P(BiasedDecodingDeviceFPTest, OneBatchOneBeam) { TEST_P(BiasedDecodingDeviceFPTest, TwoBatchesTwoBeams) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const dim_t vocab_size = 2; const dim_t batch_size = 2; @@ -597,6 +606,8 @@ TEST_P(BiasedDecodingDeviceFPTest, TwoBatchesTwoBeams) { TEST_P(BiasedDecodingDeviceFPTest, BeamDiverged) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const dim_t vocab_size = 2; const dim_t batch_size = 1; @@ -625,6 +636,8 @@ TEST_P(BiasedDecodingDeviceFPTest, BeamDiverged) { TEST_P(BiasedDecodingDeviceFPTest, TimeStepPastPrefix) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const dim_t vocab_size = 2; const dim_t batch_size = 1; @@ -653,6 +666,8 @@ TEST_P(BiasedDecodingDeviceFPTest, TimeStepPastPrefix) { TEST_P(BiasedDecodingDeviceFPTest, NonZeroTimestepBias) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const dim_t vocab_size = 2; const dim_t batch_size = 1; @@ -686,6 +701,8 @@ TEST_P(BiasedDecodingDeviceFPTest, NonZeroTimestepBias) { TEST_P(BiasedDecodingDeviceFPTest, NonZeroTimestepDiverge) { const Device device = GetParam().device; + if(device == Device::CANN) + GUARD_OPERATOR_NPU_TEST; const DataType dtype = GetParam().dtype; const dim_t vocab_size = 2; const dim_t batch_size = 1; @@ -720,6 +737,11 @@ INSTANTIATE_TEST_SUITE_P(CUDA, BiasedDecodingDeviceFPTest, ::testing::Values(FloatType{Device::CUDA, DataType::FLOAT32}, FloatType{Device::CUDA, DataType::FLOAT16}), fp_test_name); +#elif CT2_WITH_CANN +INSTANTIATE_TEST_SUITE_P(CANN, BiasedDecodingDeviceFPTest, + ::testing::Values(FloatType{Device::CANN, DataType::FLOAT32}, + FloatType{Device::CANN, DataType::FLOAT16}), + fp_test_name); #endif TEST(TranslatorTest, TranslatePrefixWithLargeBeam) {