Skip to content

Commit c956086

Browse files
committed
Do not return Constants in shape Op
1 parent 5e04917 commit c956086

File tree

6 files changed

+13
-19
lines changed

6 files changed

+13
-19
lines changed

pytensor/tensor/shape.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,7 @@ def shape(x: Union[np.ndarray, Number, Variable]) -> Variable:
146146
if not isinstance(x, Variable):
147147
x = at.as_tensor_variable(x)
148148

149-
x_type = x.type
150-
151-
if isinstance(x_type, TensorType) and all(s is not None for s in x_type.shape):
152-
res = at.as_tensor_variable(x_type.shape, ndim=1, dtype=np.int64)
153-
else:
154-
res = _shape(x)
155-
156-
return res
149+
return _shape(x)
157150

158151

159152
@_get_vector_length.register(Shape)

pytensor/tensor/var.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.graph.basic import Constant, OptionalApplyType, Variable
1313
from pytensor.graph.utils import MetaType
1414
from pytensor.scalar import ComplexError, IntegerDivisionError
15-
from pytensor.tensor import _get_vector_length, as_tensor_variable
15+
from pytensor.tensor import _get_vector_length
1616
from pytensor.tensor.exceptions import AdvancedIndexingError
1717
from pytensor.tensor.type import TensorType
1818
from pytensor.tensor.type_other import NoneConst
@@ -259,9 +259,6 @@ def transpose(self, *axes):
259259

260260
@property
261261
def shape(self):
262-
if not any(s is None for s in self.type.shape):
263-
return as_tensor_variable(self.type.shape, ndim=1, dtype=np.int64)
264-
265262
return at.shape(self)
266263

267264
@property

tests/scan/test_rewriting.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,13 +1477,12 @@ def test_while_scan_taps_and_map(self):
14771477
f(x0=0, seq=test_seq, n_steps=0)
14781478

14791479
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
1480-
# If a MissingInputError is raised, it means the rewrite failed
14811480
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
14821481
_, _, ys_trace, len_zs = scan_node.inputs
14831482
debug_fn = pytensor.function(
1484-
[n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
1483+
[x0, n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
14851484
)
1486-
stored_ys_steps, stored_zs_steps = debug_fn(n_steps=200)
1485+
stored_ys_steps, stored_zs_steps = debug_fn(x0=0, n_steps=200)
14871486
assert stored_ys_steps == 2
14881487
assert stored_zs_steps == 1
14891488

tests/tensor/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3486,7 +3486,7 @@ def test_vector(self):
34863486
def test_scalar(self):
34873487
x = scalar()
34883488
y = np.array(7, dtype=config.floatX)
3489-
assert y.size == function([], x.size)()
3489+
assert y.size == function([x], x.size)(y)
34903490

34913491
def test_shared(self):
34923492
# NB: we also test higher order tensors at the same time.

tests/tensor/test_shape.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ def test_fixed_shapes(self):
397397
shape = as_tensor_variable([2])
398398
y = specify_shape(x, shape)
399399
assert y.type.shape == (2,)
400-
assert y.shape.equals(shape)
401400

402401
def test_fixed_partial_shapes(self):
403402
x = TensorType("floatX", (None, None))("x")

tests/tensor/test_var.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytensor
88
import tests.unittest_tools as utt
9+
from pytensor.compile import DeepCopyOp
910
from pytensor.compile.mode import get_default_mode
1011
from pytensor.graph.basic import Constant, equal_computations
1112
from pytensor.tensor import get_vector_length
@@ -245,8 +246,13 @@ def test__getitem__newaxis(x, indices, new_order):
245246

246247
def test_fixed_shape_variable_basic():
247248
x = TensorVariable(TensorType("int64", shape=(4,)), None)
248-
assert isinstance(x.shape, Constant)
249-
assert np.array_equal(x.shape.data, (4,))
249+
assert x.type.shape == (4,)
250+
251+
shape_fn = pytensor.function([x], x.shape)
252+
opt_shape = shape_fn.maker.fgraph.outputs[0]
253+
assert isinstance(opt_shape.owner.op, DeepCopyOp)
254+
assert isinstance(opt_shape.owner.inputs[0], Constant)
255+
assert np.array_equal(opt_shape.owner.inputs[0].data, (4,))
250256

251257
x = TensorConstant(
252258
TensorType("int64", shape=(None, None)), np.array([[1, 2], [2, 3]])

0 commit comments

Comments
 (0)