Skip to content

Commit 7ceb542

Browse files
committed
BUG: Break on errors when performing strided casts.
Closed numpygh-15790.
1 parent 367c5a2 commit 7ceb542

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

numpy/core/src/multiarray/dtype_transfer.c

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ typedef struct {
312312
NpyAuxData *wrappeddata, *todata, *fromdata;
313313
npy_intp src_itemsize, dst_itemsize;
314314
char *bufferin, *bufferout;
315-
npy_bool init_dest;
315+
npy_bool init_dest, out_needs_api;
316316
} _align_wrap_data;
317317

318318
/* transfer data free function */
@@ -374,6 +374,7 @@ static NpyAuxData *_align_wrap_data_clone(NpyAuxData *data)
374374
}
375375

376376
newdata->init_dest = d->init_dest;
377+
newdata->out_needs_api = d->out_needs_api;
377378

378379
return (NpyAuxData *)newdata;
379380
}
@@ -394,7 +395,7 @@ _strided_to_strided_contig_align_wrap(char *dst, npy_intp dst_stride,
394395
*todata = d->todata,
395396
*fromdata = d->fromdata;
396397
char *bufferin = d->bufferin, *bufferout = d->bufferout;
397-
npy_bool init_dest = d->init_dest;
398+
npy_bool init_dest = d->init_dest, out_needs_api = d->out_needs_api;
398399

399400
for(;;) {
400401
if (N > NPY_LOWLEVEL_BUFFER_BLOCKSIZE) {
@@ -414,6 +415,9 @@ _strided_to_strided_contig_align_wrap(char *dst, npy_intp dst_stride,
414415
N -= NPY_LOWLEVEL_BUFFER_BLOCKSIZE;
415416
src += NPY_LOWLEVEL_BUFFER_BLOCKSIZE*src_stride;
416417
dst += NPY_LOWLEVEL_BUFFER_BLOCKSIZE*dst_stride;
418+
if (out_needs_api && PyErr_Occurred()) {
419+
return;
420+
}
417421
}
418422
else {
419423
tobuffer(bufferin, inner_src_itemsize, src, src_stride, N,
@@ -442,6 +446,7 @@ _strided_to_strided_contig_align_wrap(char *dst, npy_intp dst_stride,
442446
* wrapped - contig to contig transfer function being wrapped
443447
* wrappeddata - data for wrapped
444448
* init_dest - 1 means to memset the dest buffer to 0 before calling wrapped.
449+
* out_needs_api - if NPY_TRUE, check for (and break on) Python API errors.
445450
*
446451
* Returns NPY_SUCCEED or NPY_FAIL.
447452
*/
@@ -452,6 +457,7 @@ wrap_aligned_contig_transfer_function(
452457
PyArray_StridedUnaryOp *frombuffer, NpyAuxData *fromdata,
453458
PyArray_StridedUnaryOp *wrapped, NpyAuxData *wrappeddata,
454459
int init_dest,
460+
int out_needs_api,
455461
PyArray_StridedUnaryOp **out_stransfer,
456462
NpyAuxData **out_transferdata)
457463
{
@@ -485,6 +491,7 @@ wrap_aligned_contig_transfer_function(
485491
data->bufferout = data->bufferin +
486492
NPY_LOWLEVEL_BUFFER_BLOCKSIZE*src_itemsize;
487493
data->init_dest = (npy_bool) init_dest;
494+
data->out_needs_api = (npy_bool) out_needs_api;
488495

489496
/* Set the function and data */
490497
*out_stransfer = &_strided_to_strided_contig_align_wrap;
@@ -1132,6 +1139,7 @@ get_datetime_to_unicode_transfer_function(int aligned,
11321139
frombuffer, fromdata,
11331140
caststransfer, castdata,
11341141
PyDataType_FLAGCHK(str_dtype, NPY_NEEDS_INIT),
1142+
*out_needs_api,
11351143
out_stransfer, out_transferdata) != NPY_SUCCEED) {
11361144
NPY_AUXDATA_FREE(castdata);
11371145
NPY_AUXDATA_FREE(todata);
@@ -1254,6 +1262,7 @@ get_unicode_to_datetime_transfer_function(int aligned,
12541262
frombuffer, fromdata,
12551263
caststransfer, castdata,
12561264
PyDataType_FLAGCHK(dst_dtype, NPY_NEEDS_INIT),
1265+
*out_needs_api,
12571266
out_stransfer, out_transferdata) != NPY_SUCCEED) {
12581267
Py_DECREF(str_dtype);
12591268
NPY_AUXDATA_FREE(castdata);
@@ -1574,6 +1583,7 @@ get_cast_transfer_function(int aligned,
15741583
frombuffer, fromdata,
15751584
caststransfer, castdata,
15761585
PyDataType_FLAGCHK(dst_dtype, NPY_NEEDS_INIT),
1586+
*out_needs_api,
15771587
out_stransfer, out_transferdata) != NPY_SUCCEED) {
15781588
NPY_AUXDATA_FREE(castdata);
15791589
NPY_AUXDATA_FREE(todata);

0 commit comments

Comments
 (0)