Skip to content

Implement C code for ExtractDiagonal and ARange #1392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 89 additions & 35 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3207,13 +3207,14 @@
return A_replicated.reshape(tiled_shape)


class ARange(Op):
class ARange(COp):
"""Create an array containing evenly spaced values within a given interval.

Parameters and behaviour are the same as numpy.arange().

"""

# TODO: Arange should work with scalars as inputs, not arrays
Copy link
Preview

Copilot AI May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider addressing the TODO by supporting scalar inputs for ARange, which would simplify the interface and align behavior with numpy.arange.

Copilot uses AI. Check for mistakes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would just require some type checking in the perform method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_node as well

__props__ = ("dtype",)

def __init__(self, dtype):
Expand Down Expand Up @@ -3293,13 +3294,30 @@
)
]

def perform(self, node, inp, out_):
start, stop, step = inp
(out,) = out_
start = start.item()
stop = stop.item()
step = step.item()
out[0] = np.arange(start, stop, step, dtype=self.dtype)
def perform(self, node, inputs, output_storage):
start, stop, step = inputs
output_storage[0][0] = np.arange(
start.item(), stop.item(), step.item(), dtype=self.dtype
)

def c_code(self, node, nodename, input_names, output_names, sub):
[start_name, stop_name, step_name] = input_names
[out_name] = output_names
typenum = np.dtype(self.dtype).num
return f"""
double start = ((dtype_{start_name}*)PyArray_DATA({start_name}))[0];
double stop = ((dtype_{stop_name}*)PyArray_DATA({stop_name}))[0];
double step = ((dtype_{step_name}*)PyArray_DATA({step_name}))[0];
//printf("start: %f, stop: %f, step: %f\\n", start, stop, step);
Py_XDECREF({out_name});
{out_name} = (PyArrayObject*) PyArray_Arange(start, stop, step, {typenum});
if (!{out_name}) {{
{sub["fail"]}
}}
"""

def c_code_cache_version(self):
return (0,)

def connection_pattern(self, node):
return [[True], [False], [True]]
Expand Down Expand Up @@ -3685,8 +3703,7 @@
)


# TODO: optimization to insert ExtractDiag with view=True
class ExtractDiag(Op):
class ExtractDiag(COp):
"""
Return specified diagonals.

Expand Down Expand Up @@ -3742,7 +3759,7 @@

__props__ = ("offset", "axis1", "axis2", "view")

def __init__(self, offset=0, axis1=0, axis2=1, view=False):
def __init__(self, offset=0, axis1=0, axis2=1, view=True):
self.view = view
if self.view:
self.view_map = {0: [0]}
Expand All @@ -3765,24 +3782,74 @@
if x.ndim < 2:
raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x)

out_shape = [
st_dim
for i, st_dim in enumerate(x.type.shape)
if i not in (self.axis1, self.axis2)
] + [None]
if (dim1 := x.type.shape[self.axis1]) is not None and (
dim2 := x.type.shape[self.axis2]
) is not None:
offset = self.offset
if offset > 0:
diag_size = int(np.clip(dim2 - offset, 0, dim1))
elif offset < 0:
diag_size = int(np.clip(dim1 + offset, 0, dim2))
else:
diag_size = int(np.minimum(dim1, dim2))
else:
diag_size = None

out_shape = (
*(
dim
for i, dim in enumerate(x.type.shape)
if i not in (self.axis1, self.axis2)
),
diag_size,
)

return Apply(
self,
[x],
[x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()],
[x.type.clone(dtype=x.dtype, shape=out_shape)()],
)

def perform(self, node, inputs, outputs):
def perform(self, node, inputs, output_storage):
(x,) = inputs
(z,) = outputs
z[0] = x.diagonal(self.offset, self.axis1, self.axis2)
if not self.view:
z[0] = z[0].copy()
out = x.diagonal(self.offset, self.axis1, self.axis2)
if self.view:
try:
out.flags.writeable = True
Copy link
Member

@jessegrabowski jessegrabowski May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the tests don't cover it, i'm curious in what situation(s) we won't be able to do this inplace (I assume setting this writable flag is related to inplacing, but maybe I'm wrong)

Copy link
Member Author

@ricardoV94 ricardoV94 May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the base array is not writeable we can't set the diagonal view as writeable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import numpy as np
x = np.zeros(10)
x.flags.writeable = False
y = x[2:]
y.flags.writeable = True  # ValueError: cannot set WRITEABLE flag to True of this array

except ValueError:

Check warning on line 3819 in pytensor/tensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/basic.py#L3819

Added line #L3819 was not covered by tests
# We can't make this array writable
out = out.copy()

Check warning on line 3821 in pytensor/tensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/basic.py#L3821

Added line #L3821 was not covered by tests
else:
out = out.copy()

Check warning on line 3823 in pytensor/tensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/basic.py#L3823

Added line #L3823 was not covered by tests
output_storage[0][0] = out

def c_code(self, node, nodename, input_names, output_names, sub):
[x_name] = input_names
[out_name] = output_names
return f"""
Py_XDECREF({out_name});

{out_name} = (PyArrayObject*) PyArray_Diagonal({x_name}, {self.offset}, {self.axis1}, {self.axis2});
if (!{out_name}) {{
{sub["fail"]} // Error already set by Numpy
}}

if ({int(self.view)} && PyArray_ISWRITEABLE({x_name})) {{
// Make output writeable if input was writeable
PyArray_ENABLEFLAGS({out_name}, NPY_ARRAY_WRITEABLE);
}} else {{
// Make a copy
PyArrayObject *{out_name}_copy = (PyArrayObject*) PyArray_Copy({out_name});
Py_DECREF({out_name});
if (!{out_name}_copy) {{
{sub['fail']}; // Error already set by Numpy
}}
{out_name} = {out_name}_copy;
}}
"""

def c_code_cache_version(self):
return (0,)

def grad(self, inputs, gout):
# Avoid circular import
Expand Down Expand Up @@ -3829,19 +3896,6 @@
out_shape.append(diag_size)
return [tuple(out_shape)]

def __setstate__(self, state):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the purpose of this? I've never seen __setstate__ before

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pickling usually compatibility when they changed something or if the Op had state that it shouldn't have

self.__dict__.update(state)

if self.view:
self.view_map = {0: [0]}

if "offset" not in state:
self.offset = 0
if "axis1" not in state:
self.axis1 = 0
if "axis2" not in state:
self.axis2 = 1


def extract_diag(x):
warnings.warn(
Expand Down