Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/04kernel/include/kernel/collectors/simple_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace refactor::kernel {
Erf,
Neg,
Not,
HardSwish,
};

std::string_view unaryName(SimpleUnaryType type);
Expand Down
1 change: 1 addition & 0 deletions src/04kernel/src/collectors/simple_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace refactor::kernel {
CASE(Erf);
CASE(Neg);
CASE(Not);
CASE(HardSwish);
default:
UNREACHABLE();
}
Expand Down
8 changes: 4 additions & 4 deletions src/04kernel/src/kernels/simple_binary/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ namespace refactor::kernel {
switch (dataType.internal) {
CASE_DT(std::fmod(a, b), F32);
CASE_DT(a % b, U8);
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I8);
CASE_DT(static_cast<int8_t>(std::fmod(a, b)), I8);
CASE_DT(a % b, U16);
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I16);
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I32);
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I64);
CASE_DT(static_cast<int16_t>(std::fmod(a, b)), I16);
CASE_DT(static_cast<int32_t>(std::fmod(a, b)), I32);
CASE_DT(static_cast<int64_t>(std::fmod(a, b)), I64);
CASE_DT(std::fmod(a, b), F64);
CASE_DT(a % b, U32);
CASE_DT(a % b, U64);
Expand Down
12 changes: 8 additions & 4 deletions src/04kernel/src/kernels/simple_binary/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,18 @@ extern "C" __global__ void kernel(
case SimpleBinaryType::Fmod:
switch (dt) {
case DataType::U8:
case DataType::I8:
case DataType::U16:
case DataType::U32:
case DataType::U64:
return "a % b";
case DataType::I8:
return "static_cast<char>(fmodf(a, b))";
case DataType::I16:
return "static_cast<short>(fmodf(a, b))";
case DataType::I32:
return "static_cast<int>(fmodf(a, b))";
case DataType::I64:
case DataType::U32:
case DataType::U64:
return "a % b < 0 ? (a % b + b) : (a % b)";
return "static_cast<long long>(fmodf(a, b))";
case DataType::F32:
return "fmodf(a, b)";
case DataType::FP16:
Expand Down
14 changes: 14 additions & 0 deletions src/04kernel/src/kernels/simple_unary/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace refactor::kernel {
Op::Tanh,
Op::Neg,
Op::Erf,
Op::HardSwish,
};
return supportedOp.contains(op) && a.dataType.isCpuNumberic()
? std::make_unique<K>(op, a.dataType, a.elementsSize())
Expand Down Expand Up @@ -49,6 +50,12 @@ namespace refactor::kernel {
using M = std::conditional_t<sizeof(T) <= 4, float, double>;
return static_cast<T>(std::tanh(static_cast<M>(x)));
}
template<class T> auto hardswishFun(T x) noexcept -> T {
auto mid = x / 6.f + .5f;
return (mid <= 0) ? 0
: (1 <= mid) ? x
: x * mid;
}
auto copyForUnsigned(size_t n) noexcept -> Routine {
return [n](runtime::Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
std::memcpy(outputs[0], inputs[0], n);
Expand Down Expand Up @@ -171,6 +178,13 @@ namespace refactor::kernel {
default:
UNREACHABLE();
}
case Op::HardSwish:
switch (dataType) {
CASE(hardswishFun, F32);
CASE(hardswishFun, F64);
default:
UNREACHABLE();
}
default:
UNREACHABLE();
}
Expand Down
7 changes: 6 additions & 1 deletion src/04kernel/src/kernels/simple_unary/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace refactor::kernel {
static const std::unordered_set<Op>
supportedOp{Op::Abs, Op::Relu, Op::Sqrt,
Op::Sigmoid, Op::Tanh, Op::Neg,
Op::Erf};
Op::Erf, Op::HardSwish};
#ifndef USE_CUDA
return nullptr;
#endif
Expand Down Expand Up @@ -154,6 +154,11 @@ extern "C" __global__ void kernel(
{__(Op::Erf, DT::I64 ), "erf(static_cast<double>(x))"},
{__(Op::Erf, DT::FP16), "__float2half(erff(__half2float(x)))"},
{__(Op::Erf, DT::BF16), "__float2bfloat16(erff(__bfloat162float(x)))"},

{__(Op::HardSwish, DT::F32 ), "x * fmaxf(0.f, fminf(1.f, fmaf(1.f/6.f, x, 0.5f)))"},
{__(Op::HardSwish, DT::FP16), "x * __hmax(CUDART_ZERO_FP16, __hmin(CUDART_ONE_FP16, hrcp(__float2half(6.f)) * x + hrcp(__float2half(2.f))))"},
{__(Op::HardSwish, DT::F64 ), "x * fmax(0.0, fmin(1.0, fma(1.0/6.0, x, 0.5)))"},

};
// clang-format on

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ TEST(kernel, BinaryCpu) {
testBinaryCPU(SimpleBinaryType::Mul, [](float a, float b) { return a * b; });
testBinaryCPU(SimpleBinaryType::Div, [](float a, float b) { return a / b; });
testModCPU(SimpleBinaryType::Mod, [](int a, int b) { return a % b; });
testFmodWithI32CPU(SimpleBinaryType::Fmod, [](int a, int b) { return a % b < 0 ? (a % b + b) : (a % b); });
testFmodWithI32CPU(SimpleBinaryType::Fmod, [](int a, int b) { return static_cast<int32_t>(std::fmod(a, b)); });
testBinaryCPU(SimpleBinaryType::Fmod, [](float a, float b) { return std::fmod(a, b); });
}

Expand Down
29 changes: 28 additions & 1 deletion src/04kernel/test/kernels/simple_unary/test_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using namespace refactor;
using namespace kernel;

using VecFloat = std::vector<float>;

static void testOp(SimpleUnaryType opType, float check(float)) {
// build routine
auto dataTensor = Tensor::share(DataType::F32, Shape{20, 30, 50});
Expand All @@ -12,7 +14,7 @@ static void testOp(SimpleUnaryType opType, float check(float)) {
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
std::vector<float> data(dataTensor->elementsSize());
VecFloat data(dataTensor->elementsSize());
for (auto i : range0_(data.size())) { data[i] = i * 1e-4f; }
auto result = data;
// inference
Expand All @@ -27,9 +29,34 @@ static void testOp(SimpleUnaryType opType, float check(float)) {
}
}

static void testOpWithData(SimpleUnaryType opType, const VecFloat &data) {
// build routine
auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3});
auto kernel = SimpleUnaryCpu::build(opType, *dataTensor);
ASSERT_TRUE(kernel);
auto res = runtime::Resources();
auto routine = kernel->lower(res).routine;
// put input data
VecFloat inputdata(dataTensor->elementsSize());
for (auto i : range0_(inputdata.size())) { inputdata[i] = i; }
auto result = inputdata;
// inference
{
void const *inputs[]{result.data()};
void *outputs[]{result.data()};
routine(res, nullptr, inputs, outputs);
}
// check
for (auto i : range0_(inputdata.size())) {
EXPECT_NEAR(data[i], result[i], 1e-5);
}
}

TEST(kernel, SimpleUnaryCpu) {
testOp(SimpleUnaryType::Abs, std::abs);
testOp(SimpleUnaryType::Sqrt, std::sqrt);
testOp(SimpleUnaryType::Tanh, std::tanh);
testOp(SimpleUnaryType::Erf, std::erf);
testOpWithData(SimpleUnaryType::HardSwish,
VecFloat{0.000000, 0.666667, 1.666667, 3.000000, 4.000000, 5.000000});
}
1 change: 1 addition & 0 deletions src/04kernel/test/kernels/simple_unary/test_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ TEST(kernel, SimpleUnaryCuda) {
testOp(SimpleUnaryType::Sigmoid);
testOp(SimpleUnaryType::Tanh);
testOp(SimpleUnaryType::Erf);
testOp(SimpleUnaryType::HardSwish);
}

#endif
6 changes: 6 additions & 0 deletions src/05computation/src/operators/simple_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ namespace refactor::computation {
static uint8_t ID = 19;
return reinterpret_cast<size_t>(&ID);
}
case SimpleUnaryType::HardSwish: {
static uint8_t ID = 20;
return reinterpret_cast<size_t>(&ID);
}
default:
UNREACHABLE();
}
Expand Down Expand Up @@ -128,6 +132,8 @@ namespace refactor::computation {
return "Neg";
case SimpleUnaryType::Not:
return "Not";
case SimpleUnaryType::HardSwish:
return "HardSwish";
default:
UNREACHABLE();
}
Expand Down
2 changes: 2 additions & 0 deletions src/07onnx/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ namespace refactor::onnx {
REGISTER(And , SimpleBinary );
REGISTER(Or , SimpleBinary );
REGISTER(Xor , SimpleBinary );
REGISTER(Mod , SimpleBinary );
REGISTER(Abs , SimpleUnary );
REGISTER(Acos , SimpleUnary );
REGISTER(Acosh , SimpleUnary );
Expand All @@ -116,6 +117,7 @@ namespace refactor::onnx {
REGISTER(Not , SimpleUnary );
REGISTER(Neg , SimpleUnary );
REGISTER(Identity , SimpleUnary );
REGISTER(HardSwish , SimpleUnary );
REGISTER(Slice , Slice );
REGISTER(Softmax , Softmax );
REGISTER(Split , Split );
Expand Down
9 changes: 8 additions & 1 deletion src/07onnx/src/operators/simple_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace refactor::onnx {
opType == "onnx::Not" ? Ty::Not :
opType == "onnx::Neg" ? Ty::Neg :
opType == "onnx::Identity"? Ty::Identity:
opType == "onnx::HardSwish" ? Ty::HardSwish :
UNREACHABLEX(Ty, "Unsupported unary operator: {}", opType);
// clang-format on

Expand Down Expand Up @@ -129,6 +130,10 @@ namespace refactor::onnx {
static uint8_t ID = 21;
return reinterpret_cast<size_t>(&ID);
}
case Ty::HardSwish: {
static uint8_t ID = 22;
return reinterpret_cast<size_t>(&ID);
}
default:
UNREACHABLE();
}
Expand Down Expand Up @@ -159,6 +164,7 @@ namespace refactor::onnx {
case Ty::Not : return "onnx::Not";
case Ty::Neg : return "onnx::Neg";
case Ty::Identity : return "onnx::Identity";
case Ty::HardSwish : return "onnx::HardSwish";
default: UNREACHABLE();
}
// clang-format on
Expand Down Expand Up @@ -187,7 +193,7 @@ namespace refactor::onnx {
Ty::Atan, Ty::Atanh,
Ty::Cos, Ty::Cosh,
Ty::Sin, Ty::Sinh,
Ty::Tan},
Ty::Tan, Ty::HardSwish},
{Ty::Tanh, Ty::Sqrt, Ty::Sigmoid, Ty::Log},
{Ty::Neg},
{Ty::Identity}};
Expand Down Expand Up @@ -287,6 +293,7 @@ namespace refactor::onnx {
case Ty::Not : type_ = Ty_::Not ; break;
case Ty::Neg : type_ = Ty_::Neg ; break;
case Ty::Identity : return std::make_unique<computation::Identity>();
case Ty::HardSwish : type_ = Ty_::HardSwish ; break;
default: UNREACHABLE();
}
// clang-format on
Expand Down
17 changes: 9 additions & 8 deletions src/07onnx/src/operators/simple_unary.hh
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@ namespace refactor::onnx {
Atanh,
Cos,
Cosh,
Sin,
Sinh,
Tan,
Tanh,
Relu,
Sqrt,
Sigmoid,
Erf,
HardSwish,
Identity,
Log,
Not,
Neg,
Identity,
Relu,
Sin,
Sinh,
Sqrt,
Sigmoid,
Tan,
Tanh,
};

struct SimpleUnary final : public Operator {
Expand Down
14 changes: 14 additions & 0 deletions src/07onnx/test/test_simple_unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,18 @@ TEST(infer, SimpleUnary) {
ASSERT_EQ(y->dataType, DataType::F32);
ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)}));
}
{
// HardSwish Test
auto edges = Edges{
{Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3)}, {}), ""},
};
count_t inputs[]{0};
auto infered = SimpleUnary(SimpleUnaryType::HardSwish).infer(TensorRefs(edges, inputs), {true});
ASSERT_TRUE(infered.isOk());
auto outputs = std::move(infered.unwrap());
ASSERT_EQ(outputs.size(), 1);
auto y = std::move(outputs[0]);
ASSERT_EQ(y->dataType, DataType::F32);
ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(3)}));
}
}