Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repo Sync #631

Merged
merged 1 commit into from
Mar 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
repo-sync-2024-03-29T10:01:45+0800
anakinxc committed Mar 29, 2024
commit b0f637e941bbd0932e2bcac0e7223cd248a9f99a
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@

- [Feature] Add minimax approximation for log
- [Feature] Support jax.lax.top_k
- [Feature] Support round to nearest even
- [Improvement] Default log approximation to minmax
- [Improvement] Improve median performance

1 change: 1 addition & 0 deletions libspu/compiler/passes/hlo_legalize_to_pphlo.cc
Original file line number Diff line number Diff line change
@@ -1331,6 +1331,7 @@ struct HloLegalizeToPPHlo
HloToPPHloOpConverter<stablehlo::ReturnOp>,
HloToPPHloOpConverter<stablehlo::RngOp>,
HloToPPHloOpConverter<stablehlo::RoundOp>,
HloToPPHloOpConverter<stablehlo::RoundNearestEvenOp>,
HloToPPHloOpConverter<stablehlo::RsqrtOp>,
HloToPPHloOpConverter<stablehlo::SineOp>,
HloToPPHloOpConverter<stablehlo::SelectOp>,
1 change: 1 addition & 0 deletions libspu/compiler/passes/map_stablehlo_to_pphlo_op.h
Original file line number Diff line number Diff line change
@@ -77,6 +77,7 @@ MAP_HLO_TO_PPHLO(RemOp)
MAP_HLO_TO_PPHLO(ReshapeOp)
MAP_HLO_TO_PPHLO(ReverseOp)
MAP_HLO_TO_PPHLO(RoundOp)
MAP_HLO_TO_PPHLO(RoundNearestEvenOp)
MAP_HLO_TO_PPHLO(RngOp)
MAP_HLO_TO_PPHLO(SelectOp)
MAP_HLO_TO_PPHLO(ShiftLeftOp)
5 changes: 3 additions & 2 deletions libspu/device/api.cc
Original file line number Diff line number Diff line change
@@ -219,8 +219,9 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name,
}

// print link statistics
SPDLOG_INFO("Link details: total send bytes {}, send actions {}",
comm_stats.send_bytes, comm_stats.send_actions);
SPDLOG_INFO(
"Link details: total send bytes {}, recv bytes {}, send actions {}",
comm_stats.send_bytes, comm_stats.recv_bytes, comm_stats.send_actions);
}

void SPUErrorHandler(void *use_data, const char *reason, bool gen_crash_diag) {
1 change: 1 addition & 0 deletions libspu/device/pphlo/pphlo_executor.cc
Original file line number Diff line number Diff line change
@@ -264,6 +264,7 @@ STANDARD_UNARY_OP_EXEC_IMPL(NotOp, Not)
STANDARD_UNARY_OP_EXEC_IMPL(RsqrtOp, Rsqrt)
STANDARD_UNARY_OP_EXEC_IMPL(SqrtOp, Sqrt)
STANDARD_UNARY_OP_EXEC_IMPL(RoundOp, Round_AFZ)
STANDARD_UNARY_OP_EXEC_IMPL(RoundNearestEvenOp, Round_RNTE)
STANDARD_UNARY_OP_EXEC_IMPL(SineOp, Sine)
STANDARD_UNARY_OP_EXEC_IMPL(CosineOp, Cosine)

1 change: 1 addition & 0 deletions libspu/device/pphlo/pphlo_verifier.cc
Original file line number Diff line number Diff line change
@@ -268,6 +268,7 @@ UNARY_VERIFIER(ExpOp, evalExponentialOp)
UNARY_VERIFIER(RsqrtOp, evalRsqrtOp)
UNARY_VERIFIER(SqrtOp, evalSqrtOp)
UNARY_VERIFIER(RoundOp, evalRoundOp)
UNARY_VERIFIER(RoundNearestEvenOp, evalRoundNearestEvenOp)
UNARY_VERIFIER(SignOp, evalSignOp)
UNARY_VERIFIER(Log1pOp, evalLog1pOp)
UNARY_VERIFIER(Expm1Op, evalExpm1Op)
1 change: 1 addition & 0 deletions libspu/device/pphlo/pphlo_verifier.h
Original file line number Diff line number Diff line change
@@ -59,6 +59,7 @@ class PPHloVerifier {
VERIFY_DECL(SignOp)
VERIFY_DECL(SqrtOp)
VERIFY_DECL(RoundOp)
VERIFY_DECL(RoundNearestEvenOp)

// Simple binary
VERIFY_DECL(AddOp)
14 changes: 14 additions & 0 deletions libspu/dialect/pphlo/ops.td
Original file line number Diff line number Diff line change
@@ -274,6 +274,20 @@ def PPHLO_RoundOp
}];
}

def PPHLO_RoundNearestEvenOp: PPHLO_UnaryElementwiseOpWithTypeInfer<"round_nearest_even",
[SameOperandsAndResultType], PPHLO_FpTensor> {
let summary = "RoundNearestEven operation";
let description = [{
Performs element-wise rounding towards the nearest integer, breaking ties
towards the even integer, on the `operand` tensor and produces a `result`
tensor.

Ref:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even
```
}];
}

def PPHLO_RsqrtOp
: PPHLO_UnaryElementwiseOpWithTypeInfer<"rsqrt", [SameOperandsAndResultType], PPHLO_FpTensor> {
let summary = "Reciprocal of square-root operator";
58 changes: 53 additions & 5 deletions libspu/kernel/hal/polymorphic.cc
Original file line number Diff line number Diff line change
@@ -330,11 +330,59 @@ Value min(SPUContext* ctx, const Value& x, const Value& y) {
Value power(SPUContext* ctx, const Value& x, const Value& y) {
SPU_TRACE_HAL_DISP(ctx, x, y);

if (x.isInt() || y.isInt()) {
auto x_f = dtype_cast(ctx, x, DT_F32);
auto y_f = dtype_cast(ctx, y, DT_F32);
auto ret = power(ctx, x_f, y_f);
return ret;
if (x.isInt()) {
// ref:
// https://github.com/openxla/stablehlo/blob/main/stablehlo/reference/Element.cpp#L912
// Although there are some "strange" semantics in stablehlo, we still follow
// them yet:
// 1. when x is int, then the return value must be int type.
// 2. if x is int, then y must be int
// 3. if x is int and y<0, then
// a. when |x|!=1, then always return 0;
// b. when |x|=1, then y=|y|;
//
// However, for jax.numpy.power, it behaves differently:
// 1. if any x or y is float, then both x and y will be upcast to float.
// 2. if both x and y are int, then y must be non-negative.
SPU_ENFORCE(y.isInt(), "when base is int, then y must be int.");
auto k0 = _constant(ctx, 0, x.shape());
auto k1 = _constant(ctx, 1, x.shape());
const auto bit_width = SizeOf(ctx->getField()) * 8;

auto y_b = _prefer_b(ctx, y);
auto msb_y = _rshift(ctx, y_b, bit_width - 1);
auto x_abs1 = _equal(ctx, abs(ctx, x), k1);

auto ret = _constant(ctx, 1, x.shape());
// To compute ret = x^y,
// although y has `bit_width` bits, we only consider `y_bits` bits here.
// The reason are two folds (recall that both x and y are int):
// 1. if |x|>1, then `ret` will OVERFLOW/UNDERFLOW if y>63 (e.g. FM64),
// which means the valid bits of y can't exceed `log(bit_width - 1)` .
// 2. if |x|=1:
// a). x=1, then we always get `ret`=1;
// b). x=-1, then the sign of `ret` is decided on the LSB of y;
// So we can "truncate" y to `y_bits` bits safely.
const size_t y_bits = Log2Ceil(bit_width - 1);

auto base = x;
// TODO: do this in parallel
// To compute x^y, it is necessary to compute all x^(2^idx), we use base
// (init as `x`) to store it, update base to base*base till last
// iteration, and multiply all these numbers according to y_{idx}.
// e.g. y=0101, then ret = (x) * (1) * (x^(2^2)) * (1) = x^5
for (size_t idx = 0; idx < y_bits; idx++) {
// x^(2^idx) * y_{idx}
auto cur_pow = _mux(ctx, _and(ctx, _rshift(ctx, y_b, idx), k1), base, k1);
ret = _mul(ctx, cur_pow, ret);
if (idx < y_bits - 1) {
base = _mul(ctx, base, base);
}
}

// when x=-1 and y<0, we can still get a correct result
return _mux(ctx, _and(ctx, msb_y, _not(ctx, x_abs1)), k0, ret)
.setDtype(x.dtype());
}
if (x.isPublic() && y.isPublic()) {
return f_pow_p(ctx, x, y);
21 changes: 14 additions & 7 deletions libspu/kernel/hal/polymorphic_test.cc
Original file line number Diff line number Diff line change
@@ -406,18 +406,22 @@ TYPED_TEST(MathTest, Pow) {
using LHS_VT = typename std::tuple_element<1, TypeParam>::type;
using RHS_DT = typename std::tuple_element<2, TypeParam>::type;
using RHS_VT = typename std::tuple_element<3, TypeParam>::type;
// using RES_DT = typename std::tuple_element<4, TypeParam>::type;
using RES_DT = typename std::tuple_element<4, TypeParam>::type;

if constexpr (!std::is_same_v<LHS_DT, RHS_DT>) {
return;
}

// GIVEN
xt::xarray<LHS_DT> x;
xt::xarray<RHS_DT> y;
{
// random test
x = test::xt_random<LHS_DT>({5, 6}, 0, 100);
y = test::xt_random<RHS_DT>({5, 6}, -2, 2);
y = test::xt_random<RHS_DT>({5, 6}, 0, 2);

// WHAT
auto z = test::evalBinaryOp<float>(LHS_VT(), RHS_VT(), power, x, y);
auto z = test::evalBinaryOp<RHS_DT>(LHS_VT(), RHS_VT(), power, x, y);

// THEN
auto expected = xt::pow(x, y);
@@ -429,14 +433,17 @@ TYPED_TEST(MathTest, Pow) {

{
// some fixed corner case
x = {-1, -1, -3, 1, -3, 0, 1, 1, 5, 0};
y = {1, 0, -3, -3, 3, 0, 0, 2, 5, 2};
x = {-1, -1, -1, -1, -3, 1, -3, 0, 1, 1, 5, 0, 3, 2, -2};
y = {1, 0, -3, -4, -3, -3, 3, 0, 0, 2, 5, 2, -3, -1, -1};

// WHAT
auto z = test::evalBinaryOp<float>(LHS_VT(), RHS_VT(), power, x, y);
auto z = test::evalBinaryOp<RES_DT>(LHS_VT(), RHS_VT(), power, x, y);

// THEN
auto expected = xt::pow(x, y);
// when x is int and x=-3, y=-3, we should get 0.
// when x is int and x=3, y=-3, we should get 0.
xt::xarray<RES_DT> expected = xt::pow(x, y);

EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl
<< y << std::endl
<< expected << std::endl
40 changes: 40 additions & 0 deletions libspu/kernel/hlo/basic_unary.cc
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
#include "libspu/kernel/hal/complex.h"
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/ring.h"
#include "libspu/kernel/hal/type_cast.h"

namespace spu::kernel::hlo {
@@ -101,4 +102,43 @@ spu::Value Round_AFZ(SPUContext *ctx, const spu::Value &in) {
return hal::dtype_cast(ctx, hal::dtype_cast(ctx, round, DT_I64), in.dtype());
}

spu::Value Round_RNTE(SPUContext *ctx, const spu::Value &in) {
// RNTE: Round to nearest, ties to even
// let x^' = *****a.b##### be origin fxp number
// x = *****a.bc ( c = reduce_or(#####) ), y = *****a
// then ret = y + comp (comp = 0 or 1), where
// 1) if b=0, then comp=0
// 2) if b=1, c=1, then comp=1
// 3) if b=1, c=0, a=1, then comp=1
// 4) if b=1, c=0, a=0, then comp=0
// so comp = b && (c || a)
SPU_ENFORCE(!in.isComplex());
SPU_ENFORCE(in.isFxp(), "Round only supports fxp");
const auto fxp_bits = ctx->getFxpBits();
const auto k1 = hal::_constant(ctx, 1U, in.shape());

auto x_prime = hal::_prefer_b(ctx, in);
auto y = hal::floor(ctx, x_prime);

auto a = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits), k1);
auto b = hal::_and(ctx, hal::_rshift(ctx, x_prime, fxp_bits - 1), k1);

std::vector<Value> cs;
cs.reserve(fxp_bits - 1);
for (size_t idx = 0; idx < fxp_bits - 1; idx++) {
auto x_ = hal::_and(ctx, hal::_rshift(ctx, x_prime, idx), k1);
cs.push_back(std::move(x_));
}
auto c = vreduce(cs.begin(), cs.end(), [&](const Value &a, const Value &b) {
return hal::_or(ctx, a, b);
});
auto comp = hal::_and(ctx, b, hal::_or(ctx, c, a));
// set nbits to improve b2a
if (comp.storage_type().isa<BShare>()) {
const_cast<Type &>(comp.storage_type()).as<BShare>()->setNbits(1);
}

return hal::add(ctx, y, comp.setDtype(DT_I64)).setDtype(in.dtype());
}

} // namespace spu::kernel::hlo
1 change: 1 addition & 0 deletions libspu/kernel/hlo/basic_unary.h
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ SIMPLE_UNARY_KERNEL_DECL(Sign)
SIMPLE_UNARY_KERNEL_DECL(Round_AFZ)
SIMPLE_UNARY_KERNEL_DECL(Real)
SIMPLE_UNARY_KERNEL_DECL(Imag)
SIMPLE_UNARY_KERNEL_DECL(Round_RNTE)

#undef SIMPLE_UNARY_KERNEL_DECL

1 change: 1 addition & 0 deletions libspu/kernel/hlo/basic_unary_test.cc
Original file line number Diff line number Diff line change
@@ -58,6 +58,7 @@ UNARY_EMPTY_TEST(Rsqrt)
UNARY_EMPTY_TEST(Sqrt)
UNARY_EMPTY_TEST(Sign)
UNARY_EMPTY_TEST(Round_AFZ)
UNARY_EMPTY_TEST(Round_RNTE)

INSTANTIATE_TEST_SUITE_P(
UnaryTestInstances, UnaryTest,
1 change: 1 addition & 0 deletions libspu/mpc/cheetah/arith/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ spu_cc_library(
hdrs = ["matmat_prot.h"],
deps = [
":arith_comm",
"//libspu/mpc/cheetah/rlwe:lwe",
],
)

Loading