@@ -71,12 +71,17 @@ void simplify_iteration_space_1(int &nd,
71
71
nd = contracted_nd;
72
72
}
73
73
else if (nd == 1 ) {
74
+ offset = 0 ;
74
75
// Populate vectors
75
76
simplified_shape.reserve (nd);
76
77
simplified_shape.push_back (shape[0 ]);
77
78
78
79
simplified_strides.reserve (nd);
79
- simplified_strides.push_back (strides[0 ]);
80
+ simplified_strides.push_back ((strides[0 ] >= 0 ) ? strides[0 ]
81
+ : -strides[0 ]);
82
+ if ((strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
83
+ offset += (shape[0 ] - 1 ) * strides[0 ];
84
+ }
80
85
81
86
assert (simplified_shape.size () == static_cast <size_t >(nd));
82
87
assert (simplified_strides.size () == static_cast <size_t >(nd));
@@ -128,17 +133,27 @@ void simplify_iteration_space(int &nd,
128
133
nd = contracted_nd;
129
134
}
130
135
else if (nd == 1 ) {
136
+ src_offset = 0 ;
137
+ dst_offset = 0 ;
131
138
// Populate vectors
132
139
simplified_shape.reserve (nd);
133
140
simplified_shape.push_back (shape[0 ]);
134
141
assert (simplified_shape.size () == static_cast <size_t >(nd));
135
142
136
143
simplified_src_strides.reserve (nd);
137
- simplified_src_strides.push_back (src_strides[0 ]);
144
+ simplified_src_strides.push_back (
145
+ (src_strides[0 ] >= 0 ) ? src_strides[0 ] : -src_strides[0 ]);
146
+ if ((src_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
147
+ src_offset += (shape[0 ] - 1 ) * src_strides[0 ];
148
+ }
138
149
assert (simplified_src_strides.size () == static_cast <size_t >(nd));
139
150
140
151
simplified_dst_strides.reserve (nd);
141
- simplified_dst_strides.push_back (dst_strides[0 ]);
152
+ simplified_dst_strides.push_back (
153
+ (dst_strides[0 ] >= 0 ) ? dst_strides[0 ] : -dst_strides[0 ]);
154
+ if ((dst_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
155
+ dst_offset += (shape[0 ] - 1 ) * dst_strides[0 ];
156
+ }
142
157
assert (simplified_dst_strides.size () == static_cast <size_t >(nd));
143
158
}
144
159
}
@@ -202,21 +217,36 @@ void simplify_iteration_space_3(
202
217
nd = contracted_nd;
203
218
}
204
219
else if (nd == 1 ) {
220
+ src1_offset = 0 ;
221
+ src2_offset = 0 ;
222
+ dst_offset = 0 ;
205
223
// Populate vectors
206
224
simplified_shape.reserve (nd);
207
225
simplified_shape.push_back (shape[0 ]);
208
226
assert (simplified_shape.size () == static_cast <size_t >(nd));
209
227
210
228
simplified_src1_strides.reserve (nd);
211
- simplified_src1_strides.push_back (src1_strides[0 ]);
229
+ simplified_src1_strides.push_back (
230
+ (src1_strides[0 ] >= 0 ) ? src1_strides[0 ] : -src1_strides[0 ]);
231
+ if ((src1_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
232
+ src1_offset += src1_strides[0 ] * (shape[0 ] - 1 );
233
+ }
212
234
assert (simplified_src1_strides.size () == static_cast <size_t >(nd));
213
235
214
236
simplified_src2_strides.reserve (nd);
215
- simplified_src2_strides.push_back (src2_strides[0 ]);
237
+ simplified_src2_strides.push_back (
238
+ (src2_strides[0 ] >= 0 ) ? src2_strides[0 ] : -src2_strides[0 ]);
239
+ if ((src2_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
240
+ src2_offset += src2_strides[0 ] * (shape[0 ] - 1 );
241
+ }
216
242
assert (simplified_src2_strides.size () == static_cast <size_t >(nd));
217
243
218
244
simplified_dst_strides.reserve (nd);
219
- simplified_dst_strides.push_back (dst_strides[0 ]);
245
+ simplified_dst_strides.push_back (
246
+ (dst_strides[0 ] >= 0 ) ? dst_strides[0 ] : -dst_strides[0 ]);
247
+ if ((dst_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
248
+ dst_offset += dst_strides[0 ] * (shape[0 ] - 1 );
249
+ }
220
250
assert (simplified_dst_strides.size () == static_cast <size_t >(nd));
221
251
}
222
252
}
@@ -293,29 +323,129 @@ void simplify_iteration_space_4(
293
323
nd = contracted_nd;
294
324
}
295
325
else if (nd == 1 ) {
326
+ src1_offset = 0 ;
327
+ src2_offset = 0 ;
328
+ src3_offset = 0 ;
329
+ dst_offset = 0 ;
296
330
// Populate vectors
297
331
simplified_shape.reserve (nd);
298
332
simplified_shape.push_back (shape[0 ]);
299
333
assert (simplified_shape.size () == static_cast <size_t >(nd));
300
334
301
335
simplified_src1_strides.reserve (nd);
302
- simplified_src1_strides.push_back (src1_strides[0 ]);
336
+ simplified_src1_strides.push_back (
337
+ (src1_strides[0 ] >= 0 ) ? src1_strides[0 ] : -src1_strides[0 ]);
338
+ if ((src1_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
339
+ src1_offset += src1_strides[0 ] * (shape[0 ] - 1 );
340
+ }
303
341
assert (simplified_src1_strides.size () == static_cast <size_t >(nd));
304
342
305
343
simplified_src2_strides.reserve (nd);
306
- simplified_src2_strides.push_back (src2_strides[0 ]);
344
+ simplified_src2_strides.push_back (
345
+ (src2_strides[0 ] >= 0 ) ? src2_strides[0 ] : -src2_strides[0 ]);
346
+ if ((src2_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
347
+ src2_offset += src2_strides[0 ] * (shape[0 ] - 1 );
348
+ }
307
349
assert (simplified_src2_strides.size () == static_cast <size_t >(nd));
308
350
309
351
simplified_src3_strides.reserve (nd);
310
- simplified_src3_strides.push_back (src3_strides[0 ]);
352
+ simplified_src3_strides.push_back (
353
+ (src3_strides[0 ] >= 0 ) ? src3_strides[0 ] : -src3_strides[0 ]);
354
+ if ((src3_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
355
+ src3_offset += src3_strides[0 ] * (shape[0 ] - 1 );
356
+ }
311
357
assert (simplified_src3_strides.size () == static_cast <size_t >(nd));
312
358
313
359
simplified_dst_strides.reserve (nd);
314
- simplified_dst_strides.push_back (dst_strides[0 ]);
360
+ simplified_dst_strides.push_back (
361
+ (dst_strides[0 ] >= 0 ) ? dst_strides[0 ] : -dst_strides[0 ]);
362
+ if ((dst_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
363
+ dst_offset += dst_strides[0 ] * (shape[0 ] - 1 );
364
+ }
315
365
assert (simplified_dst_strides.size () == static_cast <size_t >(nd));
316
366
}
317
367
}
318
368
369
+ py::ssize_t _ravel_multi_index_c (std::vector<py::ssize_t > const &mi,
370
+ std::vector<py::ssize_t > const &shape)
371
+ {
372
+ size_t nd = shape.size ();
373
+ if (nd != mi.size ()) {
374
+ throw py::value_error (
375
+ " Multi-index and shape vectors must have the same length." );
376
+ }
377
+
378
+ py::ssize_t flat_index = 0 ;
379
+ py::ssize_t s = 1 ;
380
+ for (size_t i = 0 ; i < nd; ++i) {
381
+ flat_index += mi.at (nd - 1 - i) * s;
382
+ s *= shape.at (nd - 1 - i);
383
+ }
384
+
385
+ return flat_index;
386
+ }
387
+
388
+ py::ssize_t _ravel_multi_index_f (std::vector<py::ssize_t > const &mi,
389
+ std::vector<py::ssize_t > const &shape)
390
+ {
391
+ size_t nd = shape.size ();
392
+ if (nd != mi.size ()) {
393
+ throw py::value_error (
394
+ " Multi-index and shape vectors must have the same length." );
395
+ }
396
+
397
+ py::ssize_t flat_index = 0 ;
398
+ py::ssize_t s = 1 ;
399
+ for (size_t i = 0 ; i < nd; ++i) {
400
+ flat_index += mi.at (i) * s;
401
+ s *= shape.at (i);
402
+ }
403
+
404
+ return flat_index;
405
+ }
406
+
407
+ std::vector<py::ssize_t > _unravel_index_c (py::ssize_t flat_index,
408
+ std::vector<py::ssize_t > const &shape)
409
+ {
410
+ size_t nd = shape.size ();
411
+ std::vector<py::ssize_t > mi;
412
+ mi.resize (nd);
413
+
414
+ py::ssize_t i_ = flat_index;
415
+ for (size_t dim = 0 ; dim + 1 < nd; ++dim) {
416
+ const py::ssize_t si = shape[nd - 1 - dim];
417
+ const py::ssize_t q = i_ / si;
418
+ const py::ssize_t r = (i_ - q * si);
419
+ mi[nd - 1 - dim] = r;
420
+ i_ = q;
421
+ }
422
+ if (nd) {
423
+ mi[0 ] = i_;
424
+ }
425
+ return mi;
426
+ }
427
+
428
+ std::vector<py::ssize_t > _unravel_index_f (py::ssize_t flat_index,
429
+ std::vector<py::ssize_t > const &shape)
430
+ {
431
+ size_t nd = shape.size ();
432
+ std::vector<py::ssize_t > mi;
433
+ mi.resize (nd);
434
+
435
+ py::ssize_t i_ = flat_index;
436
+ for (size_t dim = 0 ; dim + 1 < nd; ++dim) {
437
+ const py::ssize_t si = shape[dim];
438
+ const py::ssize_t q = i_ / si;
439
+ const py::ssize_t r = (i_ - q * si);
440
+ mi[dim] = r;
441
+ i_ = q;
442
+ }
443
+ if (nd) {
444
+ mi[nd - 1 ] = i_;
445
+ }
446
+ return mi;
447
+ }
448
+
319
449
} // namespace py_internal
320
450
} // namespace tensor
321
451
} // namespace dpctl
0 commit comments