Skip to content

Commit 82a8bd6

Browse files
authored
Repo sync (#616)
1 parent 7c7f863 commit 82a8bd6

File tree

6 files changed

+108
-141
lines changed

6 files changed

+108
-141
lines changed

libspu/kernel/hal/fxp_cleartext.cc

+6
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,10 @@ Value f_erf_p(SPUContext* ctx, const Value& in) {
145145
return applyFloatingPointFn(ctx, in, [](float x) { return std::erf(x); });
146146
}
147147

148+
Value f_pow_p(SPUContext* ctx, const Value& x, const Value& y) {
149+
SPU_TRACE_HAL_DISP(ctx, x, y);
150+
return applyFloatingPointFn(ctx, x, y,
151+
[](float a, float b) { return std::pow(a, b); });
152+
}
153+
148154
} // namespace spu::kernel::hal

libspu/kernel/hal/fxp_cleartext.h

+2
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,6 @@ Value f_cosine_p(SPUContext* ctx, const Value& in);
4040

4141
Value f_erf_p(SPUContext* ctx, const Value& in);
4242

43+
Value f_pow_p(SPUContext* ctx, const Value& x, const Value& y);
44+
4345
} // namespace spu::kernel::hal

libspu/kernel/hal/polymorphic.cc

+25-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "libspu/core/trace.h"
2020
#include "libspu/kernel/hal/fxp_approx.h"
2121
#include "libspu/kernel/hal/fxp_base.h"
22+
#include "libspu/kernel/hal/fxp_cleartext.h"
2223
#include "libspu/kernel/hal/integer.h"
2324
#include "libspu/kernel/hal/ring.h" // for fast fxp x int
2425
#include "libspu/kernel/hal/type_cast.h"
@@ -329,15 +330,36 @@ Value min(SPUContext* ctx, const Value& x, const Value& y) {
329330
Value power(SPUContext* ctx, const Value& x, const Value& y) {
330331
SPU_TRACE_HAL_DISP(ctx, x, y);
331332

332-
if (x.isInt() && y.isInt()) {
333+
if (x.isInt() || y.isInt()) {
333334
auto x_f = dtype_cast(ctx, x, DT_F32);
334335
auto y_f = dtype_cast(ctx, y, DT_F32);
335336
auto ret = power(ctx, x_f, y_f);
336-
return dtype_cast(ctx, ret, x.dtype());
337+
return ret;
338+
}
339+
if (x.isPublic() && y.isPublic()) {
340+
return f_pow_p(ctx, x, y);
337341
}
338342

343+
auto msb = _msb(ctx, x);
344+
auto msb_a = _prefer_a(ctx, msb);
345+
auto x_abs = _mux(ctx, msb_a, _negate(ctx, x), x).setDtype(x.dtype());
346+
347+
// if x=0 is public, then log(x) get -inf, the wrong output will be got after
348+
// multiplying y. So we force x to be secret, then computing log(x) leads to
349+
// a small negative numbers, so exp(y*log(x))=0.
350+
auto x_s = x.isPublic() ? hal::seal(ctx, x_abs) : x_abs;
339351
// x^y = e^(y*ln(x))
340-
return exp(ctx, mul(ctx, y, log(ctx, x)));
352+
// the precision is highly dependent on the precision of exp and log, so we
353+
// choose the most precise methods here.
354+
auto val = detail::exp_pade(ctx, mul(ctx, y, detail::log_minmax(ctx, x_s)));
355+
356+
// the final sign is decided on both sign of x and the parity of y
357+
// when x<0 and y is odd, e.g. (-2)^3 = -8
358+
auto odd = _and(ctx, _rshift(ctx, y, ctx->getFxpBits()),
359+
_constant(ctx, 1, y.shape()));
360+
auto sign = _and(ctx, msb, odd);
361+
362+
return _mux(ctx, sign, _negate(ctx, val), val).setDtype(x.dtype());
341363
}
342364

343365
Value idiv(SPUContext* ctx, const Value& x, const Value& y) {

libspu/kernel/hal/polymorphic_test.cc

+31-15
Original file line numberDiff line numberDiff line change
@@ -406,26 +406,42 @@ TYPED_TEST(MathTest, Pow) {
406406
using LHS_VT = typename std::tuple_element<1, TypeParam>::type;
407407
using RHS_DT = typename std::tuple_element<2, TypeParam>::type;
408408
using RHS_VT = typename std::tuple_element<3, TypeParam>::type;
409-
using RES_DT = typename std::tuple_element<4, TypeParam>::type;
409+
// using RES_DT = typename std::tuple_element<4, TypeParam>::type;
410410

411-
if constexpr (!std::is_same_v<LHS_DT, RHS_DT> ||
412-
!std::is_same_v<LHS_VT, RHS_VT> || std::is_integral_v<RHS_DT>) {
413-
return;
411+
// GIVEN
412+
xt::xarray<LHS_DT> x;
413+
xt::xarray<RHS_DT> y;
414+
{
415+
// random test
416+
x = test::xt_random<LHS_DT>({5, 6}, 0, 100);
417+
y = test::xt_random<RHS_DT>({5, 6}, -2, 2);
418+
419+
// WHAT
420+
auto z = test::evalBinaryOp<float>(LHS_VT(), RHS_VT(), power, x, y);
421+
422+
// THEN
423+
auto expected = xt::pow(x, y);
424+
EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl
425+
<< y << std::endl
426+
<< expected << std::endl
427+
<< z << std::endl;
414428
}
415429

416-
// GIVEN
417-
const xt::xarray<LHS_DT> x = test::xt_random<LHS_DT>({5, 6}, 0, 100);
418-
const xt::xarray<RHS_DT> y = test::xt_random<RHS_DT>({5, 6}, 0, 2);
430+
{
431+
// some fixed corner case
432+
x = {-1, -1, -3, 1, -3, 0, 1, 1, 5, 0};
433+
y = {1, 0, -3, -3, 3, 0, 0, 2, 5, 2};
419434

420-
// WHAT
421-
auto z = test::evalBinaryOp<RES_DT>(LHS_VT(), RHS_VT(), power, x, y);
435+
// WHAT
436+
auto z = test::evalBinaryOp<float>(LHS_VT(), RHS_VT(), power, x, y);
422437

423-
// THEN
424-
auto expected = xt::pow(x, y);
425-
EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl
426-
<< y << std::endl
427-
<< expected << std::endl
428-
<< z << std::endl;
438+
// THEN
439+
auto expected = xt::pow(x, y);
440+
EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl
441+
<< y << std::endl
442+
<< expected << std::endl
443+
<< z << std::endl;
444+
}
429445
}
430446

431447
using MathUnaryTestTypes = ::testing::Types<

libspu/mpc/cheetah/state.cc

+41-122
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace spu::mpc::cheetah {
2828
namespace {
2929
// Return num_workers for the given size of jobs
3030
size_t InitOTState(KernelEvalContext* ctx, size_t njobs) {
31-
constexpr size_t kMinWorkSize = 5000;
31+
constexpr size_t kMinWorkSize = 2048;
3232
if (njobs == 0) {
3333
return 0;
3434
}
@@ -139,86 +139,44 @@ std::array<NdArrayRef, 3> CheetahMulState::TakeCachedBeaver(FieldType field,
139139

140140
NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x,
141141
OTUnaryFunc func) {
142-
Shape shape = x.shape();
142+
const Shape& shape = x.shape();
143+
SPU_ENFORCE(shape.numel() > 0);
143144
// (lazy) init OT
144145
int64_t numel = x.numel();
145146
int64_t nworker = InitOTState(ctx, numel);
146147
int64_t workload = nworker == 0 ? 0 : CeilDiv(numel, nworker);
147148

148-
int64_t slicing_dim = -1;
149-
int64_t slice_numel = 1;
150-
for (int64_t dim = shape.size() - 1; dim >= 0; dim--) {
151-
slice_numel *= shape[dim];
152-
if (slice_numel > workload) {
153-
slice_numel /= shape[dim];
154-
slicing_dim = dim;
155-
break;
156-
}
157-
}
158-
159-
// get the slice num in the left outer dimensions
160-
int64_t num_slice = 1;
161-
for (int64_t dim = 0; dim < slicing_dim; dim++) {
162-
num_slice *= shape[dim];
163-
}
164-
165-
int64_t slice_stride = (workload + slice_numel - 1) / slice_numel;
166-
if (slice_stride == 1) {
167-
return func(x, ctx->getState<CheetahOTState>()->get(0));
168-
}
169-
170-
int64_t num_slice_dim = shape[slicing_dim] / slice_stride +
171-
((shape[slicing_dim] % slice_stride) != 0 ? 1 : 0);
172-
173-
// initialize slice indices
174-
Index start_indices(shape.size());
175-
Index end_indices(shape.begin(), shape.end());
176-
end_indices[slicing_dim] = slice_stride;
177-
for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) {
178-
end_indices[dim] = 1;
149+
if (shape.ndim() != 1) {
150+
// TiledDispatchOTFunc over flatten input
151+
return TiledDispatchOTFunc(ctx, x.reshape({numel}), func)
152+
.reshape(x.shape());
179153
}
180154

181-
SPU_ENFORCE_LE(num_slice * num_slice_dim, nworker);
182-
nworker = num_slice * num_slice_dim;
183-
184155
std::vector<NdArrayRef> outs(nworker);
185156
std::vector<std::future<void>> futures;
186157

187-
Index sidx = start_indices;
188-
Index eidx = end_indices;
189-
for (int64_t wi = 0; wi < nworker; ++wi) {
190-
auto slice_input = x.slice(sidx, eidx, {});
158+
int64_t slice_end = 0;
159+
for (int64_t wi = 0; wi + 1 < nworker; ++wi) {
160+
int64_t slice_bgn = wi * workload;
161+
slice_end = std::min(numel, slice_bgn + workload);
162+
auto slice_input = x.slice({slice_bgn}, {slice_end}, {});
191163
futures.emplace_back(std::async(
192164
[&](int64_t idx, const NdArrayRef& input) {
193165
auto ot_instance = ctx->getState<CheetahOTState>()->get(idx);
194166
outs[idx] = func(input, ot_instance);
195167
},
196168
wi, slice_input));
197-
198-
// update indices
199-
if (0 == (eidx[slicing_dim] % shape[slicing_dim])) {
200-
// carray out
201-
sidx[slicing_dim] = 0;
202-
eidx[slicing_dim] = slice_stride;
203-
for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) {
204-
sidx[dim] = (sidx[dim] + 1) % shape[dim];
205-
eidx[dim] = eidx[dim] % shape[dim] + 1;
206-
if (eidx[dim] != 1) {
207-
break;
208-
}
209-
}
210-
} else {
211-
sidx[slicing_dim] += slice_stride;
212-
eidx[slicing_dim] += slice_stride;
213-
eidx[slicing_dim] = std::min(shape[slicing_dim], eidx[slicing_dim]);
214-
}
215169
}
216170

171+
auto slice_input = x.slice({slice_end}, {numel}, {1});
172+
auto ot_instance = ctx->getState<CheetahOTState>()->get(nworker - 1);
173+
outs[nworker - 1] = func(slice_input, ot_instance);
174+
217175
for (auto&& f : futures) {
218176
f.get();
219177
}
220178

221-
NdArrayRef out(x.eltype(), x.shape());
179+
NdArrayRef out(outs[0].eltype(), x.shape());
222180
int64_t offset = 0;
223181

224182
for (auto& out_slice : outs) {
@@ -232,89 +190,50 @@ NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x,
232190

233191
NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x,
234192
const NdArrayRef& y, OTBinaryFunc func) {
235-
Shape shape = x.shape();
236-
SPU_ENFORCE_EQ(x.shape(), y.shape());
193+
const Shape& shape = x.shape();
194+
SPU_ENFORCE(shape.numel() > 0);
195+
SPU_ENFORCE_EQ(shape, y.shape());
237196
// (lazy) init OT
238197
int64_t numel = x.numel();
239198
int64_t nworker = InitOTState(ctx, numel);
240199
int64_t workload = nworker == 0 ? 0 : CeilDiv(numel, nworker);
241200

242-
int64_t slicing_dim = -1;
243-
int64_t slice_numel = 1;
244-
for (int64_t dim = shape.size() - 1; dim >= 0; dim--) {
245-
slice_numel *= shape[dim];
246-
if (slice_numel > workload) {
247-
slice_numel /= shape[dim];
248-
slicing_dim = dim;
249-
break;
250-
}
201+
if (shape.ndim() != 1) {
202+
// TiledDispatchOTFunc over flatten input
203+
return TiledDispatchOTFunc(ctx, x.reshape({numel}), y.reshape({numel}),
204+
func)
205+
.reshape(x.shape());
251206
}
252207

253-
// get the slice num in the left outer dimensions
254-
int64_t num_slice = 1;
255-
for (int64_t dim = 0; dim < slicing_dim; dim++) {
256-
num_slice *= shape[dim];
257-
}
258-
259-
int64_t slice_stride = (workload + slice_numel - 1) / slice_numel;
260-
if (slice_stride == 1) {
261-
return func(x, y, ctx->getState<CheetahOTState>()->get(0));
262-
}
263-
264-
int64_t num_slice_dim = shape[slicing_dim] / slice_stride +
265-
((shape[slicing_dim] % slice_stride) != 0 ? 1 : 0);
266-
267-
// initialize slice indices
268-
Index start_indices(shape.size());
269-
Index end_indices(shape.begin(), shape.end());
270-
end_indices[slicing_dim] = slice_stride;
271-
for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) {
272-
end_indices[dim] = 1;
273-
}
274-
275-
SPU_ENFORCE_LE(num_slice * num_slice_dim, nworker);
276-
nworker = num_slice * num_slice_dim;
277-
278208
std::vector<NdArrayRef> outs(nworker);
279209
std::vector<std::future<void>> futures;
280210

281-
Index sidx = start_indices;
282-
Index eidx = end_indices;
283-
for (int64_t wi = 0; wi < nworker; ++wi) {
284-
auto x_slice = x.slice(sidx, eidx, {});
285-
auto y_slice = y.slice(sidx, eidx, {});
286-
211+
int64_t slice_end = 0;
212+
for (int64_t wi = 0; wi + 1 < nworker; ++wi) {
213+
int64_t slice_bgn = wi * workload;
214+
slice_end = std::min(numel, slice_bgn + workload);
215+
auto x_slice = x.slice({slice_bgn}, {slice_end}, {1});
216+
auto y_slice = y.slice({slice_bgn}, {slice_end}, {1});
287217
futures.emplace_back(std::async(
288-
[&](int64_t idx, const NdArrayRef& input0, const NdArrayRef& input1) {
218+
[&](int64_t idx, const NdArrayRef& inp0, const NdArrayRef& inp1) {
289219
auto ot_instance = ctx->getState<CheetahOTState>()->get(idx);
290-
outs[idx] = func(input0, input1, ot_instance);
220+
outs[idx] = func(inp0, inp1, ot_instance);
291221
},
292222
wi, x_slice, y_slice));
293-
294-
// update indices
295-
if (0 == (eidx[slicing_dim] % shape[slicing_dim])) {
296-
// carray out
297-
sidx[slicing_dim] = 0;
298-
eidx[slicing_dim] = slice_stride;
299-
for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) {
300-
sidx[dim] = (sidx[dim] + 1) % shape[dim];
301-
eidx[dim] = eidx[dim] % shape[dim] + 1;
302-
if (eidx[dim] != 1) {
303-
break;
304-
}
305-
}
306-
} else {
307-
sidx[slicing_dim] += slice_stride;
308-
eidx[slicing_dim] += slice_stride;
309-
eidx[slicing_dim] = std::min(shape[slicing_dim], eidx[slicing_dim]);
310-
}
311223
}
224+
225+
auto x_slice = x.slice({slice_end}, {numel}, {});
226+
auto y_slice = y.slice({slice_end}, {numel}, {});
227+
auto ot_instance = ctx->getState<CheetahOTState>()->get(nworker - 1);
228+
outs[nworker - 1] = func(x_slice, y_slice, ot_instance);
229+
312230
for (auto&& f : futures) {
313231
f.get();
314232
}
315233

316-
NdArrayRef out(x.eltype(), x.shape());
234+
NdArrayRef out(outs[0].eltype(), x.shape());
317235
int64_t offset = 0;
236+
318237
for (auto& out_slice : outs) {
319238
std::memcpy(out.data<std::byte>() + offset, out_slice.data(),
320239
out_slice.numel() * out.elsize());

libspu/mpc/cheetah/state.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include "libspu/mpc/cheetah/ot/basic_ot_prot.h"
2626
#include "libspu/mpc/cheetah/rlwe/utils.h"
2727

28+
#include "libspu/spu.pb.h"
29+
2830
namespace spu::mpc::cheetah {
2931

3032
using OTUnaryFunc = std::function<NdArrayRef(
@@ -101,7 +103,7 @@ class CheetahOTState : public State {
101103

102104
mutable std::mutex lock_;
103105

104-
static constexpr size_t kMaxOTParallel = 24;
106+
static constexpr size_t kMaxOTParallel = 48;
105107

106108
size_t maximum_instances_ = 0;
107109
std::vector<ProtPtr> basic_ot_prot_;

0 commit comments

Comments
 (0)