Open
Description
Description
When compiling a vectorized graph with an OpFromGraph
into numba mode, shape inference fails if the shape of the output depends on blockwise operations. Example:
X = pt.dmatrix("X", shape=(None, None))
X_batched = pt.tensor("X", shape=(None, None, None))
Q, R = pt.linalg.qr(X)
results = OpFromGraph(
inputs=[X],
outputs=[Q],
)(X)
z_vec = vectorize_graph(results, {X: X_batched})
fn = pytensor.function(
[X_batched],
[z_vec],
mode='NUMBA',
)
This falls back to object mode, because the introduce_explicit_core_shape_blockwise
here forbids any blockwise. It should only be looking for loops -- specifically that the blockwise Op being rewritten is present in the shape graph.
In the above graph, here is the resulting shape graph:
Shape [id A] <Vector(int64, shape=(3,))>
└─ Blockwise{OpFromGraph{inline=False}, (i00,i01)->(o00,o01)} [id B] <Tensor3(float64, shape=(?, ?, ?))>
└─ X [id C] <Tensor3(float64, shape=(?, ?, ?))>
This might also be something unique to the QR Op
? If so I'll edit the issue, but I still think the check in this rewrite is too strict.