Skip to content

Commit 5c8c051

Browse files
authored
Repo sync (#618)
1 parent 82a8bd6 commit 5c8c051

18 files changed

+123
-30
lines changed

libspu/core/encoding.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ namespace spu {
6161
FN(PT_U32, DT_U32) \
6262
FN(PT_I64, DT_I64) \
6363
FN(PT_U64, DT_U64) \
64-
FN(PT_BOOL, DT_I1) \
64+
FN(PT_I1, DT_I1) \
6565
FN(PT_F16, DT_F16) \
6666
FN(PT_F32, DT_F32) \
6767
FN(PT_F64, DT_F64)
@@ -75,7 +75,7 @@ namespace spu {
7575
FN(DT_U32, PT_U32) \
7676
FN(DT_I64, PT_I64) \
7777
FN(DT_U64, PT_U64) \
78-
FN(DT_I1, PT_BOOL) \
78+
FN(DT_I1, PT_I1) \
7979
FN(DT_F16, PT_F16) \
8080
FN(DT_F32, PT_F32) \
8181
FN(DT_F64, PT_F64)

libspu/core/encoding_test.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ TEST(EncodingTypeTest, EncodeDecodeMap) {
3131
EXPECT_EQ(getEncodeType(PT_U32), DT_U32);
3232
EXPECT_EQ(getEncodeType(PT_I64), DT_I64);
3333
EXPECT_EQ(getEncodeType(PT_U64), DT_U64);
34-
EXPECT_EQ(getEncodeType(PT_BOOL), DT_I1);
34+
EXPECT_EQ(getEncodeType(PT_I1), DT_I1);
3535
EXPECT_EQ(getEncodeType(PT_F32), DT_F32);
3636
EXPECT_EQ(getEncodeType(PT_F64), DT_F64);
3737

@@ -43,7 +43,7 @@ TEST(EncodingTypeTest, EncodeDecodeMap) {
4343
EXPECT_EQ(getDecodeType(DT_U32), PT_U32);
4444
EXPECT_EQ(getDecodeType(DT_I64), PT_I64);
4545
EXPECT_EQ(getDecodeType(DT_U64), PT_U64);
46-
EXPECT_EQ(getDecodeType(DT_I1), PT_BOOL);
46+
EXPECT_EQ(getDecodeType(DT_I1), PT_I1);
4747
EXPECT_EQ(getDecodeType(DT_F32), PT_F32);
4848
EXPECT_EQ(getDecodeType(DT_F64), PT_F64);
4949
}

libspu/core/pt_buffer_view.cc

+10
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ std::ostream& operator<<(std::ostream& out, PtBufferView v) {
2727
}
2828

2929
NdArrayRef convertToNdArray(PtBufferView bv) {
30+
if (bv.isBitSet()) {
31+
SPU_ENFORCE(bv.isCompact() && bv.pt_type == PT_I1);
32+
auto out = NdArrayRef(I1, bv.shape);
33+
auto* out_ptr = out.data<bool>();
34+
auto num_bits = bv.shape.numel();
35+
for (int64_t idx = 0; idx < num_bits; ++idx) {
36+
out_ptr[idx] = bv.getBit(idx);
37+
}
38+
return out;
39+
}
3040
const auto type = makePtType(bv.pt_type);
3141
auto out = NdArrayRef(type, bv.shape);
3242
return DISPATCH_ALL_PT_TYPES(bv.pt_type, "pt_type", [&]() {

libspu/core/pt_buffer_view.h

+33-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
#include <utility>
1818

19+
#include "spdlog/spdlog.h"
20+
1921
#include "libspu/core/ndarray_ref.h"
2022
#include "libspu/core/prelude.h"
2123
#include "libspu/core/shape.h"
@@ -46,21 +48,27 @@ struct PtBufferView {
4648
Strides const strides; // Strides in number of elements.
4749
bool const write_able{false}; // Whether this is a writable buffer
4850
bool const compacted{false}; // Whether this is a compacted buffer
51+
bool is_bitset{false}; // Bit data
4952

5053
// We have to take a concrete buffer as a view.
5154
PtBufferView() = delete;
5255

5356
// full constructor
5457
template <typename Pointer>
5558
explicit PtBufferView(Pointer ptr, PtType pt_type, Shape in_shape,
56-
Strides in_strides)
59+
Strides in_strides, bool is_bitset = false)
5760
: ptr(const_cast<void*>(static_cast<const void*>(ptr))),
5861
pt_type(pt_type),
5962
shape(std::move(in_shape)),
6063
strides(std::move(in_strides)),
6164
write_able(!std::is_const_v<std::remove_pointer_t<Pointer>>),
62-
compacted(strides == makeCompactStrides(shape)) {
65+
compacted(strides == makeCompactStrides(shape)),
66+
is_bitset(is_bitset) {
6367
static_assert(std::is_pointer_v<Pointer>);
68+
if (is_bitset) {
69+
SPU_ENFORCE(pt_type == PT_I1 && compacted,
70+
"Bitset must be I1 type with compacted data");
71+
}
6472
}
6573

6674
// View c++ builtin scalar type as a buffer
@@ -72,7 +80,12 @@ struct PtBufferView {
7280
strides(),
7381
compacted(true) {}
7482

75-
// FIXME(jint): make it work when T = bool
83+
explicit PtBufferView(bool const& s)
84+
: ptr(const_cast<void*>(static_cast<const void*>(&s))),
85+
pt_type(PT_I1),
86+
shape(),
87+
strides() {}
88+
7689
template <typename T,
7790
std::enable_if_t<detail::is_container_like_v<T>, bool> = true>
7891
/* implicit */ PtBufferView(const T& c) // NOLINT
@@ -104,6 +117,7 @@ struct PtBufferView {
104117

105118
template <typename S = uint8_t>
106119
const S& get(const Index& indices) const {
120+
SPU_ENFORCE(!is_bitset);
107121
SPU_ENFORCE(PtTypeToEnum<S>::value == pt_type);
108122
auto fi = calcFlattenOffset(indices, shape, strides);
109123
const auto* addr =
@@ -113,6 +127,7 @@ struct PtBufferView {
113127

114128
template <typename S = uint8_t>
115129
const S& get(size_t idx) const {
130+
SPU_ENFORCE(!is_bitset);
116131
if (isCompact()) {
117132
const auto* addr =
118133
static_cast<const std::byte*>(ptr) + SizeOf(pt_type) * idx;
@@ -127,13 +142,15 @@ struct PtBufferView {
127142
void set(const Index& indices, S v) {
128143
SPU_ENFORCE(write_able);
129144
SPU_ENFORCE(PtTypeToEnum<S>::value == pt_type);
145+
SPU_ENFORCE(!is_bitset);
130146
auto fi = calcFlattenOffset(indices, shape, strides);
131147
auto* addr = static_cast<std::byte*>(ptr) + SizeOf(pt_type) * fi;
132148
*reinterpret_cast<S*>(addr) = v;
133149
}
134150

135151
template <typename S = uint8_t>
136152
void set(size_t idx, S v) {
153+
SPU_ENFORCE(!is_bitset);
137154
if (isCompact()) {
138155
auto* addr = static_cast<std::byte*>(ptr) + SizeOf(pt_type) * idx;
139156
*reinterpret_cast<S*>(addr) = v;
@@ -144,6 +161,19 @@ struct PtBufferView {
144161
}
145162

146163
bool isCompact() const { return compacted; }
164+
165+
bool isBitSet() const { return is_bitset; }
166+
167+
bool getBit(size_t idx) const {
168+
SPU_ENFORCE(is_bitset);
169+
auto el_idx = idx / 8;
170+
auto bit_offset = idx % 8;
171+
172+
uint8_t mask = (1 << bit_offset);
173+
uint8_t el = static_cast<uint8_t*>(ptr)[el_idx];
174+
175+
return (mask & el) != 0;
176+
}
147177
};
148178

149179
std::ostream& operator<<(std::ostream& out, PtBufferView v);

libspu/core/pt_buffer_view_test.cc

+35
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "libspu/core/pt_buffer_view.h"
1616

1717
#include <array>
18+
#include <bitset>
1819

1920
#include "gmock/gmock.h"
2021
#include "gtest/gtest.h"
@@ -39,6 +40,12 @@ TEST(PtBufferView, Scalar) {
3940
EXPECT_TRUE(bv_f32.shape.isScalar());
4041
EXPECT_EQ(bv_f32.shape.numel(), 1);
4142
EXPECT_TRUE(bv_f32.strides.empty());
43+
44+
PtBufferView bv_i1(true);
45+
EXPECT_EQ(bv_i1.pt_type, PT_I1);
46+
EXPECT_TRUE(bv_i1.shape.isScalar());
47+
EXPECT_EQ(bv_i1.shape.numel(), 1);
48+
EXPECT_TRUE(bv_i1.strides.empty());
4249
}
4350

4451
TEST(PtBufferView, Vector) {
@@ -72,4 +79,32 @@ TEST(PtBufferView, ConvertToNdArray) {
7279
EXPECT_FLOAT_EQ((arr.at<float>(2)), 3.0);
7380
}
7481

82+
TEST(PtBufferView, BoolContainer) {
83+
std::array<bool, 3> test = {true, false, true};
84+
PtBufferView bv(test);
85+
86+
EXPECT_EQ(bv.get<bool>(0), true);
87+
EXPECT_EQ(bv.get<bool>(1), false);
88+
EXPECT_EQ(bv.get<bool>(2), true);
89+
}
90+
91+
TEST(PtBufferView, BitSet) {
92+
int16_t test = 2024;
93+
PtBufferView bv(&test, PT_I1, {8 * sizeof(int16_t)}, {1}, true);
94+
95+
EXPECT_EQ(bv.shape.numel(), 16);
96+
97+
std::bitset<16> expected(2024);
98+
for (size_t idx = 0; idx < 16; ++idx) {
99+
EXPECT_EQ(bv.getBit(idx), expected[idx]);
100+
}
101+
102+
auto arr = convertToNdArray(bv);
103+
EXPECT_EQ(arr.shape().numel(), 16);
104+
105+
for (size_t idx = 0; idx < 16; ++idx) {
106+
EXPECT_EQ(arr.at<bool>(idx), expected[idx]) << idx << "\n";
107+
}
108+
}
109+
75110
} // namespace spu

libspu/core/type.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Type F32 = makePtType(PT_F32);
6060
Type F64 = makePtType(PT_F64);
6161
Type I128 = makePtType(PT_I128);
6262
Type U128 = makePtType(PT_U128);
63-
Type BOOL = makePtType(PT_BOOL);
63+
Type I1 = makePtType(PT_I1);
6464
Type CF32 = makePtType(PT_CF32);
6565
Type CF64 = makePtType(PT_CF64);
6666

libspu/core/type.h

+1
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ extern Type U64;
338338
extern Type F16;
339339
extern Type F32;
340340
extern Type F64;
341+
extern Type I1;
341342
extern Type I128;
342343
extern Type U128;
343344
extern Type CF32;

libspu/core/type_util.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ std::ostream& operator<<(std::ostream& os, const DataType& dtype);
8585
FN(PT_U64, uint64_t, U64) \
8686
FN(PT_I128, int128_t, I128) \
8787
FN(PT_U128, uint128_t, U128) \
88-
FN(PT_BOOL, bool, I1)
88+
FN(PT_I1, bool, I1)
8989

9090
#define FOREACH_COMPLEX_PT_TYPES(FN) \
9191
FN(PT_CF32, std::complex<float>, CF32) \
@@ -131,7 +131,7 @@ std::ostream& operator<<(std::ostream& os, const DataType& dtype);
131131
#define DISPATCH_INT_PT_TYPES(PT_TYPE, NAME, ...) \
132132
[&] { \
133133
switch (PT_TYPE) { \
134-
__CASE_PT_TYPE(spu::PT_BOOL, NAME, __VA_ARGS__) \
134+
__CASE_PT_TYPE(spu::PT_I1, NAME, __VA_ARGS__) \
135135
__CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \
136136
__CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \
137137
__CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \
@@ -148,7 +148,7 @@ std::ostream& operator<<(std::ostream& os, const DataType& dtype);
148148
#define DISPATCH_ALL_PT_TYPES(PT_TYPE, NAME, ...) \
149149
[&] { \
150150
switch (PT_TYPE) { \
151-
__CASE_PT_TYPE(spu::PT_BOOL, NAME, __VA_ARGS__) \
151+
__CASE_PT_TYPE(spu::PT_I1, NAME, __VA_ARGS__) \
152152
__CASE_PT_TYPE(spu::PT_I8, NAME, __VA_ARGS__) \
153153
__CASE_PT_TYPE(spu::PT_U8, NAME, __VA_ARGS__) \
154154
__CASE_PT_TYPE(spu::PT_I16, NAME, __VA_ARGS__) \

libspu/device/io.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ IoClient::IoClient(size_t world_size, const RuntimeConfig &config)
3333

3434
size_t IoClient::getShareSize(const PtBufferView &bv, Visibility vtype,
3535
int owner_rank) {
36-
if (bv.pt_type == PT_BOOL && vtype == VIS_SECRET &&
36+
if (bv.pt_type == PT_I1 && vtype == VIS_SECRET &&
3737
base_io_->hasBitSecretSupport()) {
3838
return base_io_->getBitSecretShareSize(bv.shape.numel());
3939
} else {
@@ -46,7 +46,7 @@ std::vector<spu::Value> IoClient::makeShares(const PtBufferView &bv,
4646
const size_t fxp_bits = config_.fxp_fraction_bits();
4747
SPU_ENFORCE(fxp_bits != 0, "fxp should never be zero, please check default");
4848

49-
if (bv.pt_type == PT_BOOL && vtype == VIS_SECRET &&
49+
if (bv.pt_type == PT_I1 && vtype == VIS_SECRET &&
5050
base_io_->hasBitSecretSupport()) {
5151
auto shares = base_io_->makeBitSecret(bv);
5252
SPU_ENFORCE(shares.size() == world_size_);

libspu/device/pphlo/pphlo_executor.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ std::pair<spu::PtType, bool> getPtTypeFromMlirType(mlir::Type mlir_ty) {
6767
}
6868
} else if (auto it = express_type.dyn_cast<mlir::IntegerType>()) {
6969
if (it.getWidth() == 1) {
70-
return {spu::PT_BOOL, false};
70+
return {spu::PT_I1, false};
7171
}
7272
// In mlir, isSigned is for si[1-9][0-9]* type, isUnsigned is for
7373
// ui[1-9][0-9]*, i[1-9][0-9]* is signless IntegerType... So here, we only

libspu/kernel/hal/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ spu_cc_test(
255255
deps = [
256256
":type_cast",
257257
"//libspu/kernel:test_util",
258+
"//libspu/mpc/utils:simulate",
258259
],
259260
)
260261

libspu/kernel/hal/constants_test.cc

+10-10
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,16 @@ TEST(ConstantsTest, TensorBroadcast) {
6969
TEST(ConstantsTest, Initializer) {
7070
SPUContext ctx = test::makeSPUContext();
7171

72-
// FIXME: the dtype is determined by the C++ literal type.
73-
// EXPECT_EQ(constant(&ctx, 0, DT_I1).dtype(), DT_I1); // FIXME
74-
// EXPECT_EQ(constant(&ctx, 0, DT_I8).dtype(), DT_I8); // FIXME
75-
// EXPECT_EQ(constant(&ctx, 0, DT_U8).dtype(), DT_U8); // FIXME
76-
// EXPECT_EQ(constant(&ctx, 0, DT_I16).dtype(), DT_I16); // FIXME
77-
// EXPECT_EQ(constant(&ctx, 0, DT_U16).dtype(), DT_U16); // FIXME
78-
EXPECT_EQ(constant(&ctx, 0, DT_I32).dtype(), DT_I32);
79-
// EXPECT_EQ(constant(&ctx, 0, DT_U32).dtype(), DT_U32); // FIXME
80-
// EXPECT_EQ(constant(&ctx, 0, DT_I64).dtype(), DT_I64); // FIXME
81-
// EXPECT_EQ(constant(&ctx, 0, DT_U64).dtype(), DT_U64); // FIXME
72+
EXPECT_EQ(constant(&ctx, true, DT_I1).dtype(), DT_I1);
73+
EXPECT_EQ(constant(&ctx, static_cast<int8_t>(0), DT_I8).dtype(), DT_I8);
74+
EXPECT_EQ(constant(&ctx, static_cast<uint8_t>(0), DT_U8).dtype(), DT_U8);
75+
EXPECT_EQ(constant(&ctx, static_cast<int16_t>(0), DT_I16).dtype(), DT_I16);
76+
EXPECT_EQ(constant(&ctx, static_cast<uint16_t>(0), DT_U16).dtype(), DT_U16);
77+
EXPECT_EQ(constant(&ctx, static_cast<int32_t>(0), DT_I32).dtype(), DT_I32);
78+
EXPECT_EQ(constant(&ctx, static_cast<uint32_t>(0), DT_U32).dtype(), DT_U32);
79+
EXPECT_EQ(constant(&ctx, static_cast<int64_t>(0), DT_I64).dtype(), DT_I64);
80+
EXPECT_EQ(constant(&ctx, static_cast<uint64_t>(0), DT_U64).dtype(), DT_U64);
81+
8282
EXPECT_EQ(constant(&ctx, 0.0F, DT_F32).dtype(), DT_F32);
8383
EXPECT_EQ(constant(&ctx, 0.0, DT_F64).dtype(), DT_F64);
8484
}

libspu/kernel/hal/type_cast_test.cc

+13
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "libspu/kernel/hal/constants.h"
2323
#include "libspu/kernel/test_util.h"
24+
#include "libspu/mpc/utils/simulate.h"
2425

2526
namespace spu::kernel::hal {
2627
namespace {
@@ -86,5 +87,17 @@ TEST(TypeCastTest, fxp2int) {
8687
// TODO: cast to other int than DT_I32
8788
}
8889

90+
TEST(TypeCastTest, boolean) {
91+
mpc::utils::simulate(
92+
3, [&](const std::shared_ptr<yacl::link::Context> &lctx) {
93+
SPUContext sctx = test::makeSPUContext(SEMI2K, FM64, lctx);
94+
Value pa = constant(&sctx, true, DT_I1);
95+
Value sa = seal(&sctx, pa);
96+
EXPECT_EQ(sa.dtype(), DT_I1);
97+
EXPECT_TRUE(sa.storage_type().isa<BShare>());
98+
EXPECT_EQ(sa.storage_type().as<BShare>()->nbits(), 1);
99+
});
100+
}
101+
89102
} // namespace
90103
} // namespace spu::kernel::hal

libspu/kernel/hlo/indexing.cc

-1
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ spu::Value SecretDynamicSliceImpl(SPUContext *ctx, const spu::Value &operand,
253253

254254
if (slice_size[0] >= 1) {
255255
auto pad_value = hal::seal(ctx, hal::constant(ctx, false, mask.dtype()));
256-
pad_value = hal::_cast_type(ctx, pad_value, mask.storage_type());
257256
mask = hal::pad(ctx, mask, pad_value, {slice_size[0]}, {0}, {0});
258257
// FIXME(juhou): we should avoid setting the BShr here
259258
// However mask.storage_type().as<BShare>->nbits() is not 1 after the

libspu/mpc/aby3/io.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ size_t Aby3Io::getBitSecretShareSize(size_t numel) const {
9191

9292
std::vector<NdArrayRef> Aby3Io::makeBitSecret(const PtBufferView& in) const {
9393
PtType in_pt_type = in.pt_type;
94-
SPU_ENFORCE(in_pt_type == PT_BOOL);
94+
SPU_ENFORCE(in_pt_type == PT_I1);
9595

96-
if (in_pt_type == PT_BOOL) {
96+
if (in_pt_type == PT_I1) {
9797
// we assume boolean is stored with byte array.
9898
in_pt_type = PT_U8;
9999
}

libspu/mpc/api.cc

+5-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@ Value p2s(SPUContext* ctx, const Value& x) {
144144

145145
TRY_DISPATCH(ctx, x);
146146

147-
return p2a(ctx, x);
147+
if (x.dtype() == DT_I1) {
148+
return p2b(ctx, x);
149+
} else {
150+
return p2a(ctx, x);
151+
}
148152
}
149153

150154
Value p2v(SPUContext* ctx, const Value& x, size_t owner) {

libspu/spu.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ enum PtType {
7979
PT_U64 = 8; // uint64_t
8080
PT_I128 = 9; // int128_t
8181
PT_U128 = 10; // uint128_t
82-
PT_BOOL = 11; // bool
82+
PT_I1 = 11; // bool
8383
//
8484
PT_F16 = 30; // half
8585
PT_F32 = 31; // float

spu/libspu.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ class RuntimeWrapper {
402402
FN("float16", PT_F16) \
403403
FN("float32", PT_F32) \
404404
FN("float64", PT_F64) \
405-
FN("bool", PT_BOOL) \
405+
FN("bool", PT_I1) \
406406
FN("complex64", PT_CF32) \
407407
FN("complex128", PT_CF64)
408408

0 commit comments

Comments
 (0)