@@ -2171,8 +2171,6 @@ class Split(COp):
2171
2171
array([3, 4])
2172
2172
>>> c
2173
2173
array([5])
2174
-
2175
- TODO: Don't make a copy in C impl
2176
2174
"""
2177
2175
2178
2176
len_splits = None
@@ -2283,29 +2281,7 @@ def R_op(self, inputs, eval_points):
2283
2281
return self .make_node (eval_points [0 ], * inputs [1 :]).outputs
2284
2282
2285
2283
def c_code_cache_version (self ):
2286
- return (2 ,)
2287
-
2288
- def c_support_code (self , ** kwargs ):
2289
- return """
2290
- /* Return 1 if output has the correct shape. */
2291
- int split_output_shape_is_correct (
2292
- PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
2293
- ) {
2294
- return
2295
- PyArray_NDIM(output) == PyArray_NDIM(array_to_split)
2296
- && memcmp(
2297
- PyArray_DIMS(output),
2298
- PyArray_DIMS(array_to_split),
2299
- axis_to_split * sizeof(npy_intp)
2300
- ) == 0
2301
- && memcmp(
2302
- PyArray_DIMS(output) + axis_to_split + 1,
2303
- PyArray_DIMS(array_to_split) + axis_to_split + 1,
2304
- (PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp)
2305
- ) == 0
2306
- && split_size == PyArray_DIM(output, axis_to_split);
2307
- }
2308
- """
2284
+ return (3 ,)
2309
2285
2310
2286
def c_code (self , node , name , inputs , outputs , sub ):
2311
2287
if self .len_splits == 0 :
@@ -2316,109 +2292,96 @@ def c_code(self, node, name, inputs, outputs, sub):
2316
2292
outputs_pointers = "&" + (", &" .join (outputs ))
2317
2293
x , axis , splits = inputs
2318
2294
fail = sub ["fail" ]
2319
- x_typenum = np .dtype (node .inputs [0 ].dtype ).num
2320
- x_itemsize = np .dtype (node .inputs [0 ].dtype ).itemsize
2321
- axis_dtype = node .inputs [1 ].type .dtype_specs ()[1 ]
2322
2295
splits_dtype = node .inputs [2 ].type .dtype_specs ()[1 ]
2323
2296
expected_splits_count = self .len_splits
2297
+ ndim = node .inputs [0 ].type .ndim
2298
+
2299
+ # Most times axis is constant, inline it
2300
+ # This is safe to do because the hash of the c_code includes the constant signature
2301
+ if isinstance (node .inputs [1 ], Constant ):
2302
+ static_axis = int (node .inputs [1 ].data )
2303
+ static_axis = normalize_axis_index (static_axis , ndim )
2304
+ axis_def = f"{ static_axis } ;"
2305
+ axis_check = ""
2306
+ else :
2307
+ axis_dtype = node .inputs [1 ].type .dtype_specs ()[1 ]
2308
+ axis_def = f"(({ axis_dtype } *)PyArray_DATA({ axis } ))[0];"
2309
+ axis_check = f"""
2310
+ if (axis < 0){{
2311
+ axis = ndim + axis;
2312
+ }}
2313
+ if (axis >= ndim || axis < 0) {{
2314
+ PyErr_SetString(PyExc_ValueError, "Split axis is out of bounds");
2315
+ { fail }
2316
+ }}
2317
+ """
2324
2318
2325
2319
return f"""
2326
- int ndim = PyArray_NDIM( { x } ) ;
2327
- int axis = (int)(*( { axis_dtype } *)PyArray_GETPTR1( { axis } , 0));
2320
+ int ndim = { ndim } ;
2321
+ int axis = { axis_def }
2328
2322
int splits_count = PyArray_DIM({ splits } , 0);
2329
- npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0;
2330
- npy_intp* split_dims = NULL;
2331
- PyObject* split_view = NULL;
2332
- npy_intp data_offset;
2333
- int i;
2323
+ npy_intp sum_of_splits = 0, current_split_start = 0;
2334
2324
PyArrayObject** outputs[] = {{{ outputs_pointers } }};
2325
+ npy_intp split_dims[ndim];
2326
+ PyObject* split_view = NULL;
2335
2327
2336
2328
/* Check inputs. */
2337
-
2338
- if (splits_count != { expected_splits_count } ) {{
2339
- PyErr_Format(PyExc_ValueError,
2340
- "Split: splits count (%d) != expected count (%d).", splits_count, { expected_splits_count } );
2329
+ if (PyArray_NDIM({ x } ) != ndim) {{
2330
+ PyErr_Format(PyExc_ValueError, "Input to Split does not have expected ndim");
2341
2331
{ fail }
2342
2332
}}
2343
-
2344
- if (axis < 0) {{
2345
- axis += ndim;
2346
- }}
2347
- if (axis < 0 || axis >= ndim) {{
2348
- PyErr_Format(PyExc_IndexError, "Split: invalid axis %d for a %d-D array.", axis, ndim);
2333
+ if (splits_count != { expected_splits_count } ) {{
2334
+ PyErr_Format(PyExc_ValueError, "Split: splits count (%d) != expected count (%d).", splits_count, { expected_splits_count } );
2349
2335
{ fail }
2350
2336
}}
2351
- len_along_axis = PyArray_DIM({ x } , axis);
2352
2337
2353
- for (i = 0; i < splits_count; ++i) {{
2354
- current_split_length = (npy_intp)(*({ splits_dtype } *)PyArray_GETPTR1({ splits } , i));
2338
+ { axis_check } ;
2339
+
2340
+ for (int i = 0; i < splits_count; ++i) {{
2341
+ int current_split_length = (npy_intp)(*({ splits_dtype } *)PyArray_GETPTR1({ splits } , i));
2355
2342
if (current_split_length < 0) {{
2356
2343
PyErr_Format(PyExc_ValueError,
2357
2344
"Split: you try to take a negative number (%ld) of elements.", current_split_length);
2358
2345
{ fail }
2359
2346
}}
2360
2347
sum_of_splits += current_split_length;
2361
2348
}}
2362
- if (sum_of_splits != len_along_axis) {{
2363
- PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, len_along_axis);
2364
- { fail }
2365
- }}
2366
-
2367
- /* Check outputs. */
2368
-
2369
- split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
2370
- if (split_dims == NULL) {{
2371
- PyErr_NoMemory();
2349
+ if (sum_of_splits != PyArray_DIM({ x } , axis)) {{
2350
+ PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, PyArray_DIM({ x } , axis));
2372
2351
{ fail }
2373
2352
}}
2374
2353
2375
- memcpy(split_dims, PyArray_DIMS({ x } ), ndim * sizeof(npy_intp));
2376
-
2377
- for (i = 0; i < splits_count; ++i) {{
2378
- PyArrayObject** output = outputs[i];
2379
- current_split_length = (npy_intp) (* ({ splits_dtype } *) PyArray_GETPTR1({ splits } , i));
2380
- if (*output == NULL || !split_output_shape_is_correct(*output, { x } , axis, current_split_length)) {{
2381
- Py_XDECREF(*output);
2382
- split_dims[axis] = current_split_length;
2383
- *output = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, { x_typenum } , PyArray_IS_F_CONTIGUOUS({ x } ));
2384
- if (outputs == NULL) {{
2385
- PyErr_SetString(PyExc_RuntimeError, "Split: unable to allocate an output.");
2386
- free(split_dims);
2387
- { fail }
2388
- }}
2389
- }}
2390
- }}
2391
-
2392
2354
/* Compute split. */
2355
+ memcpy(split_dims, PyArray_DIMS({ x } ), ndim * sizeof(npy_intp));
2393
2356
2394
- for (i = 0; i < splits_count; ++i) {{
2395
- current_split_length = (npy_intp) (* ({ splits_dtype } *) PyArray_GETPTR1({ splits } , i));
2396
- data_offset = PyArray_STRIDE({ x } , axis) * current_split_start;
2397
- split_dims[axis] = current_split_length;
2398
- split_view = PyArray_New(&PyArray_Type,
2399
- ndim, split_dims,
2400
- { x_typenum } ,
2401
- PyArray_STRIDES({ x } ),
2402
- PyArray_BYTES({ x } ) + data_offset,
2403
- { x_itemsize } ,
2404
- PyArray_FLAGS({ x } ),
2405
- NULL);
2406
- if (split_view == NULL) {{
2357
+ for (int i = 0; i < splits_count; ++i) {{
2358
+ Py_XDECREF(*outputs[i]);
2359
+
2360
+ // Create view of input
2361
+ PyArray_Descr *descr = PyArray_DESCR({ x } );
2362
+ Py_INCREF(descr);
2363
+ npy_intp data_offset = PyArray_STRIDE({ x } , axis) * current_split_start;
2364
+ *outputs[i] = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type,
2365
+ descr, // PyArray_NewFromDescr steals this reference
2366
+ ndim, split_dims,
2367
+ PyArray_STRIDES({ x } ),
2368
+ PyArray_BYTES({ x } ) + data_offset,
2369
+ PyArray_FLAGS({ x } ) & ~NPY_ARRAY_OWNDATA,
2370
+ NULL);
2371
+
2372
+ if (*outputs[i] == NULL) {{
2407
2373
PyErr_SetString(PyExc_RuntimeError, "Split: unable to create a view for a split.");
2408
2374
free(split_dims);
2409
2375
{ fail }
2410
2376
}}
2411
- if (PyArray_CopyInto(*outputs[i], (PyArrayObject*)split_view) != 0) {{
2412
- PyErr_SetString(PyExc_RuntimeError, "Split: unable to copy a split view into the output.");
2413
- Py_XDECREF(split_view);
2414
- free(split_dims);
2415
- { fail }
2416
- }}
2417
- Py_XDECREF(split_view);
2418
- current_split_start += current_split_length;
2419
- }}
2420
2377
2421
- free(split_dims);
2378
+ // Set as a view of input
2379
+ Py_INCREF((PyObject*){ x } );
2380
+ PyArray_SetBaseObject(*outputs[i], (PyObject*){ x } );
2381
+
2382
+ // Update split slice pointer
2383
+ current_split_start += (npy_intp) (* ({ splits_dtype } *) PyArray_GETPTR1({ splits } , i));
2384
+ }}
2422
2385
"""
2423
2386
2424
2387
0 commit comments