Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

repo-sync-2024-09-20T11:43:39+0800 #860

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/development/add_protocols.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ member function and a member variable of an Object, respectively.
// register customized kernels
template <typename KernelT>
void regKernel() {
regKernel(KernelT::kBindName, std::make_unique<KernelT>());
regKernel(KernelT::kBindName(), std::make_unique<KernelT>());
}

template <typename KernelT>
Expand All @@ -116,7 +116,7 @@ member function and a member variable of an Object, respectively.
// add customized states
template <typename StateT, typename... Args>
void addState(Args&&... args) {
addState(StateT::kBindName,
addState(StateT::kBindName(),
std::make_unique<StateT>(std::forward<Args>(args)...));
}
...
Expand Down Expand Up @@ -205,7 +205,7 @@ As a result, the ABY3 developer can directly register these kernels through the
class AndPP : public BinaryKernel {
public:
// kernel name for dynamic binding
static constexpr char kBindName[] = "and_pp";
static constexpr const char* kBindName() { return "and_pp"; }

// define cost model
ce::CExpr latency() const override { return ce::Const(0); }
Expand Down
9 changes: 5 additions & 4 deletions libspu/core/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class Object final {

template <typename KernelT>
void regKernel() {
regKernel(KernelT::kBindName, std::make_unique<KernelT>());
regKernel(KernelT::kBindName(), std::make_unique<KernelT>());
}

template <typename KernelT, typename OtherKernelT, typename... MoreKernelT>
Expand All @@ -137,14 +137,15 @@ class Object final {

template <typename StateT, typename... Args>
void addState(Args&&... args) {
addState(StateT::kBindName,
addState(StateT::kBindName(),
std::make_unique<StateT>(std::forward<Args>(args)...));
}

template <typename StateT>
StateT* getState() {
const auto& itr = states_.find(StateT::kBindName);
SPU_ENFORCE(itr != states_.end(), "state={} not found", StateT::kBindName);
const auto& itr = states_.find(StateT::kBindName());
SPU_ENFORCE(itr != states_.end(), "state={} not found",
StateT::kBindName());
return dynamic_cast<StateT*>(itr->second.get());
}

Expand Down
1 change: 0 additions & 1 deletion libspu/kernel/hal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ spu_cc_library(
hdrs = ["utils.h"],
deps = [
":constants",
":polymorphic",
":ring",
":shape_ops",
"//libspu/core:prelude",
Expand Down
12 changes: 0 additions & 12 deletions libspu/kernel/hal/fxp_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,6 @@
namespace spu::kernel::hal {
namespace detail {

Value EvaluatePolynomial(SPUContext* ctx, const Value& x,
absl::Span<const float> coefficients) {
auto poly = constant(ctx, coefficients[0], x.dtype(), x.shape());

for (size_t i = 1; i < coefficients.size(); ++i) {
auto c = constant(ctx, coefficients[i], x.dtype(), x.shape());
poly = f_mul(ctx, poly, x);
poly = f_add(ctx, poly, c);
}
return poly;
}

Value log_minmax_normalized(SPUContext* ctx, const Value& x) {
static std::array<float, 9> kLogCoefficient{
0.0, 0.9999964239, -0.4998741238, 0.3317990258, -0.2407338084,
Expand Down
39 changes: 23 additions & 16 deletions libspu/kernel/hal/fxp_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,34 @@ Value polynomial(SPUContext* ctx, const Value& x,
SPU_ENFORCE(x.isFxp());
SPU_ENFORCE(!coeffs.empty());

if (coeffs.size() == 1U) {
if (coeffs.size() == 1U || x.numel() == 0) {
return coeffs[0];
}
Value x_pow = constant(ctx, 1.0F, x.dtype(), x.shape());
Value res = _mul(ctx, x_pow, coeffs[0]);
// Use a parallel circuit to calculate x, x^2, x^3, ..., x^n.
// The general log(n) algorithm
// algorithm:
// Step 0. x
// Step 1. x, x2
// Step 2. x, x2, x3, x4
// ...
std::vector<spu::Value> x_prefix(1, x);
size_t degree = coeffs.size() - 1;
for (int64_t i = 0; i < Log2Ceil(degree); ++i) {
size_t x_size = std::min(x_prefix.size(), degree - x_prefix.size());
std::vector<spu::Value> x_pow(x_size, x_prefix.back());
// TODO: this can be further optimized to use sign hint
vmap(x_prefix.begin(), x_prefix.begin() + x_size, x_pow.begin(),
x_pow.end(), std::back_inserter(x_prefix),
[ctx, sign_x](const Value& a, const Value& b) {
return f_mul(ctx, a, b, sign_x);
});
}

Value res = _mul(ctx, constant(ctx, 1.0F, x.dtype(), x.shape()), coeffs[0]);

const auto fbits = ctx->getFxpBits();
for (size_t i = 1; i < coeffs.size(); i++) {
if ((i & 1) == 0U) {
// x^{even order} is always positive
x_pow = _trunc(ctx, _mul(ctx, x_pow, x), fbits, SignType::Positive);
} else {
if (i > 1) {
x_pow = _trunc(ctx, _mul(ctx, x_pow, x), fbits, sign_x);
} else {
// i=1, then save a _trunc
x_pow = x;
}
}
res = _add(ctx, res, _mul(ctx, x_pow, coeffs[i]));
res = _add(ctx, res, _mul(ctx, x_prefix[i - 1], coeffs[i]));
}

return _trunc(ctx, res, fbits, sign_ret).setDtype(x.dtype());
Expand Down Expand Up @@ -93,7 +101,6 @@ Value maskNumberOfBits(SPUContext* ctx, const Value& in, size_t nbits) {
}

namespace {

Value reciprocal_goldschmidt_normalized_approx(SPUContext* ctx,
const Value& b_abs,
const Value& factor) {
Expand Down
2 changes: 1 addition & 1 deletion libspu/kernel/hal/permute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ std::vector<spu::Value> permute(SPUContext *ctx,
for (auto const &input : inputs) {
auto transposed = hal::transpose(ctx, input, perm);
auto reshaped = hal::reshape(ctx, transposed, {N, W});
inputs2d.push_back(reshaped);
inputs2d.push_back(std::move(reshaped));
}

// Call permute1d for each dim to permute.
Expand Down
117 changes: 77 additions & 40 deletions libspu/kernel/hal/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,29 @@

#include "libspu/core/context.h"
#include "libspu/core/value.h"
#include "libspu/core/vectorize.h"
#include "libspu/kernel/hal/constants.h"
#include "libspu/kernel/hal/polymorphic.h"
#include "libspu/kernel/hal/ring.h"
#include "libspu/kernel/hal/shape_ops.h"

namespace spu::kernel::hal {

//////////////////////////////////////////////////////////////////////////////
// Shape utils
//////////////////////////////////////////////////////////////////////////////

/// the squeeze function, i.e., removes dimensions of size 1 from the shape of
/// a tensor.
// @param in, the input
// @param dim, the dimension to be squeezed
Value squeeze(SPUContext* ctx, const Value& in, int64_t dim = 0);

/// the unsqueeze function, i.e., expands a tensor with a length 1 axis
/// inserted at index axis.
// @param in, the input
// @param dim, the dimension to be unsqueezed
Value unsqueeze(SPUContext* ctx, const Value& in, int64_t dim = 0);

// This is SPU's version of JAX's associative_scan
// See:
// https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.associative_scan.html
Expand All @@ -32,61 +48,82 @@ namespace spu::kernel::hal {
// for the detailed algorithm explanation
//
// fn: an associative binary Function
// in: a 1-d tensor
// in: a tensor, scan the last axis
template <typename Fn>
spu::Value associative_scan(Fn&& fn, SPUContext* ctx, const Value& in) {
SPU_ENFORCE(in.shape().ndim() == 1U, "input should be 1d");
const auto numel = in.numel();
if (numel < 2) {
SPU_ENFORCE(in.shape().ndim() >= 1U, "input should not be scalar");
// First reshape to 2D {M, N} tensor, scan each N elements
const Shape shape = in.shape();
const auto N = shape.back();
// in case some empty tensors
if (N < 2 || shape.numel() == 0) {
return in;
}
const auto M = shape.numel() / N;
spu::Value in_2d = hal::reshape(ctx, in, {M, N});

// merge consecutive even/odd index elements
auto reduced_elems = fn(ctx, hal::slice(ctx, in, {0}, {numel - 1}, {2}),
hal::slice(ctx, in, {1}, {numel}, {2}));
// process half elements recursively and get odd index elements
auto odd_elems = associative_scan(fn, ctx, reduced_elems);
spu::Value odd_elems;
std::vector<spu::Value> odd_vec;
std::vector<spu::Value> even_vec;
{
for (int64_t i = 0; i < M; ++i) {
odd_vec.push_back(hal::slice(ctx, in_2d, {i, 0}, {i + 1, N - 1}, {1, 2}));
even_vec.push_back(hal::slice(ctx, in_2d, {i, 1}, {i + 1, N}, {1, 2}));
}
std::vector<spu::Value> reduced_elems_vec;
vmap(odd_vec.begin(), odd_vec.end(), even_vec.begin(), even_vec.end(),
std::back_inserter(reduced_elems_vec),
[&](const spu::Value& odd, const spu::Value& even) {
return fn(ctx, odd, even);
});

auto concat_reduced_elems = hal::concatenate(ctx, reduced_elems_vec, 0);

// process half elements recursively and get odd index elements
odd_elems = associative_scan(fn, ctx, concat_reduced_elems);
}

// get even index elements
odd_vec.clear();
even_vec.clear();
spu::Value even_elems;
if (numel % 2 == 0) {
even_elems =
fn(ctx, hal::slice(ctx, odd_elems, {0}, {odd_elems.numel() - 1}, {1}),
hal::slice(ctx, in, {2}, {numel}, {2}));
} else {
even_elems = fn(ctx, odd_elems, hal::slice(ctx, in, {2}, {numel}, {2}));
{
std::vector<spu::Value> even_elems_vec;
for (int64_t i = 0; i < M; ++i) {
if (N % 2 == 0) {
odd_vec.push_back(hal::slice(ctx, odd_elems, {i, 0},
{i + 1, odd_elems.shape().back() - 1},
{1, 1}));
} else {
odd_vec.push_back(hal::slice(ctx, odd_elems, {i, 0},
{i + 1, odd_elems.shape().back()}, {}));
}
even_vec.push_back(hal::slice(ctx, in_2d, {i, 2}, {i + 1, N}, {1, 2}));
}
vmap(odd_vec.begin(), odd_vec.end(), even_vec.begin(), even_vec.end(),
std::back_inserter(even_elems_vec),
[&](const spu::Value& odd, const spu::Value& even) {
return fn(ctx, odd, even);
});

even_elems = hal::concatenate(ctx, even_elems_vec, 0);
}
// concat the 0th element
auto final_even_elems =
hal::concatenate(ctx, {hal::slice(ctx, in, {0}, {1}), even_elems}, 0);
auto final_even_elems = hal::concatenate(
ctx, {hal::slice(ctx, in_2d, {0, 0}, {M, 1}), even_elems}, 1);

// concat even and odd elems interleavely
auto zero = hal::constant(ctx, 0U, in.dtype(), {1});
auto pad_even =
hal::pad(ctx, final_even_elems, zero, {0},
{final_even_elems.numel() == odd_elems.numel() ? 1 : 0}, {1});
auto pad_odd =
hal::pad(ctx, odd_elems, zero, {1},
{final_even_elems.numel() == odd_elems.numel() ? 0 : 1}, {1});
auto pad_even = hal::pad(
ctx, final_even_elems, zero, {0, 0},
{0, final_even_elems.numel() == odd_elems.numel() ? 1 : 0}, {0, 1});
auto pad_odd = hal::pad(
ctx, odd_elems, zero, {0, 1},
{0, final_even_elems.numel() == odd_elems.numel() ? 0 : 1}, {0, 1});

auto ret = hal::_add(ctx, pad_even, pad_odd).setDtype(in.dtype());
return ret;
return hal::reshape(ctx, ret, in.shape());
}

//////////////////////////////////////////////////////////////////////////////
// Shape utils
//////////////////////////////////////////////////////////////////////////////

/// the squeeze function, i.e., removes dimensions of size 1 from the shape of a
/// tensor.
// @param in, the input
// @param dim, the dimension to be squeezed
Value squeeze(SPUContext* ctx, const Value& in, int64_t dim = 0);

/// the unsqueeze function, i.e., expands a tensor with a length 1 axis inserted
/// at index axis.
// @param in, the input
// @param dim, the dimension to be unsqueezed
Value unsqueeze(SPUContext* ctx, const Value& in, int64_t dim = 0);

} // namespace spu::kernel::hal
58 changes: 57 additions & 1 deletion libspu/kernel/hal/utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace spu::kernel::hal {
namespace {

TEST(UtilsTest, associative_scan) {
TEST(UtilsTest, associative_scan_1d) {
SPUContext ctx = test::makeSPUContext();

{
Expand Down Expand Up @@ -82,6 +82,62 @@ TEST(UtilsTest, associative_scan) {
}
}

TEST(UtilsTest, associative_scan_2d) {
SPUContext ctx = test::makeSPUContext();

{
const xt::xarray<int32_t> x = {{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}};
const xt::xarray<int32_t> prefix_sum = {{1, 2, 3, 4, 5}, {1, 2, 3, 4, 5}};
Value a = test::makeValue(&ctx, x, VIS_SECRET);
Value b = associative_scan(hal::add, &ctx, a);
auto ret = dump_public_as<int32_t>(&ctx, hal::reveal(&ctx, b));
EXPECT_TRUE(prefix_sum == ret) << x << std::endl
<< prefix_sum << std::endl
<< ret;
}

{
const xt::xarray<int32_t> x = {{1, 2, 3, 4, 5}, {1, 2, 3, 4, 5}};
const xt::xarray<int32_t> prefix_prod = {{1, 2, 6, 24, 120},
{1, 2, 6, 24, 120}};
Value a = test::makeValue(&ctx, x, VIS_SECRET);
Value b = associative_scan(hal::mul, &ctx, a);
auto ret = dump_public_as<int32_t>(&ctx, hal::reveal(&ctx, b));
EXPECT_TRUE(prefix_prod == ret) << x << std::endl
<< prefix_prod << std::endl
<< ret;
}

{
const xt::xarray<bool> x = {{true, true, true, false, true, false},
{true, true, true, false, true, false}};
const xt::xarray<bool> prefix_and = {
{true, true, true, false, false, false},
{true, true, true, false, false, false}};

Value a = test::makeValue(&ctx, x, VIS_SECRET);
Value b = associative_scan(hal::bitwise_and, &ctx, a);
auto ret = dump_public_as<bool>(&ctx, hal::reveal(&ctx, b));
EXPECT_TRUE(prefix_and == ret) << x << std::endl
<< prefix_and << std::endl
<< ret;
}

{
const xt::xarray<bool> x = {{true, true, true, false, true, false},
{true, true, true, false, true, false}};
const xt::xarray<bool> prefix_or = {{true, true, true, true, true, true},
{true, true, true, true, true, true}};

Value a = test::makeValue(&ctx, x, VIS_SECRET);
Value b = associative_scan(hal::bitwise_or, &ctx, a);
auto ret = dump_public_as<bool>(&ctx, hal::reveal(&ctx, b));
EXPECT_TRUE(prefix_or == ret) << x << std::endl
<< prefix_or << std::endl
<< ret;
}
}

TEST(UtilsTest, Squeeze) {
// GIVEN
xt::xarray<int32_t> x = xt::ones<int32_t>({2, 1, 2, 1, 2});
Expand Down
Loading
Loading