@@ -28,7 +28,7 @@ namespace spu::mpc::cheetah {
28
28
namespace {
29
29
// Return num_workers for the given size of jobs
30
30
size_t InitOTState (KernelEvalContext* ctx, size_t njobs) {
31
- constexpr size_t kMinWorkSize = 5000 ;
31
+ constexpr size_t kMinWorkSize = 2048 ;
32
32
if (njobs == 0 ) {
33
33
return 0 ;
34
34
}
@@ -139,86 +139,44 @@ std::array<NdArrayRef, 3> CheetahMulState::TakeCachedBeaver(FieldType field,
139
139
140
140
NdArrayRef TiledDispatchOTFunc (KernelEvalContext* ctx, const NdArrayRef& x,
141
141
OTUnaryFunc func) {
142
- Shape shape = x.shape ();
142
+ const Shape& shape = x.shape ();
143
+ SPU_ENFORCE (shape.numel () > 0 );
143
144
// (lazy) init OT
144
145
int64_t numel = x.numel ();
145
146
int64_t nworker = InitOTState (ctx, numel);
146
147
int64_t workload = nworker == 0 ? 0 : CeilDiv (numel, nworker);
147
148
148
- int64_t slicing_dim = -1 ;
149
- int64_t slice_numel = 1 ;
150
- for (int64_t dim = shape.size () - 1 ; dim >= 0 ; dim--) {
151
- slice_numel *= shape[dim];
152
- if (slice_numel > workload) {
153
- slice_numel /= shape[dim];
154
- slicing_dim = dim;
155
- break ;
156
- }
157
- }
158
-
159
- // get the slice num in the left outer dimensions
160
- int64_t num_slice = 1 ;
161
- for (int64_t dim = 0 ; dim < slicing_dim; dim++) {
162
- num_slice *= shape[dim];
163
- }
164
-
165
- int64_t slice_stride = (workload + slice_numel - 1 ) / slice_numel;
166
- if (slice_stride == 1 ) {
167
- return func (x, ctx->getState <CheetahOTState>()->get (0 ));
168
- }
169
-
170
- int64_t num_slice_dim = shape[slicing_dim] / slice_stride +
171
- ((shape[slicing_dim] % slice_stride) != 0 ? 1 : 0 );
172
-
173
- // initialize slice indices
174
- Index start_indices (shape.size ());
175
- Index end_indices (shape.begin (), shape.end ());
176
- end_indices[slicing_dim] = slice_stride;
177
- for (int64_t dim = slicing_dim - 1 ; dim >= 0 ; dim--) {
178
- end_indices[dim] = 1 ;
149
+ if (shape.ndim () != 1 ) {
150
+ // TiledDispatchOTFunc over flatten input
151
+ return TiledDispatchOTFunc (ctx, x.reshape ({numel}), func)
152
+ .reshape (x.shape ());
179
153
}
180
154
181
- SPU_ENFORCE_LE (num_slice * num_slice_dim, nworker);
182
- nworker = num_slice * num_slice_dim;
183
-
184
155
std::vector<NdArrayRef> outs (nworker);
185
156
std::vector<std::future<void >> futures;
186
157
187
- Index sidx = start_indices;
188
- Index eidx = end_indices;
189
- for (int64_t wi = 0 ; wi < nworker; ++wi) {
190
- auto slice_input = x.slice (sidx, eidx, {});
158
+ int64_t slice_end = 0 ;
159
+ for (int64_t wi = 0 ; wi + 1 < nworker; ++wi) {
160
+ int64_t slice_bgn = wi * workload;
161
+ slice_end = std::min (numel, slice_bgn + workload);
162
+ auto slice_input = x.slice ({slice_bgn}, {slice_end}, {});
191
163
futures.emplace_back (std::async (
192
164
[&](int64_t idx, const NdArrayRef& input) {
193
165
auto ot_instance = ctx->getState <CheetahOTState>()->get (idx);
194
166
outs[idx] = func (input, ot_instance);
195
167
},
196
168
wi, slice_input));
197
-
198
- // update indices
199
- if (0 == (eidx[slicing_dim] % shape[slicing_dim])) {
200
- // carray out
201
- sidx[slicing_dim] = 0 ;
202
- eidx[slicing_dim] = slice_stride;
203
- for (int64_t dim = slicing_dim - 1 ; dim >= 0 ; dim--) {
204
- sidx[dim] = (sidx[dim] + 1 ) % shape[dim];
205
- eidx[dim] = eidx[dim] % shape[dim] + 1 ;
206
- if (eidx[dim] != 1 ) {
207
- break ;
208
- }
209
- }
210
- } else {
211
- sidx[slicing_dim] += slice_stride;
212
- eidx[slicing_dim] += slice_stride;
213
- eidx[slicing_dim] = std::min (shape[slicing_dim], eidx[slicing_dim]);
214
- }
215
169
}
216
170
171
+ auto slice_input = x.slice ({slice_end}, {numel}, {1 });
172
+ auto ot_instance = ctx->getState <CheetahOTState>()->get (nworker - 1 );
173
+ outs[nworker - 1 ] = func (slice_input, ot_instance);
174
+
217
175
for (auto && f : futures) {
218
176
f.get ();
219
177
}
220
178
221
- NdArrayRef out (x .eltype (), x.shape ());
179
+ NdArrayRef out (outs[ 0 ] .eltype (), x.shape ());
222
180
int64_t offset = 0 ;
223
181
224
182
for (auto & out_slice : outs) {
@@ -232,89 +190,50 @@ NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x,
232
190
233
191
NdArrayRef TiledDispatchOTFunc (KernelEvalContext* ctx, const NdArrayRef& x,
234
192
const NdArrayRef& y, OTBinaryFunc func) {
235
- Shape shape = x.shape ();
236
- SPU_ENFORCE_EQ (x.shape (), y.shape ());
193
+ const Shape& shape = x.shape ();
194
+ SPU_ENFORCE (shape.numel () > 0 );
195
+ SPU_ENFORCE_EQ (shape, y.shape ());
237
196
// (lazy) init OT
238
197
int64_t numel = x.numel ();
239
198
int64_t nworker = InitOTState (ctx, numel);
240
199
int64_t workload = nworker == 0 ? 0 : CeilDiv (numel, nworker);
241
200
242
- int64_t slicing_dim = -1 ;
243
- int64_t slice_numel = 1 ;
244
- for (int64_t dim = shape.size () - 1 ; dim >= 0 ; dim--) {
245
- slice_numel *= shape[dim];
246
- if (slice_numel > workload) {
247
- slice_numel /= shape[dim];
248
- slicing_dim = dim;
249
- break ;
250
- }
201
+ if (shape.ndim () != 1 ) {
202
+ // TiledDispatchOTFunc over flatten input
203
+ return TiledDispatchOTFunc (ctx, x.reshape ({numel}), y.reshape ({numel}),
204
+ func)
205
+ .reshape (x.shape ());
251
206
}
252
207
253
- // get the slice num in the left outer dimensions
254
- int64_t num_slice = 1 ;
255
- for (int64_t dim = 0 ; dim < slicing_dim; dim++) {
256
- num_slice *= shape[dim];
257
- }
258
-
259
- int64_t slice_stride = (workload + slice_numel - 1 ) / slice_numel;
260
- if (slice_stride == 1 ) {
261
- return func (x, y, ctx->getState <CheetahOTState>()->get (0 ));
262
- }
263
-
264
- int64_t num_slice_dim = shape[slicing_dim] / slice_stride +
265
- ((shape[slicing_dim] % slice_stride) != 0 ? 1 : 0 );
266
-
267
- // initialize slice indices
268
- Index start_indices (shape.size ());
269
- Index end_indices (shape.begin (), shape.end ());
270
- end_indices[slicing_dim] = slice_stride;
271
- for (int64_t dim = slicing_dim - 1 ; dim >= 0 ; dim--) {
272
- end_indices[dim] = 1 ;
273
- }
274
-
275
- SPU_ENFORCE_LE (num_slice * num_slice_dim, nworker);
276
- nworker = num_slice * num_slice_dim;
277
-
278
208
std::vector<NdArrayRef> outs (nworker);
279
209
std::vector<std::future<void >> futures;
280
210
281
- Index sidx = start_indices ;
282
- Index eidx = end_indices;
283
- for ( int64_t wi = 0 ; wi < nworker; ++wi) {
284
- auto x_slice = x. slice (sidx, eidx, {} );
285
- auto y_slice = y .slice (sidx, eidx , {});
286
-
211
+ int64_t slice_end = 0 ;
212
+ for ( int64_t wi = 0 ; wi + 1 < nworker; ++wi) {
213
+ int64_t slice_bgn = wi * workload;
214
+ slice_end = std::min (numel, slice_bgn + workload );
215
+ auto x_slice = x .slice ({slice_bgn}, {slice_end} , {1 });
216
+ auto y_slice = y. slice ({slice_bgn}, {slice_end}, { 1 });
287
217
futures.emplace_back (std::async (
288
- [&](int64_t idx, const NdArrayRef& input0 , const NdArrayRef& input1 ) {
218
+ [&](int64_t idx, const NdArrayRef& inp0 , const NdArrayRef& inp1 ) {
289
219
auto ot_instance = ctx->getState <CheetahOTState>()->get (idx);
290
- outs[idx] = func (input0, input1 , ot_instance);
220
+ outs[idx] = func (inp0, inp1 , ot_instance);
291
221
},
292
222
wi, x_slice, y_slice));
293
-
294
- // update indices
295
- if (0 == (eidx[slicing_dim] % shape[slicing_dim])) {
296
- // carray out
297
- sidx[slicing_dim] = 0 ;
298
- eidx[slicing_dim] = slice_stride;
299
- for (int64_t dim = slicing_dim - 1 ; dim >= 0 ; dim--) {
300
- sidx[dim] = (sidx[dim] + 1 ) % shape[dim];
301
- eidx[dim] = eidx[dim] % shape[dim] + 1 ;
302
- if (eidx[dim] != 1 ) {
303
- break ;
304
- }
305
- }
306
- } else {
307
- sidx[slicing_dim] += slice_stride;
308
- eidx[slicing_dim] += slice_stride;
309
- eidx[slicing_dim] = std::min (shape[slicing_dim], eidx[slicing_dim]);
310
- }
311
223
}
224
+
225
+ auto x_slice = x.slice ({slice_end}, {numel}, {});
226
+ auto y_slice = y.slice ({slice_end}, {numel}, {});
227
+ auto ot_instance = ctx->getState <CheetahOTState>()->get (nworker - 1 );
228
+ outs[nworker - 1 ] = func (x_slice, y_slice, ot_instance);
229
+
312
230
for (auto && f : futures) {
313
231
f.get ();
314
232
}
315
233
316
- NdArrayRef out (x .eltype (), x.shape ());
234
+ NdArrayRef out (outs[ 0 ] .eltype (), x.shape ());
317
235
int64_t offset = 0 ;
236
+
318
237
for (auto & out_slice : outs) {
319
238
std::memcpy (out.data <std::byte>() + offset, out_slice.data (),
320
239
out_slice.numel () * out.elsize ());
0 commit comments