16
16
17
17
#include < utility>
18
18
19
+ #include " spdlog/spdlog.h"
20
+
19
21
#include " libspu/core/ndarray_ref.h"
20
22
#include " libspu/core/prelude.h"
21
23
#include " libspu/core/shape.h"
@@ -46,21 +48,27 @@ struct PtBufferView {
46
48
Strides const strides; // Strides in number of elements.
47
49
bool const write_able{false }; // Whether this is a writable buffer
48
50
bool const compacted{false }; // Whether this is a compacted buffer
51
+ bool is_bitset{false }; // Bit data
49
52
50
53
// We have to take a concrete buffer as a view.
51
54
PtBufferView () = delete ;
52
55
53
56
// full constructor
54
57
template <typename Pointer>
55
58
explicit PtBufferView (Pointer ptr, PtType pt_type, Shape in_shape,
56
- Strides in_strides)
59
+ Strides in_strides, bool is_bitset = false )
57
60
: ptr(const_cast <void *>(static_cast <const void *>(ptr))),
58
61
pt_type(pt_type),
59
62
shape(std::move(in_shape)),
60
63
strides(std::move(in_strides)),
61
64
write_able(!std::is_const_v<std::remove_pointer_t <Pointer>>),
62
- compacted(strides == makeCompactStrides(shape)) {
65
+ compacted(strides == makeCompactStrides(shape)),
66
+ is_bitset(is_bitset) {
63
67
static_assert (std::is_pointer_v<Pointer>);
68
+ if (is_bitset) {
69
+ SPU_ENFORCE (pt_type == PT_I1 && compacted,
70
+ " Bitset must be I1 type with compacted data" );
71
+ }
64
72
}
65
73
66
74
// View c++ builtin scalar type as a buffer
@@ -72,7 +80,12 @@ struct PtBufferView {
72
80
strides(),
73
81
compacted(true ) {}
74
82
75
- // FIXME(jint): make it work when T = bool
83
+ explicit PtBufferView (bool const & s)
84
+ : ptr(const_cast <void *>(static_cast <const void *>(&s))),
85
+ pt_type(PT_I1),
86
+ shape(),
87
+ strides() {}
88
+
76
89
template <typename T,
77
90
std::enable_if_t <detail::is_container_like_v<T>, bool > = true >
78
91
/* implicit */ PtBufferView(const T& c) // NOLINT
@@ -104,6 +117,7 @@ struct PtBufferView {
104
117
105
118
template <typename S = uint8_t >
106
119
const S& get (const Index& indices) const {
120
+ SPU_ENFORCE (!is_bitset);
107
121
SPU_ENFORCE (PtTypeToEnum<S>::value == pt_type);
108
122
auto fi = calcFlattenOffset (indices, shape, strides);
109
123
const auto * addr =
@@ -113,6 +127,7 @@ struct PtBufferView {
113
127
114
128
template <typename S = uint8_t >
115
129
const S& get (size_t idx) const {
130
+ SPU_ENFORCE (!is_bitset);
116
131
if (isCompact ()) {
117
132
const auto * addr =
118
133
static_cast <const std::byte*>(ptr) + SizeOf (pt_type) * idx;
@@ -127,13 +142,15 @@ struct PtBufferView {
127
142
void set (const Index& indices, S v) {
128
143
SPU_ENFORCE (write_able);
129
144
SPU_ENFORCE (PtTypeToEnum<S>::value == pt_type);
145
+ SPU_ENFORCE (!is_bitset);
130
146
auto fi = calcFlattenOffset (indices, shape, strides);
131
147
auto * addr = static_cast <std::byte*>(ptr) + SizeOf (pt_type) * fi;
132
148
*reinterpret_cast <S*>(addr) = v;
133
149
}
134
150
135
151
template <typename S = uint8_t >
136
152
void set (size_t idx, S v) {
153
+ SPU_ENFORCE (!is_bitset);
137
154
if (isCompact ()) {
138
155
auto * addr = static_cast <std::byte*>(ptr) + SizeOf (pt_type) * idx;
139
156
*reinterpret_cast <S*>(addr) = v;
@@ -144,6 +161,19 @@ struct PtBufferView {
144
161
}
145
162
146
163
bool isCompact () const { return compacted; }
164
+
165
+ bool isBitSet () const { return is_bitset; }
166
+
167
+ bool getBit (size_t idx) const {
168
+ SPU_ENFORCE (is_bitset);
169
+ auto el_idx = idx / 8 ;
170
+ auto bit_offset = idx % 8 ;
171
+
172
+ uint8_t mask = (1 << bit_offset);
173
+ uint8_t el = static_cast <uint8_t *>(ptr)[el_idx];
174
+
175
+ return (mask & el) != 0 ;
176
+ }
147
177
};
148
178
149
179
std::ostream& operator <<(std::ostream& out, PtBufferView v);
0 commit comments