diff --git a/.gitmodules b/.gitmodules index bab0da131..487539485 100644 --- a/.gitmodules +++ b/.gitmodules @@ -19,3 +19,6 @@ [submodule "src/09python_ffi/pybind11"] path = src/09python_ffi/pybind11 url = git@github.com:pybind/pybind11.git +[submodule "3rd-party/cccl"] + path = 3rd-party/cccl + url = git@github.com:NVIDIA/cccl.git diff --git a/3rd-party/backward-cpp b/3rd-party/backward-cpp index 3bb9240cb..51f070045 160000 --- a/3rd-party/backward-cpp +++ b/3rd-party/backward-cpp @@ -1 +1 @@ -Subproject commit 3bb9240cb15459768adb3e7d963a20e1523a6294 +Subproject commit 51f0700452cf71c57d43c2d028277b24cde32502 diff --git a/3rd-party/cccl b/3rd-party/cccl new file mode 160000 index 000000000..b7d4228ab --- /dev/null +++ b/3rd-party/cccl @@ -0,0 +1 @@ +Subproject commit b7d4228ab7268ed928984cd61096079bd671d25d diff --git a/CMakeLists.txt b/CMakeLists.txt index fe305d124..49ddcda61 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -105,4 +105,5 @@ add_subdirectory(src/05computation) add_subdirectory(src/06frontend) add_subdirectory(src/07onnx) add_subdirectory(src/08communication) +add_subdirectory(src/08-01llm) add_subdirectory(src/09python_ffi) diff --git a/README.md b/README.md index 2e7ca92a6..5ed6b27d3 100644 --- a/README.md +++ b/README.md @@ -168,7 +168,7 @@ executor = compiler.compile("cuda", "default", []) # -------- 编译模型 - [fmt 10.1.1](https://github.com/fmtlib/fmt/releases/tag/10.1.0) - [fmtlog v2.2.1](https://github.com/MengRao/fmtlog/releases/tag/v2.2.1) - [googletest v1.14.0](https://github.com/google/googletest/releases/tag/v1.14.0) -- [backward-cpp v1.6](https://github.com/bombela/backward-cpp/releases/tag/v1.6) +- [backward-cpp master](https://github.com/bombela/backward-cpp) - [result master](https://github.com/willowell/result) - [abseil-cpp 20230802.1](https://github.com/abseil/abseil-cpp/releases/tag/20230802.1) diff --git a/src/00common/CMakeLists.txt b/src/00common/CMakeLists.txt index ec315d965..8cdca4c9f 100644 --- a/src/00common/CMakeLists.txt +++ b/src/00common/CMakeLists.txt @@ -11,6 +11,5 @@ file(GLOB_RECURSE COMMON_TEST test/*.cpp) if(COMMON_TEST) add_executable(common_test ${COMMON_TEST}) add_test(common_test common_test) - target_link_libraries(common_test common GTest::gtest_main ${BACKWARD_ENABLE}) - add_backward(common_test) + target_link_libraries(common_test common GTest::gtest_main Backward::Object) endif() diff --git a/src/00common/include/common/rc.hpp b/src/00common/include/common/rc.hpp index f696c3329..00faafcbe 100644 --- a/src/00common/include/common/rc.hpp +++ b/src/00common/include/common/rc.hpp @@ -2,6 +2,7 @@ #define RC_HPP #include +#include namespace refactor { @@ -18,7 +19,7 @@ namespace refactor { T *_value; struct Counter { size_t strong, weak; - } * _counter; + } *_counter; Rc(T *ptr, Counter *counter) noexcept : _value(ptr), _counter(counter) { inc(); } diff --git a/src/01graph_topo/CMakeLists.txt b/src/01graph_topo/CMakeLists.txt index 34f1973c8..99316ab33 100644 --- a/src/01graph_topo/CMakeLists.txt +++ b/src/01graph_topo/CMakeLists.txt @@ -11,6 +11,5 @@ file(GLOB_RECURSE GRAPH_TOPO_TEST test/*.cpp) if(GRAPH_TOPO_TEST) add_executable(graph_topo_test ${GRAPH_TOPO_TEST}) add_test(graph_topo_test graph_topo_test) - target_link_libraries(graph_topo_test graph_topo GTest::gtest_main ${BACKWARD_ENABLE}) - add_backward(graph_topo_test) + target_link_libraries(graph_topo_test graph_topo GTest::gtest_main Backward::Object) endif() diff --git a/src/02hardware/CMakeLists.txt b/src/02hardware/CMakeLists.txt index b42ef6327..7f2c53ef6 100644 --- a/src/02hardware/CMakeLists.txt +++ b/src/02hardware/CMakeLists.txt @@ -2,21 +2,18 @@ cmake_minimum_required(VERSION 3.12 FATAL_ERROR) project(hardware VERSION 0.0.0 LANGUAGES CXX) message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) -# Source files file(GLOB_RECURSE HARDWARE_SRC src/*.cc src/*.cpp) +add_library(hardware STATIC ${HARDWARE_SRC} ${HARDWARE_CUDA_SRC}) +target_link_libraries(hardware PUBLIC common) +target_include_directories(hardware PUBLIC include) if(USE_CUDA) - file(GLOB_RECURSE HARDWARE_CUDA_SRC src/devices/nvidia/*.cu) + target_include_directories(hardware PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) endif() -add_library(hardware STATIC ${HARDWARE_SRC} ${HARDWARE_CUDA_SRC} ${HARDWARE_BANG_SRC}) -target_link_libraries(hardware PUBLIC common) -target_include_directories(hardware PUBLIC include) - file(GLOB_RECURSE HARDWARE_TEST test/*.cpp) if(HARDWARE_TEST) add_executable(hardware_test ${HARDWARE_TEST}) add_test(hardware_test hardware_test) - target_link_libraries(hardware_test hardware GTest::gtest_main ${BACKWARD_ENABLE}) - add_backward(hardware_test) + target_link_libraries(hardware_test hardware GTest::gtest_main Backward::Object) endif() diff --git a/src/02hardware/include/hardware/device.h b/src/02hardware/include/hardware/device.h index cb65a2730..91e5c4509 100644 --- a/src/02hardware/include/hardware/device.h +++ b/src/02hardware/include/hardware/device.h @@ -52,7 +52,7 @@ namespace refactor::hardware { virtual ~Device() = default; virtual Type type() const noexcept = 0; - virtual void setContext() const noexcept; + virtual void setContext() const; Arc malloc(size_t); Arc absorb(Arc &&); diff --git a/src/02hardware/include/hardware/devices/nvidia.h b/src/02hardware/include/hardware/devices/nvidia.h index 1facba2c3..d19dd3152 100644 --- a/src/02hardware/include/hardware/devices/nvidia.h +++ b/src/02hardware/include/hardware/devices/nvidia.h @@ -8,7 +8,7 @@ namespace refactor::hardware { class Nvidia final : public Device { public: explicit Nvidia(int32_t card); - void setContext() const noexcept final; + void setContext() const final; Type type() const noexcept final { return Type::Nvidia; } diff --git a/src/02hardware/src/device.cc b/src/02hardware/src/device.cc index 08c094994..29ac122e0 100644 --- a/src/02hardware/src/device.cc +++ b/src/02hardware/src/device.cc @@ -56,7 +56,7 @@ namespace refactor::hardware { Device::Device(decltype(_card) card, decltype(_mem) mem) : _card(card), _mem(std::move(mem)) {} - void Device::setContext() const noexcept {} + void Device::setContext() const {} auto Device::malloc(size_t size) -> Arc { return Arc(new Blob(this, size)); } diff --git a/src/02hardware/src/devices/nvidia/device.cc b/src/02hardware/src/devices/nvidia/device.cc index 0f0eb5f68..e298d378a 100644 --- a/src/02hardware/src/devices/nvidia/device.cc +++ b/src/02hardware/src/devices/nvidia/device.cc @@ -1,22 +1,39 @@ #include "functions.cuh" #include "hardware/devices/nvidia.h" #include "hardware/mem_pool.h" -#include "memory.cuh" + +#ifdef USE_CUDA +#include "memory.hh" +#include + +#define CUDA_ASSERT(STATUS) \ + if (auto status = (STATUS); status != cudaSuccess) { \ + RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ + cudaGetErrorString(status), (int) status)); \ + } +#endif namespace refactor::hardware { static Arc cudaMemory(int32_t card) { #ifdef USE_CUDA - ASSERT(0 <= card && card < getDeviceCount(), "Invalid card id: {}", card); - setDevice(card); - auto [free, total] = getMemInfo(); - auto size = std::min(free, std::max(5ul << 30, total * 4 / 5)); - fmt::println("initializing Nvidia GPU {}, memory {} / {}, alloc {}", - card, free, total, size); + int deviceCount; + CUDA_ASSERT(cudaGetDeviceCount(&deviceCount)); + ASSERT(0 <= card && card < deviceCount, "Invalid card id: {}", card); + CUDA_ASSERT(cudaSetDevice(card)); + + size_t free, total; + CUDA_ASSERT(cudaMemGetInfo(&free, &total)); + auto size = free * 9 / 10; + cudaDeviceProp prop; + CUDA_ASSERT(cudaGetDeviceProperties(&prop, 0)); + size_t alignment = prop.textureAlignment; + fmt::println("initializing Nvidia GPU {}, memory {} / {}, alloc {}, alignment {}", + card, free, total, size, alignment); return std::make_shared( std::make_shared(), size, - 256ul); + alignment); #else return nullptr; #endif @@ -24,8 +41,10 @@ namespace refactor::hardware { Nvidia::Nvidia(int32_t card) : Device(card, cudaMemory(card)) {} - void Nvidia::setContext() const noexcept { - setDevice(_card); + void Nvidia::setContext() const { +#ifdef USE_CUDA + CUDA_ASSERT(cudaSetDevice(_card)); +#endif } }// namespace refactor::hardware diff --git a/src/02hardware/src/devices/nvidia/functions.cu b/src/02hardware/src/devices/nvidia/functions.cu deleted file mode 100644 index 844ef388c..000000000 --- a/src/02hardware/src/devices/nvidia/functions.cu +++ /dev/null @@ -1,19 +0,0 @@ -#include "functions.cuh" - -namespace refactor::hardware { - - int getDeviceCount() { - int deviceCount; - CUDA_ASSERT(cudaGetDeviceCount(&deviceCount)); - return deviceCount; - } - void setDevice(int device) { - CUDA_ASSERT(cudaSetDevice(device)); - } - MemInfo getMemInfo() { - MemInfo memInfo; - CUDA_ASSERT(cudaMemGetInfo(&memInfo.free, &memInfo.total)); - return memInfo; - } - -}// namespace refactor::hardware diff --git a/src/02hardware/src/devices/nvidia/functions.cuh b/src/02hardware/src/devices/nvidia/functions.cuh deleted file mode 100644 index 0a47d4492..000000000 --- a/src/02hardware/src/devices/nvidia/functions.cuh +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef HARDWARE_DEVICES_NVIDIA_FUNCTIONS_CUH -#define HARDWARE_DEVICES_NVIDIA_FUNCTIONS_CUH - -#include "common.h" - -#define CUDA_ASSERT(STATUS) \ - if (auto status = (STATUS); status != cudaSuccess) { \ - RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ - cudaGetErrorString(status), (int) status)); \ - } - -namespace refactor::hardware { - - struct MemInfo { - size_t free, total; - }; - - int getDeviceCount(); - void setDevice(int device); - MemInfo getMemInfo(); - -}// namespace refactor::hardware - -#endif// HARDWARE_DEVICES_NVIDIA_FUNCTIONS_CUH diff --git a/src/02hardware/src/devices/nvidia/memory.cu b/src/02hardware/src/devices/nvidia/memory.cc similarity index 69% rename from src/02hardware/src/devices/nvidia/memory.cu rename to src/02hardware/src/devices/nvidia/memory.cc index b3c5fe3d3..42310196c 100644 --- a/src/02hardware/src/devices/nvidia/memory.cu +++ b/src/02hardware/src/devices/nvidia/memory.cc @@ -1,5 +1,14 @@ -#include "functions.cuh" -#include "memory.cuh" +#ifdef USE_CUDA + +#include "memory.hh" +#include "common.h" +#include + +#define CUDA_ASSERT(STATUS) \ + if (auto status = (STATUS); status != cudaSuccess) { \ + RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ + cudaGetErrorString(status), (int) status)); \ + } namespace refactor::hardware { using M = NvidiaMemory; @@ -29,3 +38,5 @@ namespace refactor::hardware { } }// namespace refactor::hardware + +#endif diff --git a/src/02hardware/src/devices/nvidia/memory.cuh b/src/02hardware/src/devices/nvidia/memory.hh similarity index 100% rename from src/02hardware/src/devices/nvidia/memory.cuh rename to src/02hardware/src/devices/nvidia/memory.hh diff --git a/src/03runtime/CMakeLists.txt b/src/03runtime/CMakeLists.txt index 96256b255..9fd5a99f3 100644 --- a/src/03runtime/CMakeLists.txt +++ b/src/03runtime/CMakeLists.txt @@ -11,6 +11,5 @@ file(GLOB_RECURSE RUNTIME_TEST test/*.cpp) if(RUNTIME_TEST) add_executable(runtime_test ${RUNTIME_TEST}) add_test(runtime_test runtime_test) - target_link_libraries(runtime_test runtime GTest::gtest_main ${BACKWARD_ENABLE}) - add_backward(runtime_test) + target_link_libraries(runtime_test runtime GTest::gtest_main Backward::Object) endif() diff --git a/src/03runtime/include/runtime/stream.h b/src/03runtime/include/runtime/stream.h index 24d9fe73e..4f9e34f6c 100644 --- a/src/03runtime/include/runtime/stream.h +++ b/src/03runtime/include/runtime/stream.h @@ -42,9 +42,11 @@ namespace refactor::runtime { decltype(_device)); decltype(_graph) const &graph() const noexcept { return _graph; } - void setData(count_t, void const *, size_t); + auto setData(count_t, size_t) -> Arc; void setData(count_t, Arc); - bool getData(count_t, void *, size_t) const; + auto getData(count_t) const -> Arc; + void setData(count_t, void const *, size_t); + bool copyData(count_t, void *, size_t) const; void run(); auto bench(void (*sync)()) -> std::vector; void trace(std::function); diff --git a/src/03runtime/src/stream.cc b/src/03runtime/src/stream.cc index 569769c1d..570563bc0 100644 --- a/src/03runtime/src/stream.cc +++ b/src/03runtime/src/stream.cc @@ -18,15 +18,21 @@ namespace refactor::runtime { std::move(edges), } {} + auto Stream::setData(count_t i, size_t size) -> Arc { + return _graph.edges[i].blob = _device->malloc(size); + } + void Stream::setData(count_t i, Arc blob) { + _graph.edges[i].blob = std::move(blob); + } void Stream::setData(count_t i, void const *data, size_t size) { auto blob = _device->malloc(size); blob->copyFromHost(data, size); _graph.edges[i].blob = std::move(blob); } - void Stream::setData(count_t i, Arc blob) { - _graph.edges[i].blob = std::move(blob); + auto Stream::getData(count_t i) const -> Arc { + return _graph.edges[i].blob; } - bool Stream::getData(count_t i, void *data, size_t size) const { + bool Stream::copyData(count_t i, void *data, size_t size) const { if (!_graph.edges[i].blob) { return false; } _graph.edges[i].blob->copyToHost(data, size); return true; diff --git a/src/04kernel/CMakeLists.txt b/src/04kernel/CMakeLists.txt index 1349193e5..3a401ac35 100644 --- a/src/04kernel/CMakeLists.txt +++ b/src/04kernel/CMakeLists.txt @@ -43,6 +43,5 @@ file(GLOB_RECURSE KERNEL_TEST test/*.cpp) if(KERNEL_TEST) add_executable(kernel_test ${KERNEL_TEST}) add_test(kernel_test kernel_test) - target_link_libraries(kernel_test kernel GTest::gtest_main ${BACKWARD_ENABLE}) - add_backward(kernel_test) + target_link_libraries(kernel_test kernel GTest::gtest_main Backward::Object) endif() diff --git a/src/04kernel/cuda/CMakeLists.txt b/src/04kernel/cuda/CMakeLists.txt index 3cbf8d38e..4c976e33d 100644 --- a/src/04kernel/cuda/CMakeLists.txt +++ b/src/04kernel/cuda/CMakeLists.txt @@ -11,6 +11,5 @@ file(GLOB_RECURSE KERNEL_CUDA_TEST test/*.cu) if(KERNEL_CUDA_TEST) add_executable(kernel_cuda_test ${KERNEL_CUDA_TEST}) add_test(kernel_cuda_test kernel_cuda_test) - target_link_libraries(kernel_cuda_test kernel_cuda GTest::gtest_main ${BACKWARD_ENABLE}) - add_backward(kernel_cuda_test) + target_link_libraries(kernel_cuda_test kernel_cuda GTest::gtest_main Backward::Object) endif() diff --git a/src/04kernel/include/kernel/collectors/hard_sigmoid.h b/src/04kernel/include/kernel/collectors/hard_sigmoid.h new file mode 100644 index 000000000..2395b51bd --- /dev/null +++ b/src/04kernel/include/kernel/collectors/hard_sigmoid.h @@ -0,0 +1,20 @@ +#ifndef KERNEL_HARD_SIGMOIG_H +#define KERNEL_HARD_SIGMOIG_H + +#include "../collector.h" + +namespace refactor::kernel { + + struct HardSigmoidCollector final : public InfoCollector { + float alpha, beta; + + constexpr HardSigmoidCollector(decltype(_target) target, float alpha_, float beta_) noexcept + : InfoCollector(target), alpha(alpha_), beta(beta_) {} + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; +}// namespace refactor::kernel + +#endif// KERNEL_HARD_SIGMOIG_H + diff --git a/src/04kernel/include/kernel/collectors/rms_normalization.h b/src/04kernel/include/kernel/collectors/rms_normalization.h new file mode 100644 index 000000000..9d3de6e4d --- /dev/null +++ b/src/04kernel/include/kernel/collectors/rms_normalization.h @@ -0,0 +1,20 @@ +#ifndef KERNEL_RMS_NORMALIZATION_H +#define KERNEL_RMS_NORMALIZATION_H + +#include "../collector.h" + +namespace refactor::kernel { + + struct RmsNormalizationCollector final : public InfoCollector { + float epsilon; + + constexpr RmsNormalizationCollector(decltype(_target) target, float epsilon_) noexcept + : InfoCollector(target), epsilon(epsilon_) {} + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_RMS_NORMALIZATION_H diff --git a/src/04kernel/include/kernel/collectors/simple_binary.h b/src/04kernel/include/kernel/collectors/simple_binary.h index 87d1b5f6e..7423ee703 100644 --- a/src/04kernel/include/kernel/collectors/simple_binary.h +++ b/src/04kernel/include/kernel/collectors/simple_binary.h @@ -14,6 +14,8 @@ namespace refactor::kernel { And, Or, Xor, + Mod, + Fmod, }; std::string_view opName(SimpleBinaryType type); diff --git a/src/04kernel/include/kernel/collectors/simple_unary.h b/src/04kernel/include/kernel/collectors/simple_unary.h index ee190ee17..913c0095a 100644 --- a/src/04kernel/include/kernel/collectors/simple_unary.h +++ b/src/04kernel/include/kernel/collectors/simple_unary.h @@ -25,6 +25,7 @@ namespace refactor::kernel { Erf, Neg, Not, + HardSwish, }; std::string_view unaryName(SimpleUnaryType type); diff --git a/src/04kernel/src/attributes/slice_info.cc b/src/04kernel/src/attributes/slice_info.cc index a09146c61..a3397c827 100644 --- a/src/04kernel/src/attributes/slice_info.cc +++ b/src/04kernel/src/attributes/slice_info.cc @@ -39,41 +39,36 @@ namespace refactor::kernel { } dims_.resize(rank = shape.size()); } - dims.reserve(rank); + // 合并末尾的连续维度 + for (auto i : range0_(rank).rev()) { + if (auto d = shape[i]; dims_[i].length == d) { + blockSize *= d; + shape.pop_back(); + dims_.pop_back(); + } else { + dims.resize(rank = shape.size()); + if (auto &dim = dims_[i]; dim.step == 1) { + if (auto times = std::gcd(std::gcd(dim.start, dim.length), shape[i]); times > 1) { + blockSize *= times; + dim.start /= times; + dim.length /= times; + shape[i] /= times; + } + } + break; + } + } dim_t strideI = 1; for (auto i : range0_(rank).rev()) { auto const &dim = dims_[i]; - dims.push_back({ + dims[i] = { .strideO = blockCount, .skip = static_cast(strideI * dim.start), .strideI = static_cast(strideI * dim.step), - }); + }; blockCount *= dim.length; strideI *= shape[i]; } - std::reverse(dims.begin(), dims.end()); - - while (!dims.empty()) { - auto const &dim = dims.back(); - if (dim.strideI == static_cast(dim.strideO) && !dim.skip) { - dims.pop_back(); - } else { - long times = std::gcd(std::gcd(dim.strideI, dim.strideO), dim.skip); - blockCount /= times; - blockSize *= times; - if (!dims.empty()) { - for (auto &dim : dims) { - dim.strideO /= times; - dim.skip /= times; - dim.strideI /= times; - } - if (dims.back().strideO != 1) { - dims.push_back({1, 0, 1}); - } - } - break; - } - } } SliceInfo SliceInfo::reform(dim_t maxblockSize) const noexcept { @@ -97,5 +92,4 @@ namespace refactor::kernel { dims.back() = {1, 0, 1}; } - }// namespace refactor::kernel diff --git a/src/04kernel/src/collectors/hard_sigmoid.cc b/src/04kernel/src/collectors/hard_sigmoid.cc new file mode 100644 index 000000000..69d2f9d1e --- /dev/null +++ b/src/04kernel/src/collectors/hard_sigmoid.cc @@ -0,0 +1,29 @@ +#include "kernel/collectors/hard_sigmoid.h" +#include "../kernels/hard_sigmoid/cpu_kernel.hh" +#include "../kernels/hard_sigmoid/cuda_kernel.hh" + +namespace refactor::kernel { + + std::vector + HardSigmoidCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + auto const &a = inputs[0]; + + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + if (auto ptr = HardSigmoidCpu::build(alpha, beta, a); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + case decltype(_target)::Nvidia: + if (auto ptr = HardSigmoidCuda::build(alpha, beta, a); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/collectors/rms_normalization.cc b/src/04kernel/src/collectors/rms_normalization.cc new file mode 100644 index 000000000..f8251c542 --- /dev/null +++ b/src/04kernel/src/collectors/rms_normalization.cc @@ -0,0 +1,29 @@ +#include "kernel/collectors/rms_normalization.h" +#include "../kernels/rms_normalization/cpu_kernel.hh" +#include "../kernels/rms_normalization/cuda_kernel.hh" + +namespace refactor::kernel { + +#define REGISTER(T) \ + if (auto ptr = T::build(epsilon, inputs[0]); ptr) { \ + ans.emplace_back(std::move(ptr)); \ + } + + std::vector + RmsNormalizationCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + REGISTER(RmsNormalizationCpu) + break; + case decltype(_target)::Nvidia: + REGISTER(RmsNormalizationCuda) + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/collectors/simple_binary.cc b/src/04kernel/src/collectors/simple_binary.cc index e2c001ff7..53ae6723c 100644 --- a/src/04kernel/src/collectors/simple_binary.cc +++ b/src/04kernel/src/collectors/simple_binary.cc @@ -19,6 +19,8 @@ namespace refactor::kernel { CASE(And); CASE(Or); CASE(Xor); + CASE(Mod); + CASE(Fmod); default: UNREACHABLE(); } diff --git a/src/04kernel/src/collectors/simple_unary.cc b/src/04kernel/src/collectors/simple_unary.cc index de9e0bb07..51a334c91 100644 --- a/src/04kernel/src/collectors/simple_unary.cc +++ b/src/04kernel/src/collectors/simple_unary.cc @@ -31,6 +31,7 @@ namespace refactor::kernel { CASE(Erf); CASE(Neg); CASE(Not); + CASE(HardSwish); default: UNREACHABLE(); } diff --git a/src/04kernel/src/generator/nvrtc_repo.cc b/src/04kernel/src/generator/nvrtc_repo.cc index 1767a869a..ed2e5c0b5 100644 --- a/src/04kernel/src/generator/nvrtc_repo.cc +++ b/src/04kernel/src/generator/nvrtc_repo.cc @@ -2,6 +2,7 @@ #include "nvrtc_repo.h" #include "hardware/device_manager.h" +#include #include #define NVRTC_ASSERT(CALL) \ @@ -38,9 +39,24 @@ namespace refactor::kernel::nvrtc { NVRTC_ASSERT(nvrtcCreateProgram(&prog, code.data(), name.data(), 0, nullptr, nullptr)); std::vector opts{"--std=c++17", "--gpu-architecture=compute_80"}; + { + auto proj = std::filesystem::path(__FILE__) + .parent_path() + .parent_path() + .parent_path() + .parent_path() + .parent_path(); + auto cccl = proj / "3rd-party/cccl"; + auto cudacxx = cccl / "libcudacxx/include"; + auto cub = cccl / "cub"; + ASSERT(std::filesystem::is_directory(cub), "cub not exist"); + opts.emplace_back(fmt::format("-I{}", cudacxx.c_str())); + opts.emplace_back(fmt::format("-I{}", cub.c_str())); + } #ifdef CUDA_INCLUDE_PATH opts.emplace_back(fmt::format("-I{}", CUDA_INCLUDE_PATH)); #endif + std::vector optsPtr(opts.size()); std::transform(opts.begin(), opts.end(), optsPtr.begin(), [](auto &s) { return s.c_str(); }); diff --git a/src/04kernel/src/graph.cc b/src/04kernel/src/graph.cc index 973595c0c..1ab4ab83e 100644 --- a/src/04kernel/src/graph.cc +++ b/src/04kernel/src/graph.cc @@ -1,5 +1,22 @@ #include "kernel/graph.h" +namespace refactor { + struct DataKey { + Arc dev; + Arc blob; + bool operator==(const DataKey &) const = default;// since C++20 + }; +}// namespace refactor + +template<> +struct std::hash { + std::size_t operator()(refactor::DataKey const &s) const noexcept { + auto hd = std::hash()(s.dev), + hb = std::hash()(s.blob); + return hd ^ (hb << 1); + } +}; + namespace refactor::kernel { Graph::Graph(graph_topo::GraphTopo topology, @@ -31,13 +48,19 @@ namespace refactor::kernel { _internal.edges, 32); + static std::unordered_map> CACHE; + for (auto i : range0_(edges_.size())) { auto const &edge = _internal.edges[i]; edges_[i].name = edge.name; if (edge.data) { - auto blob = device->malloc(edge.size); - blob->copyFromHost(edge.data->get()); - edges_[i].blob = std::move(blob); + auto it = CACHE.find({device, edge.data}); + if (it == CACHE.end()) { + auto blob = device->malloc(edge.size); + blob->copyFromHost(edge.data->get()); + std::tie(it, std::ignore) = CACHE.emplace(DataKey{device, edge.data}, std::move(blob)); + } + edges_[i].blob = it->second; } else if (edges_[i].stackOffset == SIZE_MAX - 1) { edges_[i].blob = device->malloc(edge.size); } diff --git a/src/04kernel/src/kernels/hard_sigmoid/cpu_kernel.cc b/src/04kernel/src/kernels/hard_sigmoid/cpu_kernel.cc new file mode 100644 index 000000000..8d6835622 --- /dev/null +++ b/src/04kernel/src/kernels/hard_sigmoid/cpu_kernel.cc @@ -0,0 +1,54 @@ +#include "cpu_kernel.hh" +#include + +namespace refactor::kernel { + using K = HardSigmoidCpu; + using DT = DataType; + + K::HardSigmoidCpu(float alpha_, float beta_, DT dataType_, size_t size_) noexcept + : Kernel(), alpha(alpha_), beta(beta_), dataType(dataType_), size(size_) {} + + auto K::build(float alpha_, float beta_, Tensor const &a) noexcept -> KernelBox { + if (!a.dataType.isCpuNumberic()) { + return nullptr; + } + return std::make_unique(alpha_, beta_, a.dataType, a.elementsSize()); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing HardSigmoid using CPU"; + } + + template + static Routine lowerTyped(float alpha_, float beta_, size_t size) { + using namespace runtime; + + return [=](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto x = reinterpret_cast(inputs[0]); + auto y = reinterpret_cast(outputs[0]); + std::for_each_n(std::execution::par_unseq, + natural_t(0), size, + [&](auto i) { + y[i] = std::clamp(alpha_ * x[i] + beta_, static_cast(0), static_cast(1)); + }); + }; + } + + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { + switch (dataType) { + case DT::F32: + return lowerTyped(alpha, beta, size); + case DT::F64: + return lowerTyped(alpha, beta, size); + default: + UNREACHABLE(); + } + } +}// namespace refactor::kernel + diff --git a/src/04kernel/src/kernels/hard_sigmoid/cpu_kernel.hh b/src/04kernel/src/kernels/hard_sigmoid/cpu_kernel.hh new file mode 100644 index 000000000..ef65a4534 --- /dev/null +++ b/src/04kernel/src/kernels/hard_sigmoid/cpu_kernel.hh @@ -0,0 +1,27 @@ +#ifndef KERNEL_HARD_SIGMOID_CPU_KERNEL_HH +#define KERNEL_HARD_SIGMOID_CPU_KERNEL_HH + +#include "kernel/collectors/hard_sigmoid.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct HardSigmoidCpu final : public Kernel { + float alpha, beta; + DataType dataType; + size_t size; + + explicit HardSigmoidCpu(float, float, DataType, size_t) noexcept; + + static KernelBox build(float, float, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; + RoutineWorkspace lower(Resources &) const noexcept final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_HARD_SIGMOID_CPU_KERNEL_HH + diff --git a/src/04kernel/src/kernels/hard_sigmoid/cuda_kernel.cc b/src/04kernel/src/kernels/hard_sigmoid/cuda_kernel.cc new file mode 100644 index 000000000..338e3a867 --- /dev/null +++ b/src/04kernel/src/kernels/hard_sigmoid/cuda_kernel.cc @@ -0,0 +1,88 @@ +#include "cuda_kernel.hh" + +#ifdef USE_CUDA +#include "../../generator/nvrtc_repo.h" +#include "kernel/cuda/threads_distributer.cuh" +#include +#endif + +namespace refactor::kernel { + using K = HardSigmoidCuda; + using DT = DataType; + + K::HardSigmoidCuda(float alpha_, float beta_, DT dt_, size_t size_) noexcept + : Kernel(), alpha(alpha_), beta(beta_), dataType(dt_), size(size_) {} + + auto K::build(float alpha_, float beta_, Tensor const &a) noexcept -> KernelBox { +#ifndef USE_CUDA + return nullptr; +#endif + return std::make_unique(alpha_, beta_, a.dataType, a.elementsSize()); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing hardsigmoid operation on Nvidia GPU"; + } + +#ifdef USE_CUDA + constexpr static const char *TEMPLATE = R"~( +__device__ __forceinline__ static {0:} fn({0:} x) {{ + return {1:}; +}} + +extern "C" __global__ void kernel( + {0:} *__restrict__ y, + {0:} const *__restrict__ x, + size_t n +) {{ + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step) + y[tid] = fn(x[tid]); +}} + )~"; + auto K::lower(Resources &res) const -> RoutineWorkspace { + using namespace runtime; + + std::string op = ""; + switch (dataType) { + case DT::F32: + op = fmt::format("fmaxf(0.f, fminf(1.f, fmaf({}, x, {})))", alpha, beta); + break; + case DT::F64: + op = fmt::format("fmax(0.0, fmin(1.0, fma({}, x, {})))", + static_cast(alpha), static_cast(beta)); + break; + case DT::FP16: + op = fmt::format("__hmax(CUDART_ZERO_FP16, __hmin(CUDART_ONE_FP16, (__float2half({}) * x + __float2half({}))))", + alpha, beta); + break; + default: + UNREACHABLE(); + } + auto name = fmt::format("hardsigmoid_{}_{}_{}", dataType.name(), alpha, beta); + auto code = fmt::format(TEMPLATE, nvrtc::dataType(dataType), op); + return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), + params = cuda::ThreadsDistributer()(size)]( + Resources &, void *, void const *const *inputs, void *const *outputs) { + auto y = outputs[0]; + auto x = inputs[0]; + auto n = params.n; + void *args[]{&y, &x, &n}; + h->launch(params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, args); + }; + } +#endif + +}// namespace refactor::kernel + diff --git a/src/04kernel/src/kernels/hard_sigmoid/cuda_kernel.hh b/src/04kernel/src/kernels/hard_sigmoid/cuda_kernel.hh new file mode 100644 index 000000000..99d5439e2 --- /dev/null +++ b/src/04kernel/src/kernels/hard_sigmoid/cuda_kernel.hh @@ -0,0 +1,28 @@ +#ifndef KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH +#define KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH + +#include "kernel/collectors/hard_sigmoid.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct HardSigmoidCuda final : public Kernel { + float alpha, beta; + DataType dataType; + size_t size; + + explicit HardSigmoidCuda(float, float, DataType, size_t) noexcept; + + static KernelBox build(float, float, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH diff --git a/src/04kernel/src/kernels/rms_normalization/cpu_kernel.cc b/src/04kernel/src/kernels/rms_normalization/cpu_kernel.cc new file mode 100644 index 000000000..4d1a060aa --- /dev/null +++ b/src/04kernel/src/kernels/rms_normalization/cpu_kernel.cc @@ -0,0 +1,75 @@ +#include "cpu_kernel.hh" +#include +#include + +namespace refactor::kernel { + using K = RmsNormalizationCpu; + + K::RmsNormalizationCpu( + decltype(epsilon) epsilon_, + decltype(dataType) dataType_, + decltype(blockCount) blockCount_, + decltype(blockSize) blockSize_) noexcept + : Kernel(), + epsilon(epsilon_), + dataType(dataType_), + blockCount(blockCount_), + blockSize(blockSize_) {} + + auto K::build(float epsilon, Tensor const &x) noexcept -> KernelBox { + if (x.dataType != DataType::F32 && x.dataType != DataType::F64) { + return nullptr; + } + auto it = x.shape.rbegin(); + dim_t blockSize = *it++; + dim_t blockCount = std::accumulate(it, x.shape.rend(), 1, std::multiplies()); + return std::make_unique(epsilon, x.dataType, blockCount, blockSize); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing rms normalization on generic cpu"; + } + + template + static Routine lowerTyped(float epsilon, dim_t blockCount, dim_t blockSize) { + using namespace runtime; + + return [epsilon, blockCount, blockSize]// + (Resources &, void *, void const *const *inputs, void *const *outputs) { + auto x = reinterpret_cast(inputs[0]); + auto w = reinterpret_cast(inputs[1]); + auto y = reinterpret_cast(outputs[0]); + std::for_each_n( + std::execution::par_unseq, + natural_t(0), + blockCount, + [blockSize, epsilon, x, w, y](auto i) { + auto x_ = x + i * blockSize; + auto y_ = y + i * blockSize; + + auto ss = std::accumulate( + x_, x_ + blockSize, static_cast(0), + [](auto acc, auto it) { return acc + it * it; }); + ss /= blockSize; + ss += epsilon; + ss = 1. / std::sqrt(ss); + + for (auto j : range0_(blockSize)) { + y_[j] = x_[j] * ss * w[j]; + } + }); + }; + } + + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { + return dataType == DataType::F32 + ? lowerTyped(epsilon, blockCount, blockSize) + : lowerTyped(epsilon, blockCount, blockSize); + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/rms_normalization/cpu_kernel.hh b/src/04kernel/src/kernels/rms_normalization/cpu_kernel.hh new file mode 100644 index 000000000..7a3e74a77 --- /dev/null +++ b/src/04kernel/src/kernels/rms_normalization/cpu_kernel.hh @@ -0,0 +1,30 @@ +#ifndef KERNEL_RMS_NORMALIZATION_CPU_KERNEL_HH +#define KERNEL_RMS_NORMALIZATION_CPU_KERNEL_HH + +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct RmsNormalizationCpu final : public Kernel { + float epsilon; + DataType dataType; + dim_t blockCount, blockSize; + + RmsNormalizationCpu( + decltype(epsilon), + decltype(dataType), + decltype(blockCount), + decltype(blockSize)) noexcept; + + static KernelBox build(float, Tensor const &x) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; + RoutineWorkspace lower(Resources &) const noexcept final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_RMS_NORMALIZATION_CPU_KERNEL_HH diff --git a/src/04kernel/src/kernels/rms_normalization/cuda_kernel.cc b/src/04kernel/src/kernels/rms_normalization/cuda_kernel.cc new file mode 100644 index 000000000..b051f42cc --- /dev/null +++ b/src/04kernel/src/kernels/rms_normalization/cuda_kernel.cc @@ -0,0 +1,122 @@ +#include "cuda_kernel.hh" +#include + +#ifdef USE_CUDA +#include "../../generator/nvrtc_repo.h" +#include +#include +#endif + +namespace refactor::kernel { + using K = RmsNormalizationCuda; + + K::RmsNormalizationCuda( + decltype(epsilon) epsilon_, + decltype(dataType) dataType_, + decltype(blockCount) blockCount_, + decltype(blockSize) blockSize_) noexcept + : Kernel(), + epsilon(epsilon_), + dataType(dataType_), + blockCount(blockCount_), + blockSize(blockSize_) {} + + auto K::build(float epsilon, Tensor const &x) noexcept -> KernelBox { +#ifndef USE_CUDA + return nullptr; +#endif + + if (!x.dataType.isFloat()) { + return nullptr; + } + auto it = x.shape.rbegin(); + dim_t blockSize = *it++; + dim_t blockCount = std::accumulate(it, x.shape.rend(), 1, std::multiplies()); + return std::make_unique(epsilon, x.dataType, blockCount, blockSize); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing rms normalization using CUDA"; + } + +#ifdef USE_CUDA + + // 0: data type + // 1: block size + // 2: T -> float + // 3: T <- float + constexpr static const char *TEMPLATE = R"~( +#include + +extern "C" __global__ void kernel( + {0:} *__restrict__ y, + {0:} const *__restrict__ x, + {0:} const *__restrict__ w, + float epsilon) {{ + + x += blockIdx.x * blockDim.x + threadIdx.x; + y += blockIdx.x * blockDim.x + threadIdx.x; + w += threadIdx.x; + + using BlockReduce = cub::BlockReduce<{0:}, {1:}>; + __shared__ typename BlockReduce::TempStorage tempStorage; + __shared__ {0:} rms; + auto acc = BlockReduce(tempStorage).Reduce(*x * *x, cub::Sum()); + if (threadIdx.x == 0) {{ + rms = {3:}(rsqrt({2:}(acc) / blockDim.x + epsilon)); + }} + __syncthreads(); + + *y = *x * rms * *w; +}} +)~"; + + auto K::lower(Resources &) const -> RoutineWorkspace { + using namespace runtime; + + std::stringstream ss; + ss << "RmsNorm" << nvrtc::dataType(dataType) << blockSize; + ss << ".cu"; + auto name = ss.str(); + auto code = fmt::format( + TEMPLATE, + nvrtc::dataType(dataType),// 0 + blockSize, // 1 + // clang-format off + dataType == DataType::F32 ? "" + : dataType == DataType::F64 ? "static_cast" + : dataType == DataType::FP16 ? "__half2float" + : dataType == DataType::BF16 ? "__bfloat162float" + : UNREACHABLEX(const char*, "unreachable"), + dataType == DataType::F32 ? "" + : dataType == DataType::F64 ? "" + : dataType == DataType::FP16 ? "__float2half" + : dataType == DataType::BF16 ? "__float2bfloat16" + : UNREACHABLEX(const char*, "unreachable") + // clang-format on + ); + + return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), + epsilon_ = this->epsilon, + blockCount = this->blockCount, + blockSize = this->blockSize]// + (Resources &, void *, void const *const *inputs, void *const *outputs) { + auto y = outputs[0]; + auto x = inputs[0]; + auto w = inputs[1]; + auto epsilon = epsilon_; + void *args[]{&y, &x, &w, &epsilon}; + h->launch(blockCount, 1, 1, + blockSize, 1, 1, + 0, args); + }; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/rms_normalization/cuda_kernel.hh b/src/04kernel/src/kernels/rms_normalization/cuda_kernel.hh new file mode 100644 index 000000000..1cd266ebf --- /dev/null +++ b/src/04kernel/src/kernels/rms_normalization/cuda_kernel.hh @@ -0,0 +1,31 @@ +#ifndef KERNEL_RMS_NORMALIZATION_CUDA_KERNEL_HH +#define KERNEL_RMS_NORMALIZATION_CUDA_KERNEL_HH + +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + struct RmsNormalizationCuda final : public Kernel { + float epsilon; + DataType dataType; + dim_t blockCount, blockSize; + + RmsNormalizationCuda( + decltype(epsilon), + decltype(dataType), + decltype(blockCount), + decltype(blockSize)) noexcept; + + static KernelBox build(float, Tensor const &x) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_RMS_NORMALIZATION_CUDA_KERNEL_HH diff --git a/src/04kernel/src/kernels/simple_binary/cpu_kernel.cc b/src/04kernel/src/kernels/simple_binary/cpu_kernel.cc index 6a737f3ae..fad5478cb 100644 --- a/src/04kernel/src/kernels/simple_binary/cpu_kernel.cc +++ b/src/04kernel/src/kernels/simple_binary/cpu_kernel.cc @@ -1,4 +1,5 @@ #include "cpu_kernel.hh" +#include #include namespace refactor::kernel { @@ -118,8 +119,38 @@ namespace refactor::kernel { UNREACHABLE(); } } - default: - UNREACHABLE(); + case Op::Mod: { + switch (dataType.internal) { + CASE_DT(a % b, U8); + CASE_DT(a % b, I8); + CASE_DT(a % b, U16); + CASE_DT(a % b, I16); + CASE_DT(a % b, I32); + CASE_DT(a % b, I64); + CASE_DT(a % b, U32); + CASE_DT(a % b, U64); + default: + UNREACHABLE(); + } + } + case Op::Fmod: { + switch (dataType.internal) { + CASE_DT(std::fmod(a, b), F32); + CASE_DT(a % b, U8); + CASE_DT(static_cast(std::fmod(a, b)), I8); + CASE_DT(a % b, U16); + CASE_DT(static_cast(std::fmod(a, b)), I16); + CASE_DT(static_cast(std::fmod(a, b)), I32); + CASE_DT(static_cast(std::fmod(a, b)), I64); + CASE_DT(std::fmod(a, b), F64); + CASE_DT(a % b, U32); + CASE_DT(a % b, U64); + default: + UNREACHABLE(); + } + default: + UNREACHABLE(); + } } } diff --git a/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc b/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc index 58d5f677e..25184617d 100644 --- a/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc +++ b/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc @@ -59,7 +59,28 @@ extern "C" __global__ void kernel( }} )~"; - constexpr static const char *SCALAR = R"~( + constexpr static const char *SCALAR_A = R"~( +__device__ __forceinline__ static {0:} fn({0:} a, {0:} b) {{ + return {1:}; +}} + +extern "C" __global__ void kernel( + {0:} *__restrict__ y, + {0:} const *__restrict__ s, + {0:} const *__restrict__ v, + size_t n +) {{ + auto num = *s; + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step) {{ + y[tid] = fn(num, v[tid]); + }} +}} +)~"; + + constexpr static const char *SCALAR_B = R"~( __device__ __forceinline__ static {0:} fn({0:} a, {0:} b) {{ return {1:}; }} @@ -135,12 +156,50 @@ extern "C" __global__ void kernel( case DataType::F32: return "powf(a, b)"; case DataType::FP16: - return "__float2half(__powf(__half2float(a), __half2float(b)))"; + return "__float2half(powf(__half2float(a), __half2float(b)))"; case DataType::BF16: return "__float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b)))"; default: return "pow(a, b)"; } + case SimpleBinaryType::Mod: + switch (dt) { + case DataType::U8: + case DataType::I8: + case DataType::U16: + case DataType::I16: + case DataType::I32: + case DataType::I64: + case DataType::U32: + case DataType::U64: + return "a % b"; + default: + UNREACHABLE(); + } + case SimpleBinaryType::Fmod: + switch (dt) { + case DataType::U8: + case DataType::U16: + case DataType::U32: + case DataType::U64: + return "a % b"; + case DataType::I8: + return "static_cast(fmodf(a, b))"; + case DataType::I16: + return "static_cast(fmodf(a, b))"; + case DataType::I32: + return "static_cast(fmodf(a, b))"; + case DataType::I64: + return "static_cast(fmodf(a, b))"; + case DataType::F32: + return "fmodf(a, b)"; + case DataType::FP16: + return "__float2half(fmodf(__half2float(a), __half2float(b)))"; + case DataType::BF16: + return "__float2bfloat16(fmodf(__bfloat162float(a), __bfloat162float(b)))"; + default: + UNREACHABLE(); + } default: UNREACHABLE(); } @@ -171,19 +230,17 @@ extern "C" __global__ void kernel( } else if (auto rank = broadcaster.strides.size() / (broadcaster.inputsCount + 1); rank == 1) { static const std::vector S0{0, 1, 1}, S1{1, 0, 1}; - auto name = fmt::format("binaryScalar{}", postfix); - auto code = fmt::format(SCALAR, dt_, op_); - return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), - // clang-format off - scalar = broadcaster.strides == S0 ? 0 - : broadcaster.strides == S1 ? 1 - : UNREACHABLEX(int, "Unreachable")]// clang-format on + auto scalar_a = broadcaster.strides == S0; + auto name = fmt::format("binaryScalar{}{}", postfix, scalar_a ? "A" : "B"); + auto code = scalar_a ? fmt::format(SCALAR_A, dt_, op_) + : fmt::format(SCALAR_B, dt_, op_); + return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel")]// (Resources &, void *, void const *const *inputs, void *const *outputs) { auto c = outputs[0]; - auto s = inputs[scalar], - v = inputs[1 - scalar]; + auto a = inputs[0], + b = inputs[1]; auto n = params.n; - void *args[]{&c, &v, &s, &n}; + void *args[]{&c, &a, &b, &n}; h->launch(params.gridSize, 1, 1, params.blockSize, 1, 1, 0, args); diff --git a/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc b/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc index d34528569..8b6a0d135 100644 --- a/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc +++ b/src/04kernel/src/kernels/simple_unary/cpu_kernel.cc @@ -18,6 +18,8 @@ namespace refactor::kernel { Op::Sigmoid, Op::Tanh, Op::Neg, + Op::Erf, + Op::HardSwish, }; return supportedOp.contains(op) && a.dataType.isCpuNumberic() ? std::make_unique(op, a.dataType, a.elementsSize()) @@ -48,6 +50,12 @@ namespace refactor::kernel { using M = std::conditional_t; return static_cast(std::tanh(static_cast(x))); } + template auto hardswishFun(T x) noexcept -> T { + auto mid = x / 6.f + .5f; + return (mid <= 0) ? 0 + : (1 <= mid) ? x + : x * mid; + } auto copyForUnsigned(size_t n) noexcept -> Routine { return [n](runtime::Resources &, void *workspace, void const *const *inputs, void *const *outputs) { std::memcpy(outputs[0], inputs[0], n); @@ -155,6 +163,28 @@ namespace refactor::kernel { default: UNREACHABLE(); } + case Op::Erf: + switch (dataType) { + CASE(std::erf, F32); + CASE(std::erf, F64); + CASE(std::erf, I8); + CASE(std::erf, I16); + CASE(std::erf, I32); + CASE(std::erf, I64); + CASE(std::erf, U8); + CASE(std::erf, U16); + CASE(std::erf, U32); + CASE(std::erf, U64); + default: + UNREACHABLE(); + } + case Op::HardSwish: + switch (dataType) { + CASE(hardswishFun, F32); + CASE(hardswishFun, F64); + default: + UNREACHABLE(); + } default: UNREACHABLE(); } diff --git a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc index e3c260dbc..e883374a4 100644 --- a/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc +++ b/src/04kernel/src/kernels/simple_unary/cuda_kernel.cc @@ -18,7 +18,8 @@ namespace refactor::kernel { auto K::build(Op op, Tensor const &a) noexcept -> KernelBox { static const std::unordered_set supportedOp{Op::Abs, Op::Relu, Op::Sqrt, - Op::Sigmoid, Op::Tanh, Op::Neg}; + Op::Sigmoid, Op::Tanh, Op::Neg, + Op::Erf, Op::HardSwish}; #ifndef USE_CUDA return nullptr; #endif @@ -140,6 +141,24 @@ extern "C" __global__ void kernel( {__(Op::Neg, DT::BF16), "-x"}, {__(Op::Neg, DT::F32 ), "-x"}, {__(Op::Neg, DT::F64 ), "-x"}, + + {__(Op::Erf, DT::F32 ), "erff(x)"}, + {__(Op::Erf, DT::F64 ), "erf(x)"}, + {__(Op::Erf, DT::U8 ), "erff(static_cast(x))"}, + {__(Op::Erf, DT::I8 ), "erff(static_cast(x))"}, + {__(Op::Erf, DT::U16 ), "erff(static_cast(x))"}, + {__(Op::Erf, DT::I16 ), "erff(static_cast(x))"}, + {__(Op::Erf, DT::U32 ), "erf(static_cast(x))"}, + {__(Op::Erf, DT::I32 ), "erf(static_cast(x))"}, + {__(Op::Erf, DT::U64 ), "erf(static_cast(x))"}, + {__(Op::Erf, DT::I64 ), "erf(static_cast(x))"}, + {__(Op::Erf, DT::FP16), "__float2half(erff(__half2float(x)))"}, + {__(Op::Erf, DT::BF16), "__float2bfloat16(erff(__bfloat162float(x)))"}, + + {__(Op::HardSwish, DT::F32 ), "x * fmaxf(0.f, fminf(1.f, fmaf(1.f/6.f, x, 0.5f)))"}, + {__(Op::HardSwish, DT::FP16), "x * __hmax(CUDART_ZERO_FP16, __hmin(CUDART_ONE_FP16, hrcp(__float2half(6.f)) * x + hrcp(__float2half(2.f))))"}, + {__(Op::HardSwish, DT::F64 ), "x * fmax(0.0, fmin(1.0, fma(1.0/6.0, x, 0.5)))"}, + }; // clang-format on diff --git a/src/04kernel/test/kernels/hard_sigmoid/test_cpu.cpp b/src/04kernel/test/kernels/hard_sigmoid/test_cpu.cpp new file mode 100644 index 000000000..65e480e73 --- /dev/null +++ b/src/04kernel/test/kernels/hard_sigmoid/test_cpu.cpp @@ -0,0 +1,31 @@ +#include "../../../src/kernels/hard_sigmoid/cpu_kernel.hh" +#include + +using namespace refactor; +using namespace kernel; + +TEST(kernel, HardSigmoidCpu) { + // build routine + auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3, 5}); + float alpha = 0.2f, beta = 0.5f; + auto kernel = HardSigmoidCpu::build(alpha, beta, *dataTensor); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector result(dataTensor->elementsSize()); + for (auto i : range0_(result.size())) { result[i] = i; } + // inference + { + void const *inputs[]{result.data()}; + void *outputs[]{result.data()}; + routine(res, nullptr, inputs, outputs); + } + std::vector output = {0.5, 0.7, 0.9, 1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., 1.}; + // check + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(output[i], result[i]); + } +} diff --git a/src/04kernel/test/kernels/hard_sigmoid/test_cuda.cpp b/src/04kernel/test/kernels/hard_sigmoid/test_cuda.cpp new file mode 100644 index 000000000..f83182536 --- /dev/null +++ b/src/04kernel/test/kernels/hard_sigmoid/test_cuda.cpp @@ -0,0 +1,49 @@ +#ifdef USE_CUDA + +#include "../../../src/kernels/hard_sigmoid/cpu_kernel.hh" +#include "../../../src/kernels/hard_sigmoid/cuda_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, HardSigmoidCuda) { + // build routine + auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3, 5}); + float alpha = 0.2f, beta = 0.5f; + auto kernel = HardSigmoidCuda::build(alpha, beta, *dataTensor); + auto kCpu = HardSigmoidCpu::build(alpha, beta, *dataTensor); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto gpuMem = dev.malloc(dataTensor->bytesSize()); + // put input data + std::vector data(dataTensor->elementsSize()); + for (auto i : range0_(data.size())) { data[i] = i; } + gpuMem->copyFromHost(data.data(), dataTensor->bytesSize()); + // inference + { + void const *inputs[]{*gpuMem}; + void *outputs[]{*gpuMem}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{data.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // take output data + std::vector result(dataTensor->elementsSize()); + gpuMem->copyToHost(result.data(), dataTensor->bytesSize()); + // check + for (auto i : range0_(data.size())) { + EXPECT_FLOAT_EQ(data[i], result[i]); + } +} + +#endif diff --git a/src/04kernel/test/kernels/rms_normalization/test_cpu.cpp b/src/04kernel/test/kernels/rms_normalization/test_cpu.cpp new file mode 100644 index 000000000..f72072317 --- /dev/null +++ b/src/04kernel/test/kernels/rms_normalization/test_cpu.cpp @@ -0,0 +1,40 @@ +#include "../../../src/kernels/rms_normalization/cpu_kernel.hh" +#include +#include + +using namespace refactor; +using namespace kernel; + +TEST(kernel, RmsNormalizationCpu) { + // build routine + auto y = Tensor::share(DataType::F32, Shape{2, 3, 4}); + auto x = Tensor::share(DataType::F32, Shape{2, 3, 4}); + auto w = Tensor::share(DataType::F32, Shape{4}); + auto kernel = RmsNormalizationCpu::build(0, *x); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector y_(y->elementsSize()); + std::vector x_(x->elementsSize()); + std::vector w_(w->elementsSize()); + std::iota(x_.begin(), x_.end(), 0); + std::iota(w_.begin(), w_.end(), 1); + // inference + { + void const *inputs[]{x_.data(), w_.data()}; + void *outputs[]{y_.data()}; + routine(res, nullptr, inputs, outputs); + } + // check + for (auto i : range0_(2 * 3)) { + auto x__ = x_.data() + i * 4; + auto acc = std::accumulate(x__, x__ + 4, 0.f, [&](auto acc, auto it) { + return acc + it * it; + }); + auto rms = 1. / std::sqrt(acc / 4); + for (auto j : range0_(4)) { + EXPECT_FLOAT_EQ(y_[i * 4 + j], x_[i * 4 + j] * rms * w_[j]); + } + } +} diff --git a/src/04kernel/test/kernels/rms_normalization/test_cuda.cpp b/src/04kernel/test/kernels/rms_normalization/test_cuda.cpp new file mode 100644 index 000000000..23319e8e1 --- /dev/null +++ b/src/04kernel/test/kernels/rms_normalization/test_cuda.cpp @@ -0,0 +1,56 @@ +#ifdef USE_CUDA + +#include "../../../src/kernels/rms_normalization/cpu_kernel.hh" +#include "../../../src/kernels/rms_normalization/cuda_kernel.hh" +#include "hardware/device_manager.h" +#include +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, RmsNormalizationCuda) { + // build routine + auto y = Tensor::share(DataType::F32, Shape{2, 3, 4}); + auto x = Tensor::share(DataType::F32, Shape{2, 3, 4}); + auto w = Tensor::share(DataType::F32, Shape{4}); + auto kernel = RmsNormalizationCuda::build(0, *x), + kCpu = RmsNormalizationCpu::build(0, *x); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto yGpu = dev.malloc(y->bytesSize()), + xGpu = dev.malloc(x->bytesSize()), + wGpu = dev.malloc(w->bytesSize()); + // put input data + std::vector y_(y->elementsSize()); + std::vector x_(x->elementsSize()); + std::vector w_(w->elementsSize()); + std::iota(x_.begin(), x_.end(), 0); + std::iota(w_.begin(), w_.end(), 1); + xGpu->copyFromHost(x_.data(), x->bytesSize()); + wGpu->copyFromHost(w_.data(), w->bytesSize()); + // inference + { + void const *inputs[]{*xGpu, *wGpu}; + void *outputs[]{*yGpu}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{x_.data(), w_.data()}; + void *outputs[]{y_.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + std::vector result(y->elementsSize()); + yGpu->copyToHost(result.data(), y->bytesSize()); + for (auto i : range0_(y_.size())) { + EXPECT_FLOAT_EQ(result[i], y_[i]); + } +} + +#endif diff --git a/src/04kernel/test/kernels/simple_binary/test_binary_cpu.cpp b/src/04kernel/test/kernels/simple_binary/test_binary_cpu.cpp index 0247a7f39..e2a840fc9 100644 --- a/src/04kernel/test/kernels/simple_binary/test_binary_cpu.cpp +++ b/src/04kernel/test/kernels/simple_binary/test_binary_cpu.cpp @@ -1,4 +1,5 @@ #include "../src/kernels/simple_binary/cpu_kernel.hh" +#include #include using namespace refactor; @@ -27,11 +28,60 @@ void testBinaryCPU(SimpleBinaryType binaryOPT, std::function operation) { + // Create Tensor and build kernels + auto aTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW); + auto bTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW); + auto cTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW); + auto cpuKernel = BinaryCpu::build(binaryOPT, *aTensor, *bTensor); + ASSERT_TRUE(cpuKernel); + auto res = runtime::Resources(); + auto cpuRoutine = cpuKernel->lower(res).routine; + // Init inputs and outputs + std::vector a(aTensor->elementsSize(), -3); + std::vector b(bTensor->elementsSize(), 2); + std::vector c(cTensor->elementsSize()); + // Compute + void const *inputs[]{a.data(), b.data()}; + void *outputs[]{c.data()}; + cpuRoutine(res, nullptr, inputs, outputs); + // Compare + for (auto i : range0_(c.size())) { + EXPECT_FLOAT_EQ(c[i], operation(a[i], b[i])); + } +} + +void testFmodWithI32CPU(SimpleBinaryType binaryOPT, std::function operation) { + // Create Tensor and build kernels + auto aTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW); + auto bTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW); + auto cTensor = Tensor::share(DataType::I32, Shape{10, 20, 30, 40}, LayoutType::NCHW); + auto cpuKernel = BinaryCpu::build(binaryOPT, *aTensor, *bTensor); + ASSERT_TRUE(cpuKernel); + auto res = runtime::Resources(); + auto cpuRoutine = cpuKernel->lower(res).routine; + // Init inputs and outputs + std::vector a(aTensor->elementsSize(), -3); + std::vector b(bTensor->elementsSize(), 2); + std::vector c(cTensor->elementsSize()); + // Compute + void const *inputs[]{a.data(), b.data()}; + void *outputs[]{c.data()}; + cpuRoutine(res, nullptr, inputs, outputs); + // Compare + for (auto i : range0_(c.size())) { + EXPECT_FLOAT_EQ(c[i], operation(a[i], b[i])); + } +} + TEST(kernel, BinaryCpu) { testBinaryCPU(SimpleBinaryType::Add, [](float a, float b) { return a + b; }); testBinaryCPU(SimpleBinaryType::Sub, [](float a, float b) { return a - b; }); testBinaryCPU(SimpleBinaryType::Mul, [](float a, float b) { return a * b; }); testBinaryCPU(SimpleBinaryType::Div, [](float a, float b) { return a / b; }); + testModCPU(SimpleBinaryType::Mod, [](int a, int b) { return a % b; }); + testFmodWithI32CPU(SimpleBinaryType::Fmod, [](int a, int b) { return static_cast(std::fmod(a, b)); }); + testBinaryCPU(SimpleBinaryType::Fmod, [](float a, float b) { return std::fmod(a, b); }); } TEST(kernel, BinaryCpuBroadcast) { diff --git a/src/04kernel/test/kernels/simple_binary/test_binary_cuda.cpp b/src/04kernel/test/kernels/simple_binary/test_binary_cuda.cpp index 901af265b..b34b5a169 100644 --- a/src/04kernel/test/kernels/simple_binary/test_binary_cuda.cpp +++ b/src/04kernel/test/kernels/simple_binary/test_binary_cuda.cpp @@ -9,12 +9,13 @@ using namespace refactor; using namespace kernel; using namespace hardware; +template void testBinaryCuda(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape dimC) { // Create Tensor and build kernels - using T_ = primitive::type; - auto aTensor = Tensor::share(DataType::I8, dimA, LayoutType::NCHW); - auto bTensor = Tensor::share(DataType::I8, dimB, LayoutType::NCHW); - auto cTensor = Tensor::share(DataType::I8, dimC, LayoutType::NCHW); + using T_ = primitive::type; + auto aTensor = Tensor::share(T, dimA, LayoutType::NCHW); + auto bTensor = Tensor::share(T, dimB, LayoutType::NCHW); + auto cTensor = Tensor::share(T, dimC, LayoutType::NCHW); auto cpuKernel = BinaryCpu::build(binaryOPT, *aTensor, *bTensor), cudaKernel = BinaryCuda::build(binaryOPT, *aTensor, *bTensor); @@ -24,8 +25,8 @@ void testBinaryCuda(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape di auto cudaRoutine = cudaKernel->lower(res).routine; // Init inputs and outputs - std::vector a(aTensor->elementsSize(), 3.0f); - std::vector b(bTensor->elementsSize(), 2.0f); + std::vector a(aTensor->elementsSize(), 3); + std::vector b(bTensor->elementsSize(), 2); std::vector c(cTensor->elementsSize()); auto &dev = *device::init(Device::Type::Nvidia, 0, ""); auto aGPU = dev.malloc(aTensor->bytesSize()), @@ -53,35 +54,63 @@ void testBinaryCuda(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape di } TEST(kernel, BinaryCudaAdd) { - testBinaryCuda(SimpleBinaryType::Add, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}); + testBinaryCuda(SimpleBinaryType::Add, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); } TEST(kernel, BinaryCudaMul) { - testBinaryCuda(SimpleBinaryType::Mul, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}); + testBinaryCuda(SimpleBinaryType::Mul, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); } TEST(kernel, BinaryCudaSub) { - testBinaryCuda(SimpleBinaryType::Sub, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}); + testBinaryCuda(SimpleBinaryType::Sub, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); } TEST(kernel, BinaryCudaDiv) { - testBinaryCuda(SimpleBinaryType::Div, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}, - Shape{2, 5, 10, 20, 3, 4}); + testBinaryCuda(SimpleBinaryType::Div, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCudaMod) { + testBinaryCuda(SimpleBinaryType::Mod, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCudaFmodI8) { + testBinaryCuda(SimpleBinaryType::Fmod, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCudaFmodF32) { + testBinaryCuda(SimpleBinaryType::Fmod, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); } TEST(kernel, BinaryCudaBroadcast) { - testBinaryCuda(SimpleBinaryType::Add, Shape{1, 2, 3, 4, 5, 6}, Shape{}, Shape{1, 2, 3, 4, 5, 6}); + testBinaryCuda(SimpleBinaryType::Sub, + Shape{1, 2, 3, 4, 5, 6}, + Shape{}, + Shape{1, 2, 3, 4, 5, 6}); + testBinaryCuda(SimpleBinaryType::Div, + Shape{}, + Shape{1, 2, 3, 4, 5, 6}, + Shape{1, 2, 3, 4, 5, 6}); } #endif diff --git a/src/04kernel/test/kernels/simple_unary/test_cpu.cpp b/src/04kernel/test/kernels/simple_unary/test_cpu.cpp index da1cb6f83..47249281e 100644 --- a/src/04kernel/test/kernels/simple_unary/test_cpu.cpp +++ b/src/04kernel/test/kernels/simple_unary/test_cpu.cpp @@ -4,6 +4,8 @@ using namespace refactor; using namespace kernel; +using VecFloat = std::vector; + static void testOp(SimpleUnaryType opType, float check(float)) { // build routine auto dataTensor = Tensor::share(DataType::F32, Shape{20, 30, 50}); @@ -12,7 +14,7 @@ static void testOp(SimpleUnaryType opType, float check(float)) { auto res = runtime::Resources(); auto routine = kernel->lower(res).routine; // put input data - std::vector data(dataTensor->elementsSize()); + VecFloat data(dataTensor->elementsSize()); for (auto i : range0_(data.size())) { data[i] = i * 1e-4f; } auto result = data; // inference @@ -27,8 +29,34 @@ static void testOp(SimpleUnaryType opType, float check(float)) { } } +static void testOpWithData(SimpleUnaryType opType, const VecFloat &data) { + // build routine + auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3}); + auto kernel = SimpleUnaryCpu::build(opType, *dataTensor); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + VecFloat inputdata(dataTensor->elementsSize()); + for (auto i : range0_(inputdata.size())) { inputdata[i] = i; } + auto result = inputdata; + // inference + { + void const *inputs[]{result.data()}; + void *outputs[]{result.data()}; + routine(res, nullptr, inputs, outputs); + } + // check + for (auto i : range0_(inputdata.size())) { + EXPECT_NEAR(data[i], result[i], 1e-5); + } +} + TEST(kernel, SimpleUnaryCpu) { testOp(SimpleUnaryType::Abs, std::abs); testOp(SimpleUnaryType::Sqrt, std::sqrt); testOp(SimpleUnaryType::Tanh, std::tanh); + testOp(SimpleUnaryType::Erf, std::erf); + testOpWithData(SimpleUnaryType::HardSwish, + VecFloat{0.000000, 0.666667, 1.666667, 3.000000, 4.000000, 5.000000}); } diff --git a/src/04kernel/test/kernels/simple_unary/test_cuda.cpp b/src/04kernel/test/kernels/simple_unary/test_cuda.cpp index 6ff5d798b..72ebff72a 100644 --- a/src/04kernel/test/kernels/simple_unary/test_cuda.cpp +++ b/src/04kernel/test/kernels/simple_unary/test_cuda.cpp @@ -51,6 +51,8 @@ TEST(kernel, SimpleUnaryCuda) { testOp(SimpleUnaryType::Sqrt); testOp(SimpleUnaryType::Sigmoid); testOp(SimpleUnaryType::Tanh); + testOp(SimpleUnaryType::Erf); + testOp(SimpleUnaryType::HardSwish); } #endif diff --git a/src/04kernel/test/kernels/slice/test_cpu.cpp b/src/04kernel/test/kernels/slice/test_cpu.cpp index 6e12c5437..d574d0e22 100644 --- a/src/04kernel/test/kernels/slice/test_cpu.cpp +++ b/src/04kernel/test/kernels/slice/test_cpu.cpp @@ -59,7 +59,8 @@ TEST(kernel, SliceCpu) { }; auto input = Tensor::share(DataType::F32, Shape{7, 6, 5, 1, 2, 3}), output = Tensor::share(DataType::F32, Shape{3, 2, 3, 1, 2, 3}); - auto kernel = SliceCpu::build(SliceInfo(dims, *input)); + auto info = SliceInfo(dims, *input); + auto kernel = SliceCpu::build(info); ASSERT_TRUE(kernel); auto res = runtime::Resources(); auto routine = kernel->lower(res).routine; @@ -94,7 +95,7 @@ TEST(kernel, SliceCpu) { } } // test reform - auto kernelReformed = SliceCpu::build(SliceInfo(dims, *input).reform(16)); + auto kernelReformed = SliceCpu::build(info.reform(16)); ASSERT_TRUE(kernelReformed); auto routineReformed = kernelReformed->lower(res).routine; std::vector resultReformed(result.size()); diff --git a/src/05computation/CMakeLists.txt b/src/05computation/CMakeLists.txt index c57859e24..2a42cc831 100644 --- a/src/05computation/CMakeLists.txt +++ b/src/05computation/CMakeLists.txt @@ -11,6 +11,5 @@ file(GLOB_RECURSE COMPUTATION_TEST test/*.cpp) if(COMPUTATION_TEST) add_executable(computation_test ${COMPUTATION_TEST}) add_test(computation_test computation_test) - target_link_libraries(computation_test computation GTest::gtest_main ${BACKWARD_ENABLE}) - add_backward(computation_test) + target_link_libraries(computation_test computation GTest::gtest_main Backward::Object) endif() diff --git a/src/05computation/include/computation/operators/hard_sigmoid.h b/src/05computation/include/computation/operators/hard_sigmoid.h new file mode 100644 index 000000000..f0f3fbea5 --- /dev/null +++ b/src/05computation/include/computation/operators/hard_sigmoid.h @@ -0,0 +1,23 @@ +#ifndef COMPUTATION_HARD_SIGMOID_H +#define COMPUTATION_HARD_SIGMOID_H + +#include "../operator.h" + +namespace refactor::computation { + + struct HardSigmoid final : public Operator { + float alpha, beta; + + constexpr HardSigmoid(float alpha_, float beta_) noexcept + : Operator(), alpha(alpha_), beta(beta_){}; + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const noexcept final; + std::string serialize() const noexcept final; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_HARD_SIGMOID_H diff --git a/src/05computation/include/computation/operators/rms_normalization.h b/src/05computation/include/computation/operators/rms_normalization.h new file mode 100644 index 000000000..a6b153a6a --- /dev/null +++ b/src/05computation/include/computation/operators/rms_normalization.h @@ -0,0 +1,23 @@ +#ifndef COMPUTATION_RMS_NORMALIZATION_H +#define COMPUTATION_RMS_NORMALIZATION_H + +#include "../operator.h" + +namespace refactor::computation { + + struct RmsNormalization final : public Operator { + float epsilon; + + constexpr explicit RmsNormalization(float epsilon_) noexcept + : Operator(), epsilon(epsilon_) {} + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const final; + std::string serialize() const noexcept final; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_RMS_NORMALIZATION_H diff --git a/src/05computation/src/operators/hard_sigmoid.cc b/src/05computation/src/operators/hard_sigmoid.cc new file mode 100644 index 000000000..74cb0e11d --- /dev/null +++ b/src/05computation/src/operators/hard_sigmoid.cc @@ -0,0 +1,23 @@ +#include "computation/operators/hard_sigmoid.h" +#include "kernel/collectors/hard_sigmoid.h" + +namespace refactor::computation { + using Op = HardSigmoid; + + auto Op::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto Op::opTypeId() const noexcept -> size_t { return typeId(); } + auto Op::name() const noexcept -> std::string_view { return "HardSigmoid"; } + + auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { + using Collector_ = kernel::HardSigmoidCollector; + return std::make_unique(target, alpha, beta); + } + auto Op::serialize() const noexcept -> std::string { + return fmt::format("{}()", name()); + } + +}// namespace refactor::computation + diff --git a/src/05computation/src/operators/rms_normalization.cc b/src/05computation/src/operators/rms_normalization.cc new file mode 100644 index 000000000..ff4659252 --- /dev/null +++ b/src/05computation/src/operators/rms_normalization.cc @@ -0,0 +1,27 @@ +#include "computation/operators/rms_normalization.h" +#include "kernel/collectors/rms_normalization.h" + +namespace refactor::computation { + using Op = RmsNormalization; + + auto Op::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto Op::opTypeId() const noexcept -> size_t { return typeId(); } + auto Op::name() const noexcept -> std::string_view { return "RmsNormalization"; } + auto Op::candidateKernels(Target target) const -> kernel::CollectorBox { + using Collector_ = kernel::RmsNormalizationCollector; + return std::make_unique(target, epsilon); + } + auto Op::serialize() const noexcept -> std::string { + union code { + float f; + int32_t i; + }; + return fmt::format(("{}({:e}={:#010x})"), + name(), epsilon, + code{epsilon}.i); + } + +}// namespace refactor::computation diff --git a/src/05computation/src/operators/simple_binary.cc b/src/05computation/src/operators/simple_binary.cc index 31831e7c4..90f7ac028 100644 --- a/src/05computation/src/operators/simple_binary.cc +++ b/src/05computation/src/operators/simple_binary.cc @@ -39,6 +39,14 @@ namespace refactor::computation { static uint8_t ID = 8; return reinterpret_cast(&ID); } + case Ty::Mod: { + static uint8_t ID = 9; + return reinterpret_cast(&ID); + } + case Ty::Fmod: { + static uint8_t ID = 10; + return reinterpret_cast(&ID); + } default: UNREACHABLE(); } @@ -64,6 +72,10 @@ namespace refactor::computation { return "Or"; case Ty::Xor: return "Xor"; + case Ty::Mod: + return "Mod"; + case Ty::Fmod: + return "Fmod"; default: UNREACHABLE(); } diff --git a/src/05computation/src/operators/simple_unary.cc b/src/05computation/src/operators/simple_unary.cc index 23deaece8..d43aa5aca 100644 --- a/src/05computation/src/operators/simple_unary.cc +++ b/src/05computation/src/operators/simple_unary.cc @@ -81,6 +81,10 @@ namespace refactor::computation { static uint8_t ID = 19; return reinterpret_cast(&ID); } + case SimpleUnaryType::HardSwish: { + static uint8_t ID = 20; + return reinterpret_cast(&ID); + } default: UNREACHABLE(); } @@ -128,6 +132,8 @@ namespace refactor::computation { return "Neg"; case SimpleUnaryType::Not: return "Not"; + case SimpleUnaryType::HardSwish: + return "HardSwish"; default: UNREACHABLE(); } diff --git a/src/06frontend/CMakeLists.txt b/src/06frontend/CMakeLists.txt index 170183289..d771b4b94 100644 --- a/src/06frontend/CMakeLists.txt +++ b/src/06frontend/CMakeLists.txt @@ -11,6 +11,5 @@ file(GLOB_RECURSE FRONTEND_TEST test/*.cpp) if(FRONTEND_TEST) add_executable(frontend_test ${FRONTEND_TEST}) add_test(frontend_test frontend_test) - target_link_libraries(frontend_test frontend GTest::gtest_main ${BACKWARD_ENABLE}) - add_backward(frontend_test) + target_link_libraries(frontend_test frontend GTest::gtest_main Backward::Object) endif() diff --git a/src/06frontend/include/frontend/operator.h b/src/06frontend/include/frontend/operator.h index 8af7dd115..965d87f36 100644 --- a/src/06frontend/include/frontend/operator.h +++ b/src/06frontend/include/frontend/operator.h @@ -49,7 +49,18 @@ namespace refactor::frontend { Tensor_ const &tensor() const; Tensors const &tensors() const; }; - using Attributes = std::unordered_map; + class Attributes { + std::unordered_map map; + + public: + void insert(std::string, Attribute); + bool empty() const; + Attribute &operator[](const char *); + Attribute const &operator[](const char *) const; + std::optional> get(const char *); + std::optional> get(const char *) const; + Attribute &getOrInsert(const char *, Attribute); + }; using ModelContext = std::unordered_map; class Operator; @@ -99,6 +110,46 @@ namespace refactor::frontend { } }; + using ShapeResult = Result; + using ShapeRefs = std::vector>; + + /// @brief 多方向形状广播。 + /// @param inputs 所有输入的形状。 + /// @return 广播后的形状。 + ShapeResult multidirBroadcast(ShapeRefs const &); + + /// @brief 单方向形状广播。 + /// @param target 目标形状。 + /// @param test 测试形状。 + /// @return 测试形状是否可以广播到目标形状。 + bool unidirBroadcast(Shape const &target, Shape const &test); + +#define EXPECT_NO_ATTRI \ + ASSERT(attributes.empty(), "{} operator should not have attributes", opType) + +#define EXPECT_SIZE(N) \ + if (inputs.size() != (N)) { \ + return Err(InferError(ERROR_MSG("Input size error"))); \ + } + +#define EXPECT_VAL(DIM, VAL) \ + int64_t VAL; \ + if ((DIM).hasValue()) { \ + VAL = (DIM).value(); \ + } else { \ + return Err(InferError(UnknownVariable{(DIM.variable()->name)})); \ + } + +#define MULTIDIR_BROADCAST(SHAPES) \ + Shape output; \ + { \ + auto res = multidirBroadcast(SHAPES); \ + if (res.isErr()) { \ + return Err(InferError(ERROR_MSG(res.unwrapErr()))); \ + } \ + output = std::move(res.unwrap()); \ + } + }// namespace refactor::frontend #endif// FRONTEND_OPERATOR_H diff --git a/src/06frontend/src/common.cpp b/src/06frontend/src/common.cpp new file mode 100644 index 000000000..1917e8259 --- /dev/null +++ b/src/06frontend/src/common.cpp @@ -0,0 +1,82 @@ +#include "frontend/operator.h" +#include + +namespace refactor::frontend { + + void Attributes::insert(std::string key, Attribute value) { + map.insert({std::move(key), std::move(value)}); + } + bool Attributes::empty() const { + return map.empty(); + } + auto Attributes::operator[](const char *key) -> Attribute & { + return map.at(key); + } + auto Attributes::operator[](const char *key) const -> Attribute const & { + return map.at(key); + } + auto Attributes::get(const char *key) -> std::optional> { + auto it = map.find(key); + return it != map.end() ? std::make_optional(std::ref(it->second)) : std::nullopt; + } + auto Attributes::get(const char *key) const -> std::optional> { + auto it = map.find(key); + return it != map.end() ? std::make_optional(std::cref(it->second)) : std::nullopt; + } + auto Attributes::getOrInsert(const char *key, Attribute otherwise) -> Attribute & { + auto [it, ok] = map.try_emplace(key, std::move(otherwise)); + return it->second; + } + + ShapeResult multidirBroadcast(ShapeRefs const &inputs) { + using Iter = std::reverse_iterator; + std::vector> iters; + iters.reserve(inputs.size()); + for (auto const &input : inputs) { + iters.emplace_back(input.get().rbegin(), input.get().rend()); + } + Shape ans; + while (true) { + std::optional dim = std::nullopt; + for (size_t i = 0; i < iters.size();) { + if (iters[i].first != iters[i].second) { + auto new_ = *iters[i].first++; + if (!dim || *dim == DimExpr(1)) { + dim = std::move(new_); + } else if (new_ != DimExpr(1) && new_ != *dim) { + loge("shape broadcast failed"); + for (auto input : inputs) { + loge("{}", shapeFormat(input.get())); + } + return Err(ERROR_MSG("Shape broadcast failed")); + } + ++i; + } else { + std::swap(iters[i], iters.back()); + iters.pop_back(); + } + } + if (dim) { + ans.emplace_back(std::move(*dim)); + } else { + break; + } + } + std ::reverse(ans.begin(), ans.end()); + return Ok(ans); + } + + bool unidirBroadcast(Shape const &target, Shape const &test) { + if (target.size() < test.size()) { + return false; + } else { + for (auto i = target.rbegin(), j = test.rbegin(); j != test.rend(); ++i, ++j) { + if (*j != *i && *j != DimExpr(1)) { + return false; + } + } + return true; + } + } + +}// namespace refactor::frontend diff --git a/src/06frontend/src/graph.cc b/src/06frontend/src/graph.cc index 713624ac5..083822e87 100644 --- a/src/06frontend/src/graph.cc +++ b/src/06frontend/src/graph.cc @@ -175,42 +175,30 @@ namespace refactor::frontend { std::vector nodes(_internal.nodes.size()); std::vector edges(_internal.edges.size()); + + auto fn = [&edges, this](auto i) { + if (edges[i].tensor) { + return; + } + auto const &[tensor, name] = _internal.edges[i]; + computation::Shape shape(tensor->shape.size()); + std::transform(std::execution::unseq, + tensor->shape.begin(), tensor->shape.end(), shape.begin(), + [](auto const &dim) { return dim.value(); }); + auto layout = shape.size() == 4 ? computation::LayoutType::NCHW : computation::LayoutType::Others; + edges[i].tensor = computation::Tensor::share(tensor->dataType, std::move(shape), layout, tensor->data); + edges[i].name = name; + }; + std::transform(_internal.topology.begin(), _internal.topology.end(), nodes.begin(), - [&edges, this](auto const &nodeRef) { + [&fn, this](auto const &nodeRef) { auto const &[op, name] = _internal.nodes[nodeRef.idx]; + std::for_each(nodeRef.inputs.begin(), nodeRef.inputs.end(), fn); + std::for_each(nodeRef.outputs.begin(), nodeRef.outputs.end(), fn); auto constant = std::all_of(std::execution::unseq, nodeRef.outputs.begin(), nodeRef.outputs.end(), [this](auto i) { return _internal.edges[i].tensor->data; }); - if (constant) { - return computation::Node{nullptr, name}; - } - auto fn = [&edges, &nodeRef, this](auto i) { - if (edges[i].tensor) { - return; - } - auto const &[tensor, name] = _internal.edges[i]; - computation::Shape shape(tensor->shape.size()); - std::transform(std::execution::unseq, - tensor->shape.begin(), tensor->shape.end(), shape.begin(), - [](auto const &dim) { return dim.value(); }); - auto layout = shape.size() == 4 ? computation::LayoutType::NCHW : computation::LayoutType::Others; - edges[i].tensor = computation::Tensor::share(tensor->dataType, std::move(shape), layout, tensor->data); - edges[i].name = name; - }; - auto op_ = op->lower(TensorRefs(_internal.edges, nodeRef.inputs)); - auto valueDependentInputs = op->valueDependentInputs(); - auto it = valueDependentInputs.begin(); - for (auto i : range0_(nodeRef.inputs.size())) { - auto input = nodeRef.inputs[i]; - if (it != valueDependentInputs.end() && i == *it) { - edges[input].name = _internal.edges[input].name; - ++it; - continue; - } - fn(input); - } - std::for_each(std::execution::unseq, nodeRef.outputs.begin(), nodeRef.outputs.end(), fn); - return computation::Node{std::move(op_), name}; + return computation::Node{constant ? nullptr : op->lower(TensorRefs(_internal.edges, nodeRef.inputs)), name}; }); auto const endTime = high_resolution_clock::now(); diff --git a/src/07onnx/CMakeLists.txt b/src/07onnx/CMakeLists.txt index fd4570b4d..547bdf645 100644 --- a/src/07onnx/CMakeLists.txt +++ b/src/07onnx/CMakeLists.txt @@ -11,6 +11,5 @@ file(GLOB_RECURSE ONNX_TEST test/*.cpp) if(ONNX_TEST) add_executable(onnx_test ${ONNX_TEST}) add_test(onnx_test onnx_test) - target_link_libraries(onnx_test onnx GTest::gtest_main ${BACKWARD_ENABLE}) - add_backward(onnx_test) + target_link_libraries(onnx_test onnx GTest::gtest_main Backward::Object) endif() diff --git a/src/07onnx/src/operators.cpp b/src/07onnx/src/operators.cpp index 6e657f272..a565a8d3e 100644 --- a/src/07onnx/src/operators.cpp +++ b/src/07onnx/src/operators.cpp @@ -17,6 +17,7 @@ #include "operators/gather_elements.hh" #include "operators/gemm.hh" #include "operators/global_pool.hh" +#include "operators/hard_sigmoid.hh" #include "operators/mat_mul.hh" #include "operators/mat_mul_integer.hh" #include "operators/pool.hh" @@ -95,6 +96,7 @@ namespace refactor::onnx { REGISTER(And , SimpleBinary ); REGISTER(Or , SimpleBinary ); REGISTER(Xor , SimpleBinary ); + REGISTER(Mod , SimpleBinary ); REGISTER(Abs , SimpleUnary ); REGISTER(Acos , SimpleUnary ); REGISTER(Acosh , SimpleUnary ); @@ -116,6 +118,7 @@ namespace refactor::onnx { REGISTER(Not , SimpleUnary ); REGISTER(Neg , SimpleUnary ); REGISTER(Identity , SimpleUnary ); + REGISTER(HardSwish , SimpleUnary ); REGISTER(Slice , Slice ); REGISTER(Softmax , Softmax ); REGISTER(Split , Split ); @@ -124,6 +127,7 @@ namespace refactor::onnx { REGISTER(Transpose , Transpose ); REGISTER(Unsqueeze , Unsqueeze ); REGISTER(Where , Where ); + REGISTER(HardSigmoid , HardSigmoid ); #undef REGISTER // clang-format on } diff --git a/src/07onnx/src/operators/batch_normalization.cc b/src/07onnx/src/operators/batch_normalization.cc index afea890bc..44a280c42 100644 --- a/src/07onnx/src/operators/batch_normalization.cc +++ b/src/07onnx/src/operators/batch_normalization.cc @@ -12,8 +12,8 @@ namespace refactor::onnx { epsilon(epsilon_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto trainingMode = defaultOr(attributes, "training_mode", {0}).int_() != 0; - auto epsilon = defaultOr(attributes, "epsilon", {1e-5f}).float_(); + auto trainingMode = attributes.getOrInsert( "training_mode", {0}).int_() != 0; + auto epsilon = attributes.getOrInsert( "epsilon", {1e-5f}).float_(); return OpBox(std::make_unique(trainingMode, epsilon)); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/cast.cc b/src/07onnx/src/operators/cast.cc index 6b63031ea..ac96c04db 100644 --- a/src/07onnx/src/operators/cast.cc +++ b/src/07onnx/src/operators/cast.cc @@ -10,7 +10,7 @@ namespace refactor::onnx { : Operator(), to(to_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto to = *DataType::parse(attributes.at("to").int_()); + auto to = *DataType::parse(attributes["to"].int_()); return OpBox(std::make_unique(to)); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/common.cpp b/src/07onnx/src/operators/common.cpp index 119e48a0b..5d2c57d10 100644 --- a/src/07onnx/src/operators/common.cpp +++ b/src/07onnx/src/operators/common.cpp @@ -1,61 +1,8 @@ #include "common.h" -#include #include -#include namespace refactor::onnx { - ShapeResult multidirBroadcast(ShapeRefs const &inputs) { - using Iter = std::reverse_iterator; - std::vector> iters; - iters.reserve(inputs.size()); - for (auto const &input : inputs) { - iters.emplace_back(input.get().rbegin(), input.get().rend()); - } - Shape ans; - while (true) { - std::optional dim = std::nullopt; - for (size_t i = 0; i < iters.size();) { - if (iters[i].first != iters[i].second) { - auto new_ = *iters[i].first++; - if (!dim || *dim == DimExpr(1)) { - dim = std::move(new_); - } else if (new_ != DimExpr(1) && new_ != *dim) { - loge("shape broadcast failed"); - for (auto input : inputs) { - loge("{}", shapeFormat(input.get())); - } - return Err(ERROR_MSG("Shape broadcast failed")); - } - ++i; - } else { - std::swap(iters[i], iters.back()); - iters.pop_back(); - } - } - if (dim) { - ans.emplace_back(std::move(*dim)); - } else { - break; - } - } - std ::reverse(ans.begin(), ans.end()); - return Ok(ans); - } - - bool unidirBroadcast(Shape const &target, Shape const &test) { - if (target.size() < test.size()) { - return false; - } else { - for (auto i = target.rbegin(), j = test.rbegin(); j != test.rend(); ++i, ++j) { - if (*j != *i && *j != DimExpr(1)) { - return false; - } - } - return true; - } - } - ShapeResult pool(SmallInts<4> const &input, Ints const &kernel, OptionalIntsRef const &dilations, @@ -100,9 +47,4 @@ namespace refactor::onnx { return Ok(std::move(ans)); } - Attribute defaultOr(Attributes &attrs, std::string const &name, Attribute defaultValue) { - auto iter = attrs.find(name); - return iter == attrs.end() ? defaultValue : std::move(iter->second); - } - }// namespace refactor::onnx diff --git a/src/07onnx/src/operators/common.h b/src/07onnx/src/operators/common.h index e962499a6..ca24fdbcb 100644 --- a/src/07onnx/src/operators/common.h +++ b/src/07onnx/src/operators/common.h @@ -8,24 +8,11 @@ namespace refactor::onnx { using namespace frontend; - using ShapeResult = Result; - using ShapeRefs = std::vector>; using OptionalInts = std::optional; using OptionalIntsRef = std::optional>; constexpr Int StandardOpsetVersion = 18; - /// @brief 多方向形状广播。 - /// @param inputs 所有输入的形状。 - /// @return 广播后的形状。 - ShapeResult multidirBroadcast(ShapeRefs const &); - - /// @brief 单方向形状广播。 - /// @param target 目标形状。 - /// @param test 测试形状。 - /// @return 测试形状是否可以广播到目标形状。 - bool unidirBroadcast(Shape const &target, Shape const &test); - /// @brief 池化形状推断。 /// @param data 输入张量的形状。 /// @param kernel kernel 的形状。 @@ -39,32 +26,6 @@ namespace refactor::onnx { OptionalIntsRef const &pads, OptionalIntsRef const &strides); - Attribute defaultOr(Attributes &attrs, - std::string const &name, - Attribute defaultValue); - -#define EXPECT_SIZE(N) \ - if (inputs.size() != (N)) { \ - return Err(InferError(ERROR_MSG("Input size error"))); \ - } - -#define EXPECT_VAL(DIM, VAL) \ - int64_t VAL; \ - if ((DIM).hasValue()) { \ - VAL = (DIM).value(); \ - } else { \ - return Err(InferError(UnknownVariable{(DIM.variable()->name)})); \ - } - -#define MULTIDIR_BROADCAST(SHAPES) \ - Shape output; \ - { \ - auto res = multidirBroadcast(SHAPES); \ - if (res.isErr()) { \ - return Err(InferError(ERROR_MSG(res.unwrapErr()))); \ - } \ - output = std::move(res.unwrap()); \ - } }// namespace refactor::onnx #endif// ONNX_INFER_H diff --git a/src/07onnx/src/operators/compair.cc b/src/07onnx/src/operators/compair.cc index d19fc78ed..11c63f33d 100644 --- a/src/07onnx/src/operators/compair.cc +++ b/src/07onnx/src/operators/compair.cc @@ -10,7 +10,7 @@ namespace refactor::onnx { : Operator(), type(type_) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "Compair operator should not have attributes"); + EXPECT_NO_ATTRI; if (opType == "onnx::Equal") { return OpBox(std::make_unique(Ty::EQ)); diff --git a/src/07onnx/src/operators/concat.cc b/src/07onnx/src/operators/concat.cc index 2e3b4e66d..7f894d5a9 100644 --- a/src/07onnx/src/operators/concat.cc +++ b/src/07onnx/src/operators/concat.cc @@ -10,7 +10,7 @@ namespace refactor::onnx { : Operator(), axis(axis_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto axis = attributes.at("axis").int_(); + auto axis = attributes["axis"].int_(); return OpBox(std::make_unique(axis)); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/constant.cc b/src/07onnx/src/operators/constant.cc index 87137994d..11655ea8c 100644 --- a/src/07onnx/src/operators/constant.cc +++ b/src/07onnx/src/operators/constant.cc @@ -10,20 +10,20 @@ namespace refactor::onnx { auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { Attribute value; - if (auto it = attributes.find("value"); it != attributes.end()) { - value = std::move(it->second); - } else if (auto it = attributes.find("value_float"); it != attributes.end()) { - value = std::move(it->second); - } else if (auto it = attributes.find("value_floats"); it != attributes.end()) { - value = std::move(it->second); - } else if (auto it = attributes.find("value_int"); it != attributes.end()) { - value = std::move(it->second); - } else if (auto it = attributes.find("value_ints"); it != attributes.end()) { - value = std::move(it->second); - } else if (auto it = attributes.find("value_string"); it != attributes.end()) { - value = std::move(it->second); - } else if (auto it = attributes.find("value_strings"); it != attributes.end()) { - value = std::move(it->second); + if (auto opt = attributes.get("value"); opt) { + value = std::move(opt->get()); + } else if (auto opt = attributes.get("value_float"); opt) { + value = std::move(opt->get()); + } else if (auto opt = attributes.get("value_floats"); opt) { + value = std::move(opt->get()); + } else if (auto opt = attributes.get("value_int"); opt) { + value = std::move(opt->get()); + } else if (auto opt = attributes.get("value_ints"); opt) { + value = std::move(opt->get()); + } else if (auto opt = attributes.get("value_string"); opt) { + value = std::move(opt->get()); + } else if (auto opt = attributes.get("value_strings"); opt) { + value = std::move(opt->get()); } else { RUNTIME_ERROR("Constant value not support"); } diff --git a/src/07onnx/src/operators/constant_of_shape.cc b/src/07onnx/src/operators/constant_of_shape.cc index 7ad72dde6..c30721072 100644 --- a/src/07onnx/src/operators/constant_of_shape.cc +++ b/src/07onnx/src/operators/constant_of_shape.cc @@ -9,8 +9,8 @@ namespace refactor::onnx { : Operator(), value(std::move(value_)) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto it = attributes.find("value"); - auto value = it != attributes.end() ? std::move(it->second.tensor()) : nullptr; + auto it = attributes.get("value"); + auto value = it ? std::move(it->get().tensor()) : nullptr; return OpBox(std::make_unique(std::move(value))); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/conv.cc b/src/07onnx/src/operators/conv.cc index c612d378c..4ed454817 100644 --- a/src/07onnx/src/operators/conv.cc +++ b/src/07onnx/src/operators/conv.cc @@ -20,14 +20,14 @@ namespace refactor::onnx { dilations = std::nullopt, pads = std::nullopt, strides = std::nullopt; - if (auto it = attributes.find("dilations"); it != attributes.end()) { - dilations.emplace(std::move(it->second.ints())); + if (auto opt = attributes.get("dilations"); opt) { + dilations.emplace(std::move(opt->get().ints())); } - if (auto it = attributes.find("pads"); it != attributes.end()) { - pads.emplace(std::move(it->second.ints())); + if (auto opt = attributes.get("pads"); opt) { + pads.emplace(std::move(opt->get().ints())); } - if (auto it = attributes.find("strides"); it != attributes.end()) { - strides.emplace(std::move(it->second.ints())); + if (auto opt = attributes.get("strides"); opt) { + strides.emplace(std::move(opt->get().ints())); } return OpBox(std::make_unique(std::move(dilations), std::move(pads), std::move(strides))); } diff --git a/src/07onnx/src/operators/cum_sum.cc b/src/07onnx/src/operators/cum_sum.cc index da52292a0..f99dea8a3 100644 --- a/src/07onnx/src/operators/cum_sum.cc +++ b/src/07onnx/src/operators/cum_sum.cc @@ -13,8 +13,8 @@ namespace refactor::onnx { reverse(reverse_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto exclusive = defaultOr(attributes, "exclusive", {0}).int_() != 0; - auto reverse = defaultOr(attributes, "reverse", {0}).int_() != 0; + auto exclusive = attributes.getOrInsert("exclusive", {0}).int_() != 0; + auto reverse = attributes.getOrInsert("reverse", {0}).int_() != 0; return OpBox(std::make_unique(exclusive, reverse)); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/dequantize_linear.cc b/src/07onnx/src/operators/dequantize_linear.cc index 5dc691513..9f540fb4a 100644 --- a/src/07onnx/src/operators/dequantize_linear.cc +++ b/src/07onnx/src/operators/dequantize_linear.cc @@ -8,8 +8,8 @@ namespace refactor::onnx { Op::DequantizeLinear(Int axis_) noexcept : Operator(), axis(axis_) {} - auto Op::build(ModelContext const &, std::string_view, Attributes attrs) -> OpBox { - auto axis = defaultOr(attrs, "axis", {1}).int_(); + auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { + auto axis = attributes.getOrInsert("axis", {1}).int_(); return OpBox(std::make_unique(axis)); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/einsum.cc b/src/07onnx/src/operators/einsum.cc index 48e8ad740..27cfb1d48 100644 --- a/src/07onnx/src/operators/einsum.cc +++ b/src/07onnx/src/operators/einsum.cc @@ -11,7 +11,7 @@ namespace refactor::onnx { : Operator(), equation(std::move(equation_)) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - return OpBox(std::make_unique(std::move(attributes.at("equation").string()))); + return OpBox(std::make_unique(std::move(attributes["equation"].string()))); } auto Op::typeId() -> size_t { static uint8_t ID = 1; diff --git a/src/07onnx/src/operators/expand.cc b/src/07onnx/src/operators/expand.cc index b5e729810..28d6d5499 100644 --- a/src/07onnx/src/operators/expand.cc +++ b/src/07onnx/src/operators/expand.cc @@ -6,8 +6,8 @@ namespace refactor::onnx { using Op = Expand; - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "Expand operator should not have attributes"); + auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { + EXPECT_NO_ATTRI; return OpBox(std::make_unique()); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/flatten.cc b/src/07onnx/src/operators/flatten.cc index 439f17ebb..7ae0484c5 100644 --- a/src/07onnx/src/operators/flatten.cc +++ b/src/07onnx/src/operators/flatten.cc @@ -8,7 +8,7 @@ namespace refactor::onnx { Op::Flatten(Int axis_) : Operator(), axis(axis_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto axis = defaultOr(attributes, "axis", {1}).int_(); + auto axis = attributes.getOrInsert( "axis", {1}).int_(); return OpBox(std::make_unique(axis)); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/gather.cc b/src/07onnx/src/operators/gather.cc index c0d3165ff..3aa088909 100644 --- a/src/07onnx/src/operators/gather.cc +++ b/src/07onnx/src/operators/gather.cc @@ -10,7 +10,7 @@ namespace refactor::onnx { : Operator(), axis(axis_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto axis = defaultOr(attributes, "axis", {0}).int_(); + auto axis = attributes.getOrInsert("axis", {0}).int_(); return OpBox(std::make_unique(axis)); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/gather_elements.cc b/src/07onnx/src/operators/gather_elements.cc index 8a932d70f..c9dce5b10 100644 --- a/src/07onnx/src/operators/gather_elements.cc +++ b/src/07onnx/src/operators/gather_elements.cc @@ -10,7 +10,7 @@ namespace refactor::onnx { : Operator(), axis(axis_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto axis = defaultOr(attributes, "axis", {0}).int_(); + auto axis = attributes.getOrInsert( "axis", {0}).int_(); return OpBox(std::make_unique(axis)); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/gemm.cc b/src/07onnx/src/operators/gemm.cc index 7210bc0c9..634004517 100644 --- a/src/07onnx/src/operators/gemm.cc +++ b/src/07onnx/src/operators/gemm.cc @@ -14,10 +14,10 @@ namespace refactor::onnx { transB(transB_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto alpha = defaultOr(attributes, "alpha", {1.0f}).float_(); - auto beta = defaultOr(attributes, "beta", {1.0f}).float_(); - auto transA = defaultOr(attributes, "transA", {0}).int_() != 0; - auto transB = defaultOr(attributes, "transB", {0}).int_() != 0; + auto alpha = attributes.getOrInsert( "alpha", {1.0f}).float_(); + auto beta = attributes.getOrInsert( "beta", {1.0f}).float_(); + auto transA = attributes.getOrInsert( "transA", {0}).int_() != 0; + auto transB = attributes.getOrInsert( "transB", {0}).int_() != 0; return OpBox(std::make_unique(alpha, beta, transA, transB)); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/global_pool.cc b/src/07onnx/src/operators/global_pool.cc index 2bdbda7ac..253dd3dfa 100644 --- a/src/07onnx/src/operators/global_pool.cc +++ b/src/07onnx/src/operators/global_pool.cc @@ -11,7 +11,7 @@ namespace refactor::onnx { : Operator(), type(type_) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "Global pool operator should not have attributes"); + EXPECT_NO_ATTRI; if (opType == "onnx::GlobalAveragePool") { return OpBox(std::make_unique(Ty::Average)); diff --git a/src/07onnx/src/operators/hard_sigmoid.cc b/src/07onnx/src/operators/hard_sigmoid.cc new file mode 100644 index 000000000..96c4f2d8b --- /dev/null +++ b/src/07onnx/src/operators/hard_sigmoid.cc @@ -0,0 +1,40 @@ +#include "computation/operators/hard_sigmoid.h" +#include "common.h" +#include "hard_sigmoid.hh" +#include + +namespace refactor::onnx { + using Op = HardSigmoid; + + Op::HardSigmoid(Float alpha, Float beta) + : Operator(), alpha(alpha), beta(beta) {} + + auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { + auto alpha = attributes.getOrInsert( "alpha", {0.2f}).float_(); + auto beta = attributes.getOrInsert( "beta", {0.5f}).float_(); + return OpBox(std::make_unique(alpha, beta)); + } + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "onnx::HardSigmoid"; } + + auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { + EXPECT_SIZE(1) + auto dataType = inputs[0].dataType; + if (!dataType.isIeee754()) { + return Err(InferError(ERROR_MSG("Data type not support"))); + } + auto ans = Tensor::share(dataType, inputs[0].shape, extractDependency(inputs)); + return Ok(Tensors{std::move(ans)}); + } + auto Op::lower(TensorRefs) const -> computation::OpBox { + using Op_ = computation::HardSigmoid; + return std::make_unique(alpha, beta); + } + + +}// namespace refactor::onnx diff --git a/src/07onnx/src/operators/hard_sigmoid.hh b/src/07onnx/src/operators/hard_sigmoid.hh new file mode 100644 index 000000000..35590a57f --- /dev/null +++ b/src/07onnx/src/operators/hard_sigmoid.hh @@ -0,0 +1,25 @@ +#ifndef ONNX_HARD_SIGMOID_HH +#define ONNX_HARD_SIGMOID_HH + +#include "frontend/operator.h" + +namespace refactor::onnx { + using namespace frontend; + + struct HardSigmoid final : public Operator { + Float alpha, beta; + + explicit HardSigmoid(Float, Float); + + static OpBox build(ModelContext const &, std::string_view, Attributes); + static size_t typeId(); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + +}// namespace refactor::onnx + +#endif// ONNX_HARD_SIGMOID_HH diff --git a/src/07onnx/src/operators/mat_mul.cc b/src/07onnx/src/operators/mat_mul.cc index 0850ba6f0..7eb263761 100644 --- a/src/07onnx/src/operators/mat_mul.cc +++ b/src/07onnx/src/operators/mat_mul.cc @@ -5,8 +5,8 @@ namespace refactor::onnx { using Op = MatMul; - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "MatMul operator should not have attributes"); + auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { + EXPECT_NO_ATTRI; return OpBox(std::make_unique()); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/mat_mul_integer.cc b/src/07onnx/src/operators/mat_mul_integer.cc index de5de0525..96443a467 100644 --- a/src/07onnx/src/operators/mat_mul_integer.cc +++ b/src/07onnx/src/operators/mat_mul_integer.cc @@ -6,8 +6,8 @@ namespace refactor::onnx { using Op = MatMulInteger; - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "MatMulInteger operator should not have attributes"); + auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { + EXPECT_NO_ATTRI; return OpBox(std::make_unique()); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/pool.cc b/src/07onnx/src/operators/pool.cc index 76a52266f..76fe9fa02 100644 --- a/src/07onnx/src/operators/pool.cc +++ b/src/07onnx/src/operators/pool.cc @@ -21,19 +21,19 @@ namespace refactor::onnx { strides(std::move(strides_)) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - auto kernelShape = std::move(attributes.at("kernel_shape").ints()); + auto kernelShape = std::move(attributes["kernel_shape"].ints()); OptionalInts dilations = std::nullopt, pads = std::nullopt, strides = std::nullopt; - if (auto it = attributes.find("dilations"); it != attributes.end()) { - dilations.emplace(std::move(it->second.ints())); + if (auto opt = attributes.get("dilations"); opt) { + dilations.emplace(std::move(opt->get().ints())); } - if (auto it = attributes.find("pads"); it != attributes.end()) { - pads.emplace(std::move(it->second.ints())); + if (auto opt = attributes.get("pads"); opt) { + pads.emplace(std::move(opt->get().ints())); } - if (auto it = attributes.find("strides"); it != attributes.end()) { - strides.emplace(std::move(it->second.ints())); + if (auto opt = attributes.get("strides"); opt) { + strides.emplace(std::move(opt->get().ints())); } Ty ty; diff --git a/src/07onnx/src/operators/reduce.cc b/src/07onnx/src/operators/reduce.cc index 6d9f0f033..a8ae15b5e 100644 --- a/src/07onnx/src/operators/reduce.cc +++ b/src/07onnx/src/operators/reduce.cc @@ -21,12 +21,12 @@ namespace refactor::onnx { auto noopWithEmptyAxes = false; decltype(Op::axes) axes = std::nullopt; if (opsetVer >= 18) { - noopWithEmptyAxes = defaultOr(attributes, "noop_with_empty_axes", {0}).int_() != 0; + noopWithEmptyAxes = attributes.getOrInsert( "noop_with_empty_axes", {0}).int_() != 0; } else { - axes.emplace(defaultOr(attributes, "axes", {{}}).ints()); + axes.emplace(attributes.getOrInsert( "axes", {{}}).ints()); } - auto keepDims = defaultOr(attributes, "keepdims", {1}).int_(); + auto keepDims = attributes.getOrInsert( "keepdims", {1}).int_(); Ty ty; if (opType == "onnx::ReduceMean") { ty = Ty::Mean; diff --git a/src/07onnx/src/operators/reshape.cc b/src/07onnx/src/operators/reshape.cc index 9a72672c5..cfa443ba5 100644 --- a/src/07onnx/src/operators/reshape.cc +++ b/src/07onnx/src/operators/reshape.cc @@ -9,7 +9,7 @@ namespace refactor::onnx { : Operator(), allowzero(allowzero_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto allowzero = defaultOr(attributes, "allowzero", {0}).int_() != 0; + auto allowzero = attributes.getOrInsert( "allowzero", {0}).int_() != 0; return OpBox(std::make_unique(allowzero)); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/scatter_nd.cc b/src/07onnx/src/operators/scatter_nd.cc index b6d40cb9b..f4876e8b1 100644 --- a/src/07onnx/src/operators/scatter_nd.cc +++ b/src/07onnx/src/operators/scatter_nd.cc @@ -5,9 +5,9 @@ namespace refactor::onnx { using Op = ScatterND; - auto Op::build(ModelContext const &ctx, std::string_view, Attributes attrs) -> OpBox { - if (auto it = attrs.find("reduction"); it != attrs.end()) { - ASSERT(it->second.isString() && it->second.string() == "none", + auto Op::build(ModelContext const &ctx, std::string_view, Attributes attributes) -> OpBox { + if (auto opt = attributes.get("reduction"); opt) { + ASSERT(opt->get().isString() && opt->get().string() == "none", "currently only support `reduction = none`"); } return OpBox(std::make_unique()); diff --git a/src/07onnx/src/operators/select.cc b/src/07onnx/src/operators/select.cc index b215f2e4e..864ba9be2 100644 --- a/src/07onnx/src/operators/select.cc +++ b/src/07onnx/src/operators/select.cc @@ -11,7 +11,7 @@ namespace refactor::onnx { : Operator(), type(type_) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "Select operator should not have attributes"); + EXPECT_NO_ATTRI; auto type = opType == "onnx::Max" ? Ty::Max : opType == "onnx::Min" ? Ty::Min diff --git a/src/07onnx/src/operators/shape.cc b/src/07onnx/src/operators/shape.cc index 4881299f6..299b3a74d 100644 --- a/src/07onnx/src/operators/shape.cc +++ b/src/07onnx/src/operators/shape.cc @@ -10,10 +10,10 @@ namespace refactor::onnx { end(end_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto start = defaultOr(attributes, "start", {0}).int_(); + auto start = attributes.getOrInsert("start", {0}).int_(); std::optional end = std::nullopt; - if (auto it = attributes.find("end"); it != attributes.end()) { - end.emplace(it->second.int_()); + if (auto opt = attributes.get("end"); opt) { + end.emplace(opt->get().int_()); } return OpBox(std::make_unique(start, end)); } diff --git a/src/07onnx/src/operators/simple_binary.cc b/src/07onnx/src/operators/simple_binary.cc index a1bd5b24d..2db99bdd3 100644 --- a/src/07onnx/src/operators/simple_binary.cc +++ b/src/07onnx/src/operators/simple_binary.cc @@ -10,7 +10,7 @@ namespace refactor::onnx { : Operator(), type(type_) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "Simple binary operator should not have attributes"); + auto fmod = attributes.getOrInsert( "fmod", {0}).int_(); // clang-format off auto type = opType == "onnx::Add" ? Ty::Add : @@ -21,6 +21,7 @@ namespace refactor::onnx { opType == "onnx::And" ? Ty::And : opType == "onnx::Or" ? Ty::Or : opType == "onnx::Xor" ? Ty::Xor : + opType == "onnx::Mod" ? (fmod == 0 ? Ty::Mod : Ty::Fmod) : UNREACHABLEX(Ty, "Unsupported binary operator: {}", opType); // clang-format on return OpBox(std::make_unique(type)); @@ -48,6 +49,26 @@ namespace refactor::onnx { static uint8_t ID = 5; return reinterpret_cast(&ID); } + case Ty::And: { + static uint8_t ID = 6; + return reinterpret_cast(&ID); + } + case Ty::Or: { + static uint8_t ID = 7; + return reinterpret_cast(&ID); + } + case Ty::Xor: { + static uint8_t ID = 8; + return reinterpret_cast(&ID); + } + case Ty::Mod: { + static uint8_t ID = 9; + return reinterpret_cast(&ID); + } + case Ty::Fmod: { + static uint8_t ID = 10; + return reinterpret_cast(&ID); + } default: UNREACHABLE(); } @@ -65,6 +86,8 @@ namespace refactor::onnx { case Ty::And: return "onnx::And"; case Ty::Or : return "onnx::Or" ; case Ty::Xor: return "onnx::Xor"; + case Ty::Mod: return "onnx::Mod"; + case Ty::Fmod: return "onnx::Mod"; default: UNREACHABLE(); } // clang-format on @@ -162,6 +185,8 @@ namespace refactor::onnx { case Ty::And : type_ = Ty_::And; break; case Ty::Or : type_ = Ty_::Or ; break; case Ty::Xor : type_ = Ty_::Xor; break; + case Ty::Mod : type_ = Ty_::Mod; break; + case Ty::Fmod : type_ = Ty_::Fmod; break; default: UNREACHABLE(); } // clang-format on diff --git a/src/07onnx/src/operators/simple_binary.hh b/src/07onnx/src/operators/simple_binary.hh index dfcacc17d..4c948f5fc 100644 --- a/src/07onnx/src/operators/simple_binary.hh +++ b/src/07onnx/src/operators/simple_binary.hh @@ -15,6 +15,8 @@ namespace refactor::onnx { And, Or, Xor, + Mod, + Fmod, }; struct SimpleBinary final : public Operator { diff --git a/src/07onnx/src/operators/simple_unary.cc b/src/07onnx/src/operators/simple_unary.cc index 8b51319a1..8ce5e14fd 100644 --- a/src/07onnx/src/operators/simple_unary.cc +++ b/src/07onnx/src/operators/simple_unary.cc @@ -12,7 +12,7 @@ namespace refactor::onnx { : Operator(), type(type_) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "Simple binary operator should not have attributes"); + EXPECT_NO_ATTRI; // clang-format off auto type = @@ -37,6 +37,7 @@ namespace refactor::onnx { opType == "onnx::Not" ? Ty::Not : opType == "onnx::Neg" ? Ty::Neg : opType == "onnx::Identity"? Ty::Identity: + opType == "onnx::HardSwish" ? Ty::HardSwish : UNREACHABLEX(Ty, "Unsupported unary operator: {}", opType); // clang-format on @@ -129,6 +130,10 @@ namespace refactor::onnx { static uint8_t ID = 21; return reinterpret_cast(&ID); } + case Ty::HardSwish: { + static uint8_t ID = 22; + return reinterpret_cast(&ID); + } default: UNREACHABLE(); } @@ -159,6 +164,7 @@ namespace refactor::onnx { case Ty::Not : return "onnx::Not"; case Ty::Neg : return "onnx::Neg"; case Ty::Identity : return "onnx::Identity"; + case Ty::HardSwish : return "onnx::HardSwish"; default: UNREACHABLE(); } // clang-format on @@ -187,7 +193,7 @@ namespace refactor::onnx { Ty::Atan, Ty::Atanh, Ty::Cos, Ty::Cosh, Ty::Sin, Ty::Sinh, - Ty::Tan}, + Ty::Tan, Ty::HardSwish}, {Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log}, {Ty::Neg}, {Ty::Identity}}; @@ -287,6 +293,7 @@ namespace refactor::onnx { case Ty::Not : type_ = Ty_::Not ; break; case Ty::Neg : type_ = Ty_::Neg ; break; case Ty::Identity : return std::make_unique(); + case Ty::HardSwish : type_ = Ty_::HardSwish ; break; default: UNREACHABLE(); } // clang-format on diff --git a/src/07onnx/src/operators/simple_unary.hh b/src/07onnx/src/operators/simple_unary.hh index e0f8275fa..746a17752 100644 --- a/src/07onnx/src/operators/simple_unary.hh +++ b/src/07onnx/src/operators/simple_unary.hh @@ -16,18 +16,19 @@ namespace refactor::onnx { Atanh, Cos, Cosh, - Sin, - Sinh, - Tan, - Tanh, - Relu, - Sqrt, - Sigmoid, Erf, + HardSwish, + Identity, Log, Not, Neg, - Identity, + Relu, + Sin, + Sinh, + Sqrt, + Sigmoid, + Tan, + Tanh, }; struct SimpleUnary final : public Operator { diff --git a/src/07onnx/src/operators/slice.cc b/src/07onnx/src/operators/slice.cc index 0a0853a85..d7e9e5a0c 100644 --- a/src/07onnx/src/operators/slice.cc +++ b/src/07onnx/src/operators/slice.cc @@ -7,8 +7,8 @@ namespace refactor::onnx { using computation::Dimensions; using Op = Slice; - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "Slice operator should not have attributes"); + auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { + EXPECT_NO_ATTRI; return OpBox(std::make_unique()); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/softmax.cc b/src/07onnx/src/operators/softmax.cc index 5cc54e62a..21264f51b 100644 --- a/src/07onnx/src/operators/softmax.cc +++ b/src/07onnx/src/operators/softmax.cc @@ -9,7 +9,7 @@ namespace refactor::onnx { : Operator(), axis(axis_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto axis = defaultOr(attributes, "axis", {-1}).int_(); + auto axis = attributes.getOrInsert( "axis", {-1}).int_(); return OpBox(std::make_unique(axis)); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/split.cc b/src/07onnx/src/operators/split.cc index a63733dd7..bfd0fbcd0 100644 --- a/src/07onnx/src/operators/split.cc +++ b/src/07onnx/src/operators/split.cc @@ -12,8 +12,8 @@ namespace refactor::onnx { numOutputs(numOutputs_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto axis = defaultOr(attributes, "axis", {0}).int_(); - auto numOutputs = defaultOr(attributes, "num_outputs", {0}).int_(); + auto axis = attributes.getOrInsert( "axis", {0}).int_(); + auto numOutputs = attributes.getOrInsert( "num_outputs", {0}).int_(); return OpBox(std::make_unique(axis, numOutputs)); } auto Op::typeId() -> size_t { @@ -39,7 +39,7 @@ namespace refactor::onnx { auto dependencies = extractDependency(inputs); if (inputs.size() == 1) { Tensors ans(numOutputs, nullptr); - auto each = total + numOutputs - 1 / numOutputs; + auto each = (total + numOutputs - 1) / numOutputs; for (auto i : range0_(numOutputs)) { if (total > each) { ans[i] = Tensor::share(input.dataType, input.shape, dependencies); diff --git a/src/07onnx/src/operators/squeeze.cc b/src/07onnx/src/operators/squeeze.cc index e89f5e4db..63d843cf0 100644 --- a/src/07onnx/src/operators/squeeze.cc +++ b/src/07onnx/src/operators/squeeze.cc @@ -7,19 +7,19 @@ namespace refactor::onnx { Op::Squeeze(decltype(axes) axes_) : Operator(), axes(std::move(axes_)) {} - auto Op::build(ModelContext const &ctx, std::string_view, Attributes attributes) -> OpBox { + auto Op::build(ModelContext const &ctx, std::string_view opType, Attributes attributes) -> OpBox { auto iter = ctx.find("opset_version"); auto opsetVer = iter != ctx.end() ? iter->second.int_() : StandardOpsetVersion; if (opsetVer >= 13) { - ASSERT(attributes.empty(), "Squeeze operator should not have attributes"); + EXPECT_NO_ATTRI; return OpBox(std::make_unique( std::nullopt)); - } else if (auto it = attributes.find("axes"); it != attributes.end()) { + } else if (auto opt = attributes.get("axes"); opt) { return OpBox(std::make_unique( std::make_optional( std::make_optional( - std::move(it->second.ints()))))); + std::move(opt->get().ints()))))); } else { return OpBox(std::make_unique( std::make_optional>( diff --git a/src/07onnx/src/operators/tile.cc b/src/07onnx/src/operators/tile.cc index 14a13fb2f..f132b9a5f 100644 --- a/src/07onnx/src/operators/tile.cc +++ b/src/07onnx/src/operators/tile.cc @@ -5,8 +5,8 @@ namespace refactor::onnx { using Op = Tile; - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "Tile operator should not have attributes"); + auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { + EXPECT_NO_ATTRI; return OpBox(std::make_unique()); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/transpose.cc b/src/07onnx/src/operators/transpose.cc index 6a61cef75..733b06f37 100644 --- a/src/07onnx/src/operators/transpose.cc +++ b/src/07onnx/src/operators/transpose.cc @@ -10,7 +10,7 @@ namespace refactor::onnx { : Operator(), perm(perm_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto perm = defaultOr(attributes, "perm", {}).ints(); + auto perm = attributes.getOrInsert( "perm", {}).ints(); return OpBox(std::make_unique(std::move(perm))); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/unsqueeze.cc b/src/07onnx/src/operators/unsqueeze.cc index 496a988d7..9a546eaf2 100644 --- a/src/07onnx/src/operators/unsqueeze.cc +++ b/src/07onnx/src/operators/unsqueeze.cc @@ -7,15 +7,15 @@ namespace refactor::onnx { Op::Unsqueeze(decltype(axes) axes_) : Operator(), axes(std::move(axes_)) {} - auto Op::build(ModelContext const &ctx, std::string_view, Attributes attributes) -> OpBox { + auto Op::build(ModelContext const &ctx, std::string_view opType, Attributes attributes) -> OpBox { auto iter = ctx.find("opset_version"); auto opsetVer = iter != ctx.end() ? iter->second.int_() : StandardOpsetVersion; if (opsetVer >= 13) { - ASSERT(attributes.empty(), "Unsqueeze operator should not have attributes"); + EXPECT_NO_ATTRI; return OpBox(std::make_unique(std::nullopt)); } else { - return OpBox(std::make_unique(std::make_optional(attributes.at("axes").ints()))); + return OpBox(std::make_unique(std::make_optional(attributes["axes"].ints()))); } } auto Op::typeId() -> size_t { diff --git a/src/07onnx/src/operators/where.cc b/src/07onnx/src/operators/where.cc index eea30e691..4aed5dffa 100644 --- a/src/07onnx/src/operators/where.cc +++ b/src/07onnx/src/operators/where.cc @@ -6,8 +6,8 @@ namespace refactor::onnx { using Op = Where; - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - ASSERT(attributes.empty(), "Where operator should not have attributes"); + auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { + EXPECT_NO_ATTRI; return OpBox(std::make_unique()); } auto Op::typeId() -> size_t { diff --git a/src/07onnx/test/test_hard_sigmoid.cpp b/src/07onnx/test/test_hard_sigmoid.cpp new file mode 100644 index 000000000..846f2d945 --- /dev/null +++ b/src/07onnx/test/test_hard_sigmoid.cpp @@ -0,0 +1,23 @@ +#include "../src/operators/hard_sigmoid.hh" +#include "onnx/operators.h" +#include + +using namespace refactor; +using namespace onnx; + +TEST(infer, HardSigmoid) { + onnx::register_(); + + auto edges = Edges{ + {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""}, + }; + count_t inputs[]{0}; + auto infered = HardSigmoid(0.2f, 0.5f).infer(TensorRefs(edges, inputs), {false}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::F32); + ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)})); +} + diff --git a/src/07onnx/test/test_simple_unary.cpp b/src/07onnx/test/test_simple_unary.cpp new file mode 100644 index 000000000..bd1e7959a --- /dev/null +++ b/src/07onnx/test/test_simple_unary.cpp @@ -0,0 +1,39 @@ +#include "../src/operators/simple_unary.hh" +#include "onnx/operators.h" +#include + +using namespace refactor; +using namespace onnx; + +TEST(infer, SimpleUnary) { + onnx::register_(); + + { + // Erf Test + auto edges = Edges{ + {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""}, + }; + count_t inputs[]{0}; + auto infered = SimpleUnary(SimpleUnaryType::Erf).infer(TensorRefs(edges, inputs), {true}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::F32); + ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)})); + } + { + // HardSwish Test + auto edges = Edges{ + {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""}, + }; + count_t inputs[]{0}; + auto infered = SimpleUnary(SimpleUnaryType::HardSwish).infer(TensorRefs(edges, inputs), {true}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::F32); + ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)})); + } +} diff --git a/src/08-01llm/CMakeLists.txt b/src/08-01llm/CMakeLists.txt new file mode 100644 index 000000000..f95e1fbeb --- /dev/null +++ b/src/08-01llm/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.12 FATAL_ERROR) +project(llm VERSION 0.0.0 LANGUAGES CXX) +message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION}) + +file(GLOB_RECURSE LLM_SRC src/*.cc src/*.cpp) +add_library(llm STATIC ${LLM_SRC}) +target_link_libraries(llm PUBLIC frontend) +target_include_directories(llm PUBLIC include) + +file(GLOB_RECURSE LLM_TEST test/*.cpp) +if(LLM_TEST) + add_executable(llm_test ${LLM_TEST}) + add_test(llm_test llm_test) + target_link_libraries(llm_test llm GTest::gtest_main Backward::Object) +endif() diff --git a/src/08-01llm/README.md b/src/08-01llm/README.md new file mode 100644 index 000000000..85b672311 --- /dev/null +++ b/src/08-01llm/README.md @@ -0,0 +1,63 @@ +# 大模型自定义算子 + +## RMS Normalization + +### Summary + +```plaintext + ___ → → +y = (x^2 + δ)^(-1/2) * w * x +``` + +### Attributes + +- **epsilon - FLOAT** (default is `1e-5`): 防止除 0 异常的小数字 ε。 + +### Inputs + +2 Inputs: + +- **X(heterogeneous) - T**: 来自之前算子的输入数据张量。形状为 `N1 x N2 ... D`,`Nx` 可以为任意维度,将在长度为 `D` 的最后一个维度上标准化。 +- **W(heterogeneous) - T**: 权重张量。形状为 `D`,`D` 为 `X` 的最后一个维度的长度。 + +### Outputs + +1 Output: + +- **Y(heterogeneous) - T**: 输出张量。形状与 `X` 相同。 + +## Attention + +### Summary + +Multi-head Self Attention 的封装形式,用于 transformer 模型。 + +支持使用 kv cache,使用条件由输入和属性综合决定。有以下 6 种情况: + +| 序号 | 输入数量 | `max_seq_len` | 使用 kv cache | 输出数量 | cache s 维度 | 备注 +|:-:|:-:|:-----:|:-------:|:-:|:------------------------:|:- +| 1 | 3 | 0 | none | 1 | - | +| 2 | 3 | S > 0 | init | 3 | `S` | `assert(S >= seq_len)` +| 3 | 4 | 0 | inplace | 3 | `past_seq_len + seq_len` | `past_seq_len` 必须是常量 +| 4 | 4 | S > 0 | inplace | 3 | `S` | `assert(S >= past_seq_len + seq_len)` +| 5 | 6 | 0 | copy | 3 | `past_seq_len + seq_len` | `past_seq_len` 必须是常量 +| 6 | 6 | S > 0 | copy | 3 | `S` | `assert(S >= past_seq_len + seq_len)` + +### Attributes + +- **max_seq_len - INT** (default is `0`): 最大序列长度,用于初始化 kv cache。 + +### Inputs + +- **query(heterogeneous) - T**: 形状为 `N x n_head x seq_len x head_dim`。 +- **key(heterogeneous) - T**: 形状为 `N x n_kv_head x seq_len x head_dim`。 +- **value(heterogeneous) - T**: 形状为 `N x n_kv_head x seq_len x head_dim`。 +- **past_seq_len(optional) -int64**: 要连接的历史序列长度,必须为标量。不使用 kv cache 时留空。 +- **k_cache(optional, heterogeneous) -T**: k 缓存的初始值,形状为 `N x n_kv_head x s x head_dim`,`s` 为不小于 `past_seq_len` 的任意值。不使用或不重置 kv cache 时留空。 +- **v_cache(optional, heterogeneous) -T**: v 缓存的初始值,形状为 `N x n_kv_head x s x head_dim`,`s` 为不小于 `past_seq_len` 的任意值。不使用或不重置 kv cache 时留空。 + +### Outputs + +- **output(heterogeneous) - T**: 形状与 `query` 相同。 +- **k_cache(optional, heterogeneous) - T**: 形状为 `N x n_kv_head x s x head_dim`。`s` 的值根据 `Summary` 的描述计算。 +- **v_cache(optional, heterogeneous) - T**: 形状为 `N x n_kv_head x s x head_dim`。`s` 的值根据 `Summary` 的描述计算。 diff --git a/src/08-01llm/include/llm/operators.h b/src/08-01llm/include/llm/operators.h new file mode 100644 index 000000000..d589c2415 --- /dev/null +++ b/src/08-01llm/include/llm/operators.h @@ -0,0 +1,10 @@ +#ifndef LLM_OPERATORS_H +#define LLM_OPERATORS_H + +namespace refactor::llm { + + void register_(); + +}// namespace refactor::llm + +#endif// LLM_OPERATORS_H diff --git a/src/08-01llm/src/operators.cpp b/src/08-01llm/src/operators.cpp new file mode 100644 index 000000000..be610eab5 --- /dev/null +++ b/src/08-01llm/src/operators.cpp @@ -0,0 +1,15 @@ +#include "llm/operators.h" +#include "operators/mat_mul.hh" + +namespace refactor::llm { + using namespace frontend; + + void register_() { + // clang-format off + #define REGISTER(NAME, CLASS) Operator::register_("llm::" #NAME) + REGISTER(MatMul, MatMul); + #undef REGISTER + // clang-format on + } + +}// namespace refactor::llm diff --git a/src/08-01llm/src/operators/common.h b/src/08-01llm/src/operators/common.h new file mode 100644 index 000000000..d62053d94 --- /dev/null +++ b/src/08-01llm/src/operators/common.h @@ -0,0 +1,19 @@ +#ifndef LLM_COMMON_H +#define LLM_COMMON_H + +#include "common.h" + +#define EXPECT_SIZE(N) \ + if (inputs.size() != (N)) { \ + return Err(InferError(ERROR_MSG("Input size error"))); \ + } + +#define EXPECT_VAL(DIM, VAL) \ + int64_t VAL; \ + if ((DIM).hasValue()) { \ + VAL = (DIM).value(); \ + } else { \ + return Err(InferError(UnknownVariable{(DIM.variable()->name)})); \ + } + +#endif// LLM_COMMON_H diff --git a/src/08-01llm/src/operators/mat_mul.cc b/src/08-01llm/src/operators/mat_mul.cc new file mode 100644 index 000000000..99f123cdc --- /dev/null +++ b/src/08-01llm/src/operators/mat_mul.cc @@ -0,0 +1,87 @@ +#include "computation/operators/mat_mul.h" +#include "common.h" +#include "mat_mul.hh" + +namespace refactor::llm { + using Op = MatMul; + + Op::MatMul( + decltype(transA) transA_, + decltype(transB) transB_) + : Operator(), + transA(transA_), + transB(transB_) {} + + auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { + auto transA = attributes.getOrInsert("transA", {0}).int_() != 0; + auto transB = attributes.getOrInsert("transB", {0}).int_() != 0; + return OpBox(std::make_unique(transA, transB)); + } + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "llm::MatMul"; } + + auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { + EXPECT_SIZE(2) + + auto const &a = inputs[0]; + auto const &b = inputs[1]; + auto dataType = a.dataType; + if (!dataType.isNumberic() || b.dataType != dataType) { + return Err(InferError(ERROR_MSG("Input data type not support"))); + } + auto sa = a.shape, sb = b.shape; + switch (sa.size()) { + case 1: + sa.insert(sa.begin(), DimExpr(1)); + break; + case 0: + return Err(InferError(ERROR_MSG("Input shape not support"))); + default: + break; + } + switch (sb.size()) { + case 1: + sb.emplace_back(1); + break; + case 0: + return Err(InferError(ERROR_MSG("Input shape not support"))); + default: + break; + } + DimExpr m(1), n(1), ka(1), kb(1); + if (!transA) { + m = sa.rbegin()[1]; + ka = sa.rbegin()[0]; + } else { + m = sa.rbegin()[0]; + ka = sa.rbegin()[1]; + } + sa.pop_back(); + sa.pop_back(); + if (!transB) { + kb = sb.rbegin()[1]; + n = sb.rbegin()[0]; + } else { + kb = sb.rbegin()[0]; + n = sb.rbegin()[1]; + } + sb.pop_back(); + sb.pop_back(); + ASSERT(ka == kb, "Input shape not support"); + MULTIDIR_BROADCAST((ShapeRefs{sa, sb})) + output.emplace_back(std::move(m)); + output.emplace_back(std::move(n)); + return Ok(Tensors{Tensor::share(dataType, std::move(output), extractDependency(inputs))}); + } + + auto Op::lower(TensorRefs) const -> computation::OpBox { + using Op_ = computation::MatMul; + return std::make_unique(1., 1., transA, transB); + } + +}// namespace refactor::llm diff --git a/src/08-01llm/src/operators/mat_mul.hh b/src/08-01llm/src/operators/mat_mul.hh new file mode 100644 index 000000000..ef1ca4bba --- /dev/null +++ b/src/08-01llm/src/operators/mat_mul.hh @@ -0,0 +1,25 @@ +#ifndef LLM_MAT_MUL_HH +#define LLM_MAT_MUL_HH + +#include "frontend/operator.h" + +namespace refactor::llm { + using namespace frontend; + + struct MatMul final : public Operator { + bool transA, transB; + + MatMul(decltype(transA), decltype(transB)); + + static OpBox build(ModelContext const &, std::string_view, Attributes); + static size_t typeId(); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + +}// namespace refactor::llm + +#endif// LLM_MAT_MUL_HH diff --git a/src/08-01llm/src/operators/rms_normalization.cc b/src/08-01llm/src/operators/rms_normalization.cc new file mode 100644 index 000000000..ca4a6a95e --- /dev/null +++ b/src/08-01llm/src/operators/rms_normalization.cc @@ -0,0 +1,42 @@ +#include "rms_normalization.hh" +#include "common.h" +#include "computation/operators/rms_normalization.h" + +namespace refactor::llm { + using Op = RmsNormalization; + + Op::RmsNormalization(decltype(epsilon) epsilon_) + : Operator(), epsilon(epsilon_) {} + + auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { + auto epsilon = attributes.getOrInsert("epsilon", {1e-5f}).float_(); + return OpBox(std::make_unique(epsilon)); + } + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "llm::RmsNormalization"; } + + auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult { + EXPECT_SIZE(2) + + auto const &x = inputs[0]; + auto const &w = inputs[1]; + if (x.rank() < 1 || w.rank() != 1 || x.shape.back() != w.shape.back()) { + return Err(InferError(ERROR_MSG("Input shape not support"))); + } + if (x.dataType != w.dataType) { + return Err(InferError(ERROR_MSG("Input data type not support"))); + } + return Ok(Tensors{Tensor::share(x.dataType, x.shape, extractDependency(inputs))}); + } + + auto Op::lower(TensorRefs) const -> computation::OpBox { + using Op_ = computation::RmsNormalization; + return std::make_unique(epsilon); + } + +}// namespace refactor::llm diff --git a/src/08-01llm/src/operators/rms_normalization.hh b/src/08-01llm/src/operators/rms_normalization.hh new file mode 100644 index 000000000..dfd21c866 --- /dev/null +++ b/src/08-01llm/src/operators/rms_normalization.hh @@ -0,0 +1,25 @@ +#ifndef LLM_RMS_NORMALIZATION_HH +#define LLM_RMS_NORMALIZATION_HH + +#include "frontend/operator.h" + +namespace refactor::llm { + using namespace frontend; + + struct RmsNormalization final : public Operator { + float epsilon; + + RmsNormalization(decltype(epsilon)); + + static OpBox build(ModelContext const &, std::string_view, Attributes); + static size_t typeId(); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + +}// namespace refactor::llm + +#endif// LLM_RMS_NORMALIZATION_HH diff --git a/src/08-01llm/test/test_rms_normalization.cpp b/src/08-01llm/test/test_rms_normalization.cpp new file mode 100644 index 000000000..b24d5e032 --- /dev/null +++ b/src/08-01llm/test/test_rms_normalization.cpp @@ -0,0 +1,22 @@ +#include "../src/operators/rms_normalization.hh" +#include "llm/operators.h" +#include + +using namespace refactor; +using namespace llm; + +TEST(infer, RmsNormalization) { + llm::register_(); + auto edges = Edges{ + {Tensor::share(DataType::F32, Shape{DimExpr(7), DimExpr(2), DimExpr(3)}, {}), ""}, + {Tensor::share(DataType::F32, Shape{DimExpr(3)}, {}), ""}, + }; + count_t inputs[]{0, 1}; + auto infered = RmsNormalization(1e-6).infer(TensorRefs(edges, inputs), {true}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::F32); + ASSERT_EQ(y->shape, (Shape{DimExpr(7), DimExpr(2), DimExpr(3)})); +} diff --git a/src/08communication/CMakeLists.txt b/src/08communication/CMakeLists.txt index e39c47ead..538766710 100644 --- a/src/08communication/CMakeLists.txt +++ b/src/08communication/CMakeLists.txt @@ -11,6 +11,5 @@ file(GLOB_RECURSE COMMUNICATION_TEST test/*.cpp) if(COMMUNICATION_TEST) add_executable(communication_test ${COMMUNICATION_TEST}) add_test(communication_test communication_test) - target_link_libraries(communication_test communication GTest::gtest_main ${BACKWARD_ENABLE}) - add_backward(communication_test) + target_link_libraries(communication_test communication GTest::gtest_main Backward::Object) endif() diff --git a/src/08communication/src/operators/all_gather.cc b/src/08communication/src/operators/all_gather.cc index ea62e3aab..2be27ce99 100644 --- a/src/08communication/src/operators/all_gather.cc +++ b/src/08communication/src/operators/all_gather.cc @@ -8,7 +8,7 @@ namespace refactor::communication { : Operator(), nranks(nranks_) {} auto Op::build(ModelContext const &ctx, std::string_view, Attributes attributes) -> OpBox { - auto nranks = attributes.at("nranks").int_(); + auto nranks = attributes["nranks"].int_(); return OpBox(std::make_unique(nranks)); } auto Op::typeId() -> size_t { diff --git a/src/09python_ffi/CMakeLists.txt b/src/09python_ffi/CMakeLists.txt index a14a9e739..ccce34d37 100644 --- a/src/09python_ffi/CMakeLists.txt +++ b/src/09python_ffi/CMakeLists.txt @@ -7,7 +7,7 @@ add_subdirectory(pybind11) file(GLOB_RECURSE PYFFI_SRC src/*.cc src/*.cpp) pybind11_add_module(python_ffi SHARED ${PYFFI_SRC}) -target_link_libraries(python_ffi PRIVATE onnx communication) +target_link_libraries(python_ffi PRIVATE onnx llm communication) target_include_directories(python_ffi PRIVATE include) # EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a diff --git a/src/09python_ffi/src/compiler.cc b/src/09python_ffi/src/compiler.cc index 5ddc61516..bf04053e9 100644 --- a/src/09python_ffi/src/compiler.cc +++ b/src/09python_ffi/src/compiler.cc @@ -19,7 +19,21 @@ namespace refactor::python_ffi { } void - Compiler::setInput(size_t index, int dataType, DimVec dims) { + Compiler::setInput(size_t index, pybind11::array data) { + ASSERT(index < _g.internal().topology.globalInputsCount(), + "Input {} not exist", index); + + Shape shape(data.ndim(), DimExpr(1)); + std::transform(std::execution::unseq, + data.shape(), data.shape() + data.ndim(), shape.begin(), + [](auto const &d) { return DimExpr(d); }); + auto ans = Tensor::share(parseNumpyDType(data.dtype()), std::move(shape), {}); + std::memcpy(ans->malloc(), data.data(), data.nbytes()); + _g.internal().edges[index].tensor = std::move(ans); + } + + void + Compiler::setInputInfo(size_t index, int dataType, DimVec dims) { ASSERT(index < _g.internal().topology.globalInputsCount(), "Input {} not exist", index); diff --git a/src/09python_ffi/src/compiler.h b/src/09python_ffi/src/compiler.h index 5cb5d741a..70657adb7 100644 --- a/src/09python_ffi/src/compiler.h +++ b/src/09python_ffi/src/compiler.h @@ -15,7 +15,8 @@ namespace refactor::python_ffi { public: explicit Compiler(frontend::Graph); void substitute(CStr, int64_t); - void setInput(size_t index, int dataType, DimVec dims); + void setInput(size_t index, pybind11::array); + void setInputInfo(size_t index, int dataType, DimVec dims); std::unordered_set fillEdgeInfo(bool calculate); Arc compileOn( Arc device, diff --git a/src/09python_ffi/src/executor.cc b/src/09python_ffi/src/executor.cc index 92b83087b..c6a20cb95 100644 --- a/src/09python_ffi/src/executor.cc +++ b/src/09python_ffi/src/executor.cc @@ -26,33 +26,43 @@ namespace refactor::python_ffi { for (auto i : graph.topology.globalInputs()) { auto size = graph.edges[i].tensor->bytesSize(); buffer.resize(size); - if (stream.getData(i, buffer.data(), size)) { + if (stream.copyData(i, buffer.data(), size)) { _stream.setData(i, buffer.data(), size); } } } void Executor::setInput(count_t i, pybind11::array data) { - i = _stream.graph().topology.globalInputs().at(i); + i = _graph.internal().contiguous().topology.globalInputs().at(i); - auto const &name = _stream.graph().edges[i].name; - auto const &edges = _graph.internal().contiguous().edges; - auto const &tensor = *std::find_if(edges.begin(), edges.end(), [&](auto const &e) { return e.name == name; })->tensor; + auto const &tensor = *_graph.internal().contiguous().edges[i].tensor; ASSERT(tensor.bytesSize() == static_cast(data.nbytes()), "input size mismatch"); _stream.setData(i, data.data(), data.nbytes()); } - auto Executor::getOutput(count_t i) -> pybind11::array { - i = _stream.graph().topology.globalOutputs().at(i); + void Executor::setInputBlob(count_t i, Arc blob) { + i = _graph.internal().contiguous().topology.globalInputs().at(i); - auto const &name = _stream.graph().edges[i].name; - auto const &edges = _graph.internal().contiguous().edges; - auto const &tensor = *std::find_if(edges.begin(), edges.end(), [&](auto const &e) { return e.name == name; })->tensor; + auto const &tensor = *_graph.internal().contiguous().edges[i].tensor; + ASSERT(tensor.bytesSize() == blob->size(), "input size mismatch"); + _stream.setData(i, std::move(blob)); + } + + auto Executor::getOutput(count_t i) const -> pybind11::array { + i = _graph.internal().contiguous().topology.globalOutputs().at(i); + + auto const &tensor = *_graph.internal().contiguous().edges[i].tensor; auto ans = pybind11::array(buildNumpyDType(tensor.dataType), std::move(tensor.shape)); - _stream.getData(i, ans.mutable_data(), ans.nbytes()); + _stream.copyData(i, ans.mutable_data(), ans.nbytes()); return ans; } + auto Executor::getOutputBlob(count_t i) const -> Arc { + i = _graph.internal().contiguous().topology.globalOutputs().at(i); + + return _stream.getData(i); + } + void Executor::run() { _stream.run(); } @@ -76,6 +86,60 @@ namespace refactor::python_ffi { os.write(ptr, size); } + static void writeText(std::ofstream os, char const *ptr, size_t size, + DataType dataType, computation::Shape const &shape) { + if (shape.empty()) { + os << dataType.name() << "<>" << std::endl; + return; + } else { + auto iter = shape.begin(); + os << dataType.name() << '<' << *iter++; + while (iter != shape.end()) { os << 'x' << *iter++; } + os << '>' << std::endl; + }; + +#define CASE(T) \ + case DataType::T: { \ + using T_ = primitive::type; \ + auto ptr_ = reinterpret_cast(ptr); \ + for (auto i : range0_(size / sizeof(T_))) { \ + os << ptr_[i] << '\t'; \ + } \ + } break + + switch (dataType) { + case DataType::U8: { + auto ptr_ = reinterpret_cast(ptr); + for (auto i : range0_(size)) { + os << static_cast(ptr_[i]) << '\t'; + } + } break; + case DataType::I8: { + auto ptr_ = reinterpret_cast(ptr); + for (auto i : range0_(size)) { + os << static_cast(ptr_[i]) << '\t'; + } + } break; + case DataType::Bool: { + auto ptr_ = reinterpret_cast(ptr); + for (auto i : range0_(size)) { + os << (ptr_[i] ? "true " : "false") << '\t'; + } + } break; + CASE(F32); + CASE(U16); + CASE(I16); + CASE(I32); + CASE(I64); + CASE(F64); + CASE(U32); + CASE(U64); + default: + UNREACHABLE(); + break; + } + } + static void writeNpy(std::ofstream os, char const *ptr, size_t size, DataType dataType, computation::Shape const &shape) { std::stringstream ss; @@ -126,7 +190,6 @@ namespace refactor::python_ffi { fs::create_directories(path); ASSERT(fs::is_directory(path), "Failed to create \"{}\"", path.c_str()); - auto const npy = format == "npy"; size_t dataIdx = 0; auto const &graph = _graph.internal().contiguous(); @@ -154,9 +217,12 @@ namespace refactor::python_ffi { auto file = path / fmt::format("data{:06}.{}", dataIdx++, format); fs::remove(file); std::ofstream os(file, std::ios::binary); - if (npy) { + if (format == "npy") { writeNpy(std::move(os), buffer.data(), size, edge.tensor->dataType, edge.tensor->shape); + } else if (format == "text") { + writeText(std::move(os), buffer.data(), size, + edge.tensor->dataType, edge.tensor->shape); } else { writeBin(std::move(os), buffer.data(), size); } diff --git a/src/09python_ffi/src/executor.h b/src/09python_ffi/src/executor.h index 5174cc744..004b9e63d 100644 --- a/src/09python_ffi/src/executor.h +++ b/src/09python_ffi/src/executor.h @@ -15,7 +15,9 @@ namespace refactor::python_ffi { Executor(computation::Graph, runtime::Stream); void dispatch(Arc, std::string allocator); void setInput(count_t, pybind11::array); - auto getOutput(count_t) -> pybind11::array; + void setInputBlob(count_t, Arc); + auto getOutput(count_t) const -> pybind11::array; + auto getOutputBlob(count_t) const -> Arc; void run(); void bench(bool sync); void trace(std::string path, std::string format); diff --git a/src/09python_ffi/src/import.cpp b/src/09python_ffi/src/import.cpp index 801e9ee1d..dda0e660c 100644 --- a/src/09python_ffi/src/import.cpp +++ b/src/09python_ffi/src/import.cpp @@ -56,16 +56,16 @@ namespace refactor::python_ffi { SharedOp makeOp(AttributeMap ctx, Name opType, AttributeMap attrs) { - std::unordered_map attrs_; + Attributes attrs_; for (auto &[name, value] : attrs) { - attrs_.insert({std::move(name), {std::move(value)}}); + attrs_.insert(std::move(name), {std::move(value)}); } std::unordered_map ctx_; for (auto &[name, value] : ctx) { ctx_.insert({std::move(name), {std::move(value)}}); } return std::make_shared(Operator::build( - ctx_, fmt::format("onnx::{}", opType), std::move(attrs_))); + ctx_, std::move(opType), std::move(attrs_))); } Arc diff --git a/src/09python_ffi/src/main.cpp b/src/09python_ffi/src/main.cpp index 54b92ee38..48a4ea6ff 100644 --- a/src/09python_ffi/src/main.cpp +++ b/src/09python_ffi/src/main.cpp @@ -1,6 +1,7 @@ #include "communication/operators.h" #include "hardware/device.h" #include "import.h" +#include "llm/operators.h" #include "onnx/operators.h" #include // keep this line to convert stl types @@ -14,6 +15,7 @@ namespace refactor::python_ffi { using namespace frontend; onnx::register_(); + llm::register_(); communication::register_(); // clang-format off @@ -21,6 +23,7 @@ namespace refactor::python_ffi { py::class_ >(m, "Tensor" ); py::class_ >(m, "Operator" ); py::class_ >(m, "Device" ); + py::class_>(m, "Pinned" ); m .def("config_log" , &configLog , return_::automatic ) .def("find_device" , &findDevice , return_::move ) @@ -33,6 +36,7 @@ namespace refactor::python_ffi { py::class_>(m, "Compiler" ) .def("substitute" , &Compiler::substitute , return_::automatic ) .def("set_input" , &Compiler::setInput , return_::automatic ) + .def("set_input_info" , &Compiler::setInputInfo , return_::automatic ) .def("check_variables" , &Compiler::fillEdgeInfo , return_::move ) .def("zero_inputs" , &Compiler::zeroInputs , return_::move ) .def("get_tensor" , &Compiler::getTensor , return_::move ) @@ -43,7 +47,9 @@ namespace refactor::python_ffi { py::class_>(m, "Executor" ) .def("dispatch" , &Executor::dispatch , return_::automatic ) .def("set_input" , &Executor::setInput , return_::automatic ) + .def("set_input_blob" , &Executor::setInputBlob , return_::automatic ) .def("get_output" , &Executor::getOutput , return_::move ) + .def("get_output_blob" , &Executor::getOutputBlob , return_::move ) .def("run" , &Executor::run , return_::automatic ) .def("bench" , &Executor::bench , return_::automatic ) .def("trace" , &Executor::trace , return_::automatic ) diff --git a/src/09python_ffi/src/refactor_graph/onnx.py b/src/09python_ffi/src/refactor_graph/onnx.py index be688da7e..7f5875b38 100644 --- a/src/09python_ffi/src/refactor_graph/onnx.py +++ b/src/09python_ffi/src/refactor_graph/onnx.py @@ -63,7 +63,7 @@ def make_compiler(model: ModelProto, external_data_path: str = "") -> Compiler: { node.name: _make_operator( context, - node.op_type, + "onnx::" + node.op_type, _parse_attribute(node, external_data_path), ) for node in model.graph.node