Skip to content

Missed scan rewrites #787

Open
Open
@aseyboldt

Description

@aseyboldt

Description

There are two issues with the code generated by this snippet:

def update(x):
    return pt.exp(x) - 5

x_init = pt.vector("x_init", shape=(7,))
x_init_tangent = pt.vector("x_init_tangent", shape=(7,))
seq, updates = pytensor.scan(update, outputs_info=[x_init], n_steps=10)
outputs = seq[-1]
output_tangent = pytensor.Rop(outputs, x_init, eval_points=x_init_tangent)

with pytensor.config.change_flags(optimizer_verbose=False):
    func = pytensor.function([x_init, x_init_tangent], [outputs, output_tangent], mode=pytensor.compile.mode.get_mode("FAST_RUN"))

pytensor.dprint(func, print_type=True, print_destroy_map=True)
Subtensor{i} [id A] <Vector(float64, shape=(7,))> 13
 ├─ Scan{scan_fn&rop_of_scan_fn, while_loop=False, inplace=all}.0 [id B] <Matrix(float64, shape=(?, 7))> 12
 │  ├─ 10 [id C] <Scalar(int8, shape=())>
 │  ├─ SetSubtensor{:stop} [id D] <Matrix(float64, shape=(2, 7))> 11
 │  │  ├─ AllocEmpty{dtype='float64'} [id E] <Matrix(float64, shape=(2, 7))> 10
 │  │  │  ├─ 2 [id F] <Scalar(int64, shape=())>
 │  │  │  └─ 7 [id G] <Scalar(int64, shape=())>
 │  │  ├─ SpecifyShape [id H] <Matrix(float64, shape=(1, 7))> 7
 │  │  │  ├─ Unbroadcast{0} [id I] <Matrix(float64, shape=(?, 7))> 6
 │  │  │  │  └─ ExpandDims{axis=0} [id J] <Matrix(float64, shape=(1, 7))> 5
 │  │  │  │     └─ x_init [id K] <Vector(float64, shape=(7,))>
 │  │  │  ├─ 1 [id L] <Scalar(int8, shape=())>
 │  │  │  └─ 7 [id M] <Scalar(int8, shape=())>
 │  │  └─ 1 [id N] <int64>
 │  ├─ SetSubtensor{:stop} [id O] <Matrix(float64, shape=(1, 7))> 9
 │  │  ├─ AllocEmpty{dtype='float64'} [id P] <Matrix(float64, shape=(1, 7))> 8
 │  │  │  ├─ 1 [id Q] <Scalar(int64, shape=())>
 │  │  │  └─ 7 [id G] <Scalar(int64, shape=())>
 │  │  ├─ SpecifyShape [id H] <Matrix(float64, shape=(1, 7))> 7
 │  │  │  └─ ···
 │  │  └─ 1 [id N] <int64>
 │  └─ SetSubtensor{:stop} [id R] <Matrix(float64, shape=(2, 7))> 4
 │     ├─ AllocEmpty{dtype='float64'} [id S] <Matrix(float64, shape=(2, 7))> 3
 │     │  ├─ 2 [id T] <Scalar(int64, shape=())>
 │     │  └─ 7 [id G] <Scalar(int64, shape=())>
 │     ├─ SpecifyShape [id U] <Matrix(float64, shape=(1, 7))> 2
 │     │  ├─ Unbroadcast{0} [id V] <Matrix(float64, shape=(?, 7))> 1
 │     │  │  └─ ExpandDims{axis=0} [id W] <Matrix(float64, shape=(1, 7))> 0
 │     │  │     └─ x_init_tangent [id X] <Vector(float64, shape=(7,))>
 │     │  ├─ 1 [id L] <Scalar(int8, shape=())>
 │     │  └─ 7 [id M] <Scalar(int8, shape=())>
 │     └─ 1 [id N] <int64>
 └─ 1 [id Y] <uint8>
Subtensor{i} [id Z] <Vector(float64, shape=(7,))> 14
 ├─ Scan{scan_fn&rop_of_scan_fn, while_loop=False, inplace=all}.2 [id B] <Matrix(float64, shape=(?, 7))> 12
 │  └─ ···
 └─ 1 [id Y] <uint8>

Inner graphs:

Scan{scan_fn&rop_of_scan_fn, while_loop=False, inplace=all} [id B]
 ← Composite{(exp(i0) - 5.0)} [id BA] <Vector(float64, shape=(7,))>
    └─ *0-<Vector(float64, shape=(7,))> [id BB] <Vector(float64, shape=(7,))> -> [id D]
 ← Composite{...}.0 [id BC] <Vector(float64, shape=(7,))>
    ├─ *1-<Vector(float64, shape=(7,))> [id BD] <Vector(float64, shape=(7,))> -> [id O]
    └─ *2-<Vector(float64, shape=(7,))> [id BE] <Vector(float64, shape=(7,))> -> [id R]
 ← Composite{...}.1 [id BC] <Vector(float64, shape=(7,))>
    └─ ···

Composite{(exp(i0) - 5.0)} [id BA]
 ← sub [id BF] <float64> 'o0'
    ├─ exp [id BG] <float64>
    │  └─ i0 [id BH] <float64>
    └─ 5.0 [id BI] <float64>

Composite{...} [id BC]
 ← sub [id BJ] <float64> 'o0'
    ├─ exp [id BK] <float64> 't3'
    │  └─ i0 [id BL] <float64>
    └─ 5.0 [id BM] <float64>
 ← mul [id BN] <float64> 'o1'
    ├─ exp [id BK] <float64> 't3'
    │  └─ ···
    └─ i1 [id BO] <float64>
  • Intermediate arrays have shape (2, 7) instead of (1, 7) (this also happens without the Rop
  • We compute exp(5) twice in the loop

cc @ricardoV94

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions