Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/04kernel/include/kernel/collectors/hard_sigmoid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef KERNEL_HARD_SIGMOIG_H
#define KERNEL_HARD_SIGMOIG_H

#include "../collector.h"

namespace refactor::kernel {

struct HardSigmoidCollector final : public InfoCollector {
float alpha, beta;

constexpr HardSigmoidCollector(decltype(_target) target, float alpha_, float beta_) noexcept
: InfoCollector(target), alpha(alpha_), beta(beta_) {}

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};
}// namespace refactor::kernel

#endif// KERNEL_HARD_SIGMOIG_H

29 changes: 29 additions & 0 deletions src/04kernel/src/collectors/hard_sigmoid.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "kernel/collectors/hard_sigmoid.h"
#include "../kernels/hard_sigmoid/cpu_kernel.hh"
#include "../kernels/hard_sigmoid/cuda_kernel.hh"

namespace refactor::kernel {

std::vector<KernelBox>
HardSigmoidCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto const &a = inputs[0];

std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
if (auto ptr = HardSigmoidCpu::build(alpha, beta, a); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Nvidia:
if (auto ptr = HardSigmoidCuda::build(alpha, beta, a); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
54 changes: 54 additions & 0 deletions src/04kernel/src/kernels/hard_sigmoid/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "cpu_kernel.hh"
#include <execution>

namespace refactor::kernel {
using K = HardSigmoidCpu;
using DT = DataType;

K::HardSigmoidCpu(float alpha_, float beta_, DT dataType_, size_t size_) noexcept
: Kernel(), alpha(alpha_), beta(beta_), dataType(dataType_), size(size_) {}

auto K::build(float alpha_, float beta_, Tensor const &a) noexcept -> KernelBox {
if (!a.dataType.isCpuNumberic()) {
return nullptr;
}
return std::make_unique<K>(alpha_, beta_, a.dataType, a.elementsSize());
}

auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
auto K::description() const noexcept -> std::string_view {
return "Performing HardSigmoid using CPU";
}

template<class T>
static Routine lowerTyped(float alpha_, float beta_, size_t size) {
using namespace runtime;

return [=](Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
auto x = reinterpret_cast<T const *>(inputs[0]);
auto y = reinterpret_cast<T *>(outputs[0]);
std::for_each_n(std::execution::par_unseq,
natural_t(0), size,
[&](auto i) {
y[i] = std::clamp(alpha_ * x[i] + beta_, static_cast<T>(0), static_cast<T>(1));
});
};
}

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
switch (dataType) {
case DT::F32:
return lowerTyped<float>(alpha, beta, size);
case DT::F64:
return lowerTyped<double>(alpha, beta, size);
default:
UNREACHABLE();
}
}
}// namespace refactor::kernel

27 changes: 27 additions & 0 deletions src/04kernel/src/kernels/hard_sigmoid/cpu_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef KERNEL_HARD_SIGMOID_CPU_KERNEL_HH
#define KERNEL_HARD_SIGMOID_CPU_KERNEL_HH

#include "kernel/collectors/hard_sigmoid.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct HardSigmoidCpu final : public Kernel {
float alpha, beta;
DataType dataType;
size_t size;

explicit HardSigmoidCpu(float, float, DataType, size_t) noexcept;

static KernelBox build(float, float, Tensor const &) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
RoutineWorkspace lower(Resources &) const noexcept final;
};

}// namespace refactor::kernel

#endif// KERNEL_HARD_SIGMOID_CPU_KERNEL_HH

88 changes: 88 additions & 0 deletions src/04kernel/src/kernels/hard_sigmoid/cuda_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "cuda_kernel.hh"

#ifdef USE_CUDA
#include "../../generator/nvrtc_repo.h"
#include "kernel/cuda/threads_distributer.cuh"
#include <cuda_runtime.h>
#endif

namespace refactor::kernel {
using K = HardSigmoidCuda;
using DT = DataType;

K::HardSigmoidCuda(float alpha_, float beta_, DT dt_, size_t size_) noexcept
: Kernel(), alpha(alpha_), beta(beta_), dataType(dt_), size(size_) {}

auto K::build(float alpha_, float beta_, Tensor const &a) noexcept -> KernelBox {
#ifndef USE_CUDA
return nullptr;
#endif
return std::make_unique<K>(alpha_, beta_, a.dataType, a.elementsSize());
}

auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
auto K::kernelTypeId() const noexcept -> size_t {
return typeId();
}
auto K::description() const noexcept -> std::string_view {
return "Performing hardsigmoid operation on Nvidia GPU";
}

#ifdef USE_CUDA
constexpr static const char *TEMPLATE = R"~(
__device__ __forceinline__ static {0:} fn({0:} x) {{
return {1:};
}}

extern "C" __global__ void kernel(
{0:} *__restrict__ y,
{0:} const *__restrict__ x,
size_t n
) {{
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
step = blockDim.x * gridDim.x;
tid < n;
tid += step)
y[tid] = fn(x[tid]);
}}
)~";
auto K::lower(Resources &res) const -> RoutineWorkspace {
using namespace runtime;

std::string op = "";
switch (dataType) {
case DT::F32:
op = fmt::format("fmaxf(0.f, fminf(1.f, fmaf({}, x, {})))", alpha, beta);
break;
case DT::F64:
op = fmt::format("fmax(0.0, fmin(1.0, fma({}, x, {})))",
static_cast<double>(alpha), static_cast<double>(beta));
break;
case DT::FP16:
op = fmt::format("__hmax(CUDART_ZERO_FP16, __hmin(CUDART_ONE_FP16, (__float2half({}) * x + __float2half({}))))",
alpha, beta);
break;
default:
UNREACHABLE();
}
auto name = fmt::format("hardsigmoid_{}_{}_{}", dataType.name(), alpha, beta);
auto code = fmt::format(TEMPLATE, nvrtc::dataType(dataType), op);
return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
params = cuda::ThreadsDistributer()(size)](
Resources &, void *, void const *const *inputs, void *const *outputs) {
auto y = outputs[0];
auto x = inputs[0];
auto n = params.n;
void *args[]{&y, &x, &n};
h->launch(params.gridSize, 1, 1,
params.blockSize, 1, 1,
0, args);
};
}
#endif

}// namespace refactor::kernel

28 changes: 28 additions & 0 deletions src/04kernel/src/kernels/hard_sigmoid/cuda_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH
#define KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH

#include "kernel/collectors/hard_sigmoid.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct HardSigmoidCuda final : public Kernel {
float alpha, beta;
DataType dataType;
size_t size;

explicit HardSigmoidCuda(float, float, DataType, size_t) noexcept;

static KernelBox build(float, float, Tensor const &) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_CUDA
RoutineWorkspace lower(Resources &) const final;
#endif
};

}// namespace refactor::kernel

#endif// KERNEL_HARD_SIGMOID_CUDA_KERNEL_HH
31 changes: 31 additions & 0 deletions src/04kernel/test/kernels/hard_sigmoid/test_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "../../../src/kernels/hard_sigmoid/cpu_kernel.hh"
#include <gtest/gtest.h>

using namespace refactor;
using namespace kernel;

TEST(kernel, HardSigmoidCpu) {
// build routine
auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3, 5});
float alpha = 0.2f, beta = 0.5f;
auto kernel = HardSigmoidCpu::build(alpha, beta, *dataTensor);
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
std::vector<float> result(dataTensor->elementsSize());
for (auto i : range0_(result.size())) { result[i] = i; }
// inference
{
void const *inputs[]{result.data()};
void *outputs[]{result.data()};
routine(res, nullptr, inputs, outputs);
}
std::vector<float> output = {0.5, 0.7, 0.9, 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1.};
// check
for (auto i : range0_(result.size())) {
EXPECT_FLOAT_EQ(output[i], result[i]);
}
}
49 changes: 49 additions & 0 deletions src/04kernel/test/kernels/hard_sigmoid/test_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifdef USE_CUDA

#include "../../../src/kernels/hard_sigmoid/cpu_kernel.hh"
#include "../../../src/kernels/hard_sigmoid/cuda_kernel.hh"
#include "hardware/device_manager.h"
#include <gtest/gtest.h>

using namespace refactor;
using namespace kernel;
using namespace hardware;

TEST(kernel, HardSigmoidCuda) {
// build routine
auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3, 5});
float alpha = 0.2f, beta = 0.5f;
auto kernel = HardSigmoidCuda::build(alpha, beta, *dataTensor);
auto kCpu = HardSigmoidCpu::build(alpha, beta, *dataTensor);
ASSERT_TRUE(kernel && kCpu);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine,
rCpu = kCpu->lower(res).routine;
// malloc
auto &dev = *device::init(Device::Type::Nvidia, 0, "");
auto gpuMem = dev.malloc(dataTensor->bytesSize());
// put input data
std::vector<float> data(dataTensor->elementsSize());
for (auto i : range0_(data.size())) { data[i] = i; }
gpuMem->copyFromHost(data.data(), dataTensor->bytesSize());
// inference
{
void const *inputs[]{*gpuMem};
void *outputs[]{*gpuMem};
routine(res, nullptr, inputs, outputs);
}
{
void const *inputs[]{data.data()};
void *outputs[]{data.data()};
rCpu(res, nullptr, inputs, outputs);
}
// take output data
std::vector<float> result(dataTensor->elementsSize());
gpuMem->copyToHost(result.data(), dataTensor->bytesSize());
// check
for (auto i : range0_(data.size())) {
EXPECT_FLOAT_EQ(data[i], result[i]);
}
}

#endif
23 changes: 23 additions & 0 deletions src/05computation/include/computation/operators/hard_sigmoid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef COMPUTATION_HARD_SIGMOID_H
#define COMPUTATION_HARD_SIGMOID_H

#include "../operator.h"

namespace refactor::computation {

struct HardSigmoid final : public Operator {
float alpha, beta;

constexpr HardSigmoid(float alpha_, float beta_) noexcept
: Operator(), alpha(alpha_), beta(beta_){};

static size_t typeId() noexcept;
size_t opTypeId() const noexcept final;
std::string_view name() const noexcept final;
kernel::CollectorBox candidateKernels(Target) const noexcept final;
std::string serialize() const noexcept final;
};

}// namespace refactor::computation

#endif// COMPUTATION_HARD_SIGMOID_H
23 changes: 23 additions & 0 deletions src/05computation/src/operators/hard_sigmoid.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "computation/operators/hard_sigmoid.h"
#include "kernel/collectors/hard_sigmoid.h"

namespace refactor::computation {
using Op = HardSigmoid;

auto Op::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
auto Op::opTypeId() const noexcept -> size_t { return typeId(); }
auto Op::name() const noexcept -> std::string_view { return "HardSigmoid"; }

auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox {
using Collector_ = kernel::HardSigmoidCollector;
return std::make_unique<Collector_>(target, alpha, beta);
}
auto Op::serialize() const noexcept -> std::string {
return fmt::format("{}()", name());
}

}// namespace refactor::computation

2 changes: 2 additions & 0 deletions src/07onnx/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "operators/gather_elements.hh"
#include "operators/gemm.hh"
#include "operators/global_pool.hh"
#include "operators/hard_sigmoid.hh"
#include "operators/mat_mul.hh"
#include "operators/mat_mul_integer.hh"
#include "operators/pool.hh"
Expand Down Expand Up @@ -124,6 +125,7 @@ namespace refactor::onnx {
REGISTER(Transpose , Transpose );
REGISTER(Unsqueeze , Unsqueeze );
REGISTER(Where , Where );
REGISTER(HardSigmoid , HardSigmoid );
#undef REGISTER
// clang-format on
}
Expand Down
Loading