Skip to content

Commit aa9b281

Browse files
committed
Make Split C-impl return a view
1 parent 5bed8a5 commit aa9b281

File tree

1 file changed

+61
-98
lines changed

1 file changed

+61
-98
lines changed

pytensor/tensor/basic.py

Lines changed: 61 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -2171,8 +2171,6 @@ class Split(COp):
21712171
array([3, 4])
21722172
>>> c
21732173
array([5])
2174-
2175-
TODO: Don't make a copy in C impl
21762174
"""
21772175

21782176
len_splits = None
@@ -2283,29 +2281,7 @@ def R_op(self, inputs, eval_points):
22832281
return self.make_node(eval_points[0], *inputs[1:]).outputs
22842282

22852283
def c_code_cache_version(self):
2286-
return (2,)
2287-
2288-
def c_support_code(self, **kwargs):
2289-
return """
2290-
/* Return 1 if output has the correct shape. */
2291-
int split_output_shape_is_correct (
2292-
PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
2293-
) {
2294-
return
2295-
PyArray_NDIM(output) == PyArray_NDIM(array_to_split)
2296-
&& memcmp(
2297-
PyArray_DIMS(output),
2298-
PyArray_DIMS(array_to_split),
2299-
axis_to_split * sizeof(npy_intp)
2300-
) == 0
2301-
&& memcmp(
2302-
PyArray_DIMS(output) + axis_to_split + 1,
2303-
PyArray_DIMS(array_to_split) + axis_to_split + 1,
2304-
(PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp)
2305-
) == 0
2306-
&& split_size == PyArray_DIM(output, axis_to_split);
2307-
}
2308-
"""
2284+
return (3,)
23092285

23102286
def c_code(self, node, name, inputs, outputs, sub):
23112287
if self.len_splits == 0:
@@ -2316,109 +2292,96 @@ def c_code(self, node, name, inputs, outputs, sub):
23162292
outputs_pointers = "&" + (", &".join(outputs))
23172293
x, axis, splits = inputs
23182294
fail = sub["fail"]
2319-
x_typenum = np.dtype(node.inputs[0].dtype).num
2320-
x_itemsize = np.dtype(node.inputs[0].dtype).itemsize
2321-
axis_dtype = node.inputs[1].type.dtype_specs()[1]
23222295
splits_dtype = node.inputs[2].type.dtype_specs()[1]
23232296
expected_splits_count = self.len_splits
2297+
ndim = node.inputs[0].type.ndim
2298+
2299+
# Most times axis is constant, inline it
2300+
# This is safe to do because the hash of the c_code includes the constant signature
2301+
if isinstance(node.inputs[1], Constant):
2302+
static_axis = int(node.inputs[1].data)
2303+
static_axis = normalize_axis_index(static_axis, ndim)
2304+
axis_def = f"{static_axis};"
2305+
axis_check = ""
2306+
else:
2307+
axis_dtype = node.inputs[1].type.dtype_specs()[1]
2308+
axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];"
2309+
axis_check = f"""
2310+
if (axis < 0){{
2311+
axis = ndim + axis;
2312+
}}
2313+
if (axis >= ndim || axis < 0) {{
2314+
PyErr_SetString(PyExc_ValueError, "Split axis is out of bounds");
2315+
{fail}
2316+
}}
2317+
"""
23242318

23252319
return f"""
2326-
int ndim = PyArray_NDIM({x});
2327-
int axis = (int)(*({axis_dtype}*)PyArray_GETPTR1({axis}, 0));
2320+
int ndim = {ndim};
2321+
int axis = {axis_def}
23282322
int splits_count = PyArray_DIM({splits}, 0);
2329-
npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0;
2330-
npy_intp* split_dims = NULL;
2331-
PyObject* split_view = NULL;
2332-
npy_intp data_offset;
2333-
int i;
2323+
npy_intp sum_of_splits = 0, current_split_start = 0;
23342324
PyArrayObject** outputs[] = {{{outputs_pointers}}};
2325+
npy_intp split_dims[ndim];
2326+
PyObject* split_view = NULL;
23352327
23362328
/* Check inputs. */
2337-
2338-
if (splits_count != {expected_splits_count}) {{
2339-
PyErr_Format(PyExc_ValueError,
2340-
"Split: splits count (%d) != expected count (%d).", splits_count, {expected_splits_count});
2329+
if (PyArray_NDIM({x}) != ndim) {{
2330+
PyErr_Format(PyExc_ValueError, "Input to Split does not have expected ndim");
23412331
{fail}
23422332
}}
2343-
2344-
if (axis < 0) {{
2345-
axis += ndim;
2346-
}}
2347-
if (axis < 0 || axis >= ndim) {{
2348-
PyErr_Format(PyExc_IndexError, "Split: invalid axis %d for a %d-D array.", axis, ndim);
2333+
if (splits_count != {expected_splits_count}) {{
2334+
PyErr_Format(PyExc_ValueError, "Split: splits count (%d) != expected count (%d).", splits_count, {expected_splits_count});
23492335
{fail}
23502336
}}
2351-
len_along_axis = PyArray_DIM({x}, axis);
23522337
2353-
for (i = 0; i < splits_count; ++i) {{
2354-
current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
2338+
{axis_check};
2339+
2340+
for (int i = 0; i < splits_count; ++i) {{
2341+
int current_split_length = (npy_intp)(*({splits_dtype}*)PyArray_GETPTR1({splits}, i));
23552342
if (current_split_length < 0) {{
23562343
PyErr_Format(PyExc_ValueError,
23572344
"Split: you try to take a negative number (%ld) of elements.", current_split_length);
23582345
{fail}
23592346
}}
23602347
sum_of_splits += current_split_length;
23612348
}}
2362-
if (sum_of_splits != len_along_axis) {{
2363-
PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, len_along_axis);
2364-
{fail}
2365-
}}
2366-
2367-
/* Check outputs. */
2368-
2369-
split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
2370-
if (split_dims == NULL) {{
2371-
PyErr_NoMemory();
2349+
if (sum_of_splits != PyArray_DIM({x}, axis)) {{
2350+
PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, PyArray_DIM({x}, axis));
23722351
{fail}
23732352
}}
23742353
2375-
memcpy(split_dims, PyArray_DIMS({x}), ndim * sizeof(npy_intp));
2376-
2377-
for (i = 0; i < splits_count; ++i) {{
2378-
PyArrayObject** output = outputs[i];
2379-
current_split_length = (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i));
2380-
if (*output == NULL || !split_output_shape_is_correct(*output, {x}, axis, current_split_length)) {{
2381-
Py_XDECREF(*output);
2382-
split_dims[axis] = current_split_length;
2383-
*output = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, {x_typenum}, PyArray_IS_F_CONTIGUOUS({x}));
2384-
if (outputs == NULL) {{
2385-
PyErr_SetString(PyExc_RuntimeError, "Split: unable to allocate an output.");
2386-
free(split_dims);
2387-
{fail}
2388-
}}
2389-
}}
2390-
}}
2391-
23922354
/* Compute split. */
2355+
memcpy(split_dims, PyArray_DIMS({x}), ndim * sizeof(npy_intp));
23932356
2394-
for (i = 0; i < splits_count; ++i) {{
2395-
current_split_length = (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i));
2396-
data_offset = PyArray_STRIDE({x}, axis) * current_split_start;
2397-
split_dims[axis] = current_split_length;
2398-
split_view = PyArray_New(&PyArray_Type,
2399-
ndim, split_dims,
2400-
{x_typenum},
2401-
PyArray_STRIDES({x}),
2402-
PyArray_BYTES({x}) + data_offset,
2403-
{x_itemsize},
2404-
PyArray_FLAGS({x}),
2405-
NULL);
2406-
if (split_view == NULL) {{
2357+
for (int i = 0; i < splits_count; ++i) {{
2358+
Py_XDECREF(*outputs[i]);
2359+
2360+
// Create view of input
2361+
PyArray_Descr *descr = PyArray_DESCR({x});
2362+
Py_INCREF(descr);
2363+
npy_intp data_offset = PyArray_STRIDE({x}, axis) * current_split_start;
2364+
*outputs[i] = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type,
2365+
descr, // PyArray_NewFromDescr steals this reference
2366+
ndim, split_dims,
2367+
PyArray_STRIDES({x}),
2368+
PyArray_BYTES({x}) + data_offset,
2369+
PyArray_FLAGS({x}) & ~NPY_ARRAY_OWNDATA,
2370+
NULL);
2371+
2372+
if (*outputs[i] == NULL) {{
24072373
PyErr_SetString(PyExc_RuntimeError, "Split: unable to create a view for a split.");
24082374
free(split_dims);
24092375
{fail}
24102376
}}
2411-
if (PyArray_CopyInto(*outputs[i], (PyArrayObject*)split_view) != 0) {{
2412-
PyErr_SetString(PyExc_RuntimeError, "Split: unable to copy a split view into the output.");
2413-
Py_XDECREF(split_view);
2414-
free(split_dims);
2415-
{fail}
2416-
}}
2417-
Py_XDECREF(split_view);
2418-
current_split_start += current_split_length;
2419-
}}
24202377
2421-
free(split_dims);
2378+
// Set as a view of input
2379+
Py_INCREF((PyObject*){x});
2380+
PyArray_SetBaseObject(*outputs[i], (PyObject*){x});
2381+
2382+
// Update split slice pointer
2383+
current_split_start += (npy_intp) (* ({splits_dtype}*) PyArray_GETPTR1({splits}, i));
2384+
}}
24222385
"""
24232386

24242387

0 commit comments

Comments
 (0)