Skip to content

Commit c36f731

Browse files
committed
Do not return Constants in shape Op
1 parent db673f0 commit c36f731

File tree

7 files changed

+18
-23
lines changed

7 files changed

+18
-23
lines changed

pytensor/tensor/shape.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,7 @@ def shape(x: Union[np.ndarray, Number, Variable]) -> Variable:
146146
if not isinstance(x, Variable):
147147
x = at.as_tensor_variable(x)
148148

149-
x_type = x.type
150-
151-
if isinstance(x_type, TensorType) and all(s is not None for s in x_type.shape):
152-
res = at.as_tensor_variable(x_type.shape, ndim=1, dtype=np.int64)
153-
else:
154-
res = _shape(x)
155-
156-
return res
149+
return _shape(x)
157150

158151

159152
@_get_vector_length.register(Shape)

pytensor/tensor/var.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.graph.basic import Constant, OptionalApplyType, Variable
1313
from pytensor.graph.utils import MetaType
1414
from pytensor.scalar import ComplexError, IntegerDivisionError
15-
from pytensor.tensor import _get_vector_length, as_tensor_variable
15+
from pytensor.tensor import _get_vector_length
1616
from pytensor.tensor.exceptions import AdvancedIndexingError
1717
from pytensor.tensor.type import TensorType
1818
from pytensor.tensor.type_other import NoneConst
@@ -259,9 +259,6 @@ def transpose(self, *axes):
259259

260260
@property
261261
def shape(self):
262-
if not any(s is None for s in self.type.shape):
263-
return as_tensor_variable(self.type.shape, ndim=1, dtype=np.int64)
264-
265262
return at.shape(self)
266263

267264
@property

tests/scan/test_rewriting.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,13 +1477,12 @@ def test_while_scan_taps_and_map(self):
14771477
f(x0=0, seq=test_seq, n_steps=0)
14781478

14791479
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
1480-
# If a MissingInputError is raised, it means the rewrite failed
14811480
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
14821481
_, _, ys_trace, len_zs = scan_node.inputs
14831482
debug_fn = pytensor.function(
1484-
[n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
1483+
[x0, n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
14851484
)
1486-
stored_ys_steps, stored_zs_steps = debug_fn(n_steps=200)
1485+
stored_ys_steps, stored_zs_steps = debug_fn(x0=0, n_steps=200)
14871486
assert stored_ys_steps == 2
14881487
assert stored_zs_steps == 1
14891488

tests/tensor/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3486,7 +3486,7 @@ def test_vector(self):
34863486
def test_scalar(self):
34873487
x = scalar()
34883488
y = np.array(7, dtype=config.floatX)
3489-
assert y.size == function([], x.size)()
3489+
assert y.size == function([x], x.size)(y)
34903490

34913491
def test_shared(self):
34923492
# NB: we also test higher order tensors at the same time.

tests/tensor/test_nlinalg.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pytensor
88
from pytensor import function
99
from pytensor.configdefaults import config
10-
from pytensor.graph.basic import Constant
1110
from pytensor.tensor.math import _allclose
1211
from pytensor.tensor.nlinalg import (
1312
SVD,
@@ -274,9 +273,7 @@ def test_det_grad():
274273

275274
def test_det_shape():
276275
x = matrix()
277-
det_shape = det(x).shape
278-
assert isinstance(det_shape, Constant)
279-
assert tuple(det_shape.data) == ()
276+
assert det(x).type.shape == ()
280277

281278

282279
def test_slogdet():

tests/tensor/test_shape.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytensor.tensor.rewriting.shape import ShapeFeature
1717
from pytensor.tensor.shape import (
1818
Reshape,
19+
Shape,
1920
Shape_i,
2021
SpecifyShape,
2122
Unbroadcast,
@@ -397,7 +398,7 @@ def test_fixed_shapes(self):
397398
shape = as_tensor_variable([2])
398399
y = specify_shape(x, shape)
399400
assert y.type.shape == (2,)
400-
assert y.shape.equals(shape)
401+
assert isinstance(y.shape.owner.op, Shape)
401402

402403
def test_fixed_partial_shapes(self):
403404
x = TensorType("floatX", (None, None))("x")

tests/tensor/test_var.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
import pytensor
88
import tests.unittest_tools as utt
9+
from pytensor.compile import DeepCopyOp
910
from pytensor.compile.mode import get_default_mode
1011
from pytensor.graph.basic import Constant, equal_computations
1112
from pytensor.tensor import get_vector_length
1213
from pytensor.tensor.basic import constant
1314
from pytensor.tensor.elemwise import DimShuffle
1415
from pytensor.tensor.math import dot, eq
16+
from pytensor.tensor.shape import Shape
1517
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
1618
from pytensor.tensor.type import (
1719
TensorType,
@@ -245,8 +247,14 @@ def test__getitem__newaxis(x, indices, new_order):
245247

246248
def test_fixed_shape_variable_basic():
247249
x = TensorVariable(TensorType("int64", shape=(4,)), None)
248-
assert isinstance(x.shape, Constant)
249-
assert np.array_equal(x.shape.data, (4,))
250+
assert x.type.shape == (4,)
251+
assert isinstance(x.shape.owner.op, Shape)
252+
253+
shape_fn = pytensor.function([x], x.shape)
254+
opt_shape = shape_fn.maker.fgraph.outputs[0]
255+
assert isinstance(opt_shape.owner.op, DeepCopyOp)
256+
assert isinstance(opt_shape.owner.inputs[0], Constant)
257+
assert np.array_equal(opt_shape.owner.inputs[0].data, (4,))
250258

251259
x = TensorConstant(
252260
TensorType("int64", shape=(None, None)), np.array([[1, 2], [2, 3]])

0 commit comments

Comments
 (0)