Skip to content

[Bug][FuseOps] Ops don't fuse together and remain individual PrimFuncs #15358

@echuraev

Description

@echuraev

While working on the PR #15137, I have found that FuseOps pass fuse operations in none intuitive way and sometimes ops don't fuse together and remain individual PrimFuncs. On this branch I prepared a test script which reproduces this situation.

We have a base function which contains one operation:

fn (%x0: Tensor[(10, 20), float32]) {
  multiply(%x0, 2f);
}

In the test we create 5 such operations and compute their sum (base_func0 + base_func1 + ... + base_func4). This Relay function looks in the following way:

fn (%x0: Tensor[(10, 20), float32], %x1: Tensor[(10, 20), float32], %x2: Tensor[(10, 20), float32], %x3: Tensor[(10, 20), float32], %x4: Tensor[(10, 20), float32]) {
  %0 = multiply(%x0, 2f);
  %1 = multiply(%x1, 2f);
  %2 = add(%0, %1);
  %3 = multiply(%x2, 2f);
  %4 = add(%2, %3);
  %5 = multiply(%x3, 2f);
  %6 = add(%4, %5);
  %7 = multiply(%x4, 2f);
  add(%6, %7)
}

We want to specify fusing depth that each PrimFunc will contain maximum two base functions. In this case max_fused_ops = (base_function_ops + 1) * number_of_fused_base_func, where base_function_ops = 1 is the number of operations in base func, number_of_fused_base_func = 2 is the maximum number of base functions in one PrimFunc.

In the formula we add one to base_function_ops, because if we want to fuse N base functions into one function, then for each base function we will have additionally N-1 add operations and +1 add operation for the previous result.

Expected behavior

After fusing algorithm, I expected to see the code that fuse several base functions into one PrimFunc. E.g. in the code below, 5 base functions were fused into 3 PrimFuncs. The first and the second PrimFuncs contain 4 base functions, and the last PrimFunc computes the result of computation for 4 base functions with the last one.

fn (%x0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x01: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x2: Tensor[(10, 20), float32] /* ty=Tensor[(10, 2
0), float32] */, %x02: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */) -> Tensor[(10, 20), float32] {
  %6 = fn (%p02: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p12: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
    %4 = multiply(%p02, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
    %5 = multiply(%p12, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
    add(%4, %5) /* ty=Tensor[(10, 20), float32] */
  } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
  %7 = %6(%x0, %x1) /* ty=Tensor[(10, 20), float32] */;
  %8 = fn (%p01: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p11: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p2: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32]
{
    %1 = multiply(%p01, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
    %2 = add(%p11, %1) /* ty=Tensor[(10, 20), float32] */;
    %3 = multiply(%p2, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
    add(%2, %3) /* ty=Tensor[(10, 20), float32] */
  } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
  %9 = %8(%x01, %7, %x2) /* ty=Tensor[(10, 20), float32] */;
  %10 = fn (%p0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
    %0 = multiply(%p0, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
    add(%p1, %0) /* ty=Tensor[(10, 20), float32] */
  } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
  %10(%x02, %9) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */

Actual behavior

Actually, with the current FuseOps pass implementation, I see a bit different code. There are 5 PrimFuncs and each of these PrimFuncs contains one base function. I suppose it is an incorrect behavior, please correct me if I'm wrong.

fn (%x0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x2: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x3: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20
), float32] */, %x4: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */) -> Tensor[(10, 20), float32] {
  %4 = fn (%p02: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
    multiply(%p02, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */
  } /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
  %5 = fn (%p03: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
    multiply(%p03, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */
  } /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
  %6 = %4(%x1) /* ty=Tensor[(10, 20), float32] */;
  %7 = %5(%x2) /* ty=Tensor[(10, 20), float32] */;
  %8 = fn (%p01: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p11: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p21: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32]
 {
    %2 = multiply(%p01, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
    %3 = add(%2, %p11) /* ty=Tensor[(10, 20), float32] */;
    add(%3, %p21) /* ty=Tensor[(10, 20), float32] */
  } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
  %9 = fn (%p04: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
    multiply(%p04, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */
  } /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
  %10 = %8(%x0, %6, %7) /* ty=Tensor[(10, 20), float32] */;
  %11 = %9(%x4) /* ty=Tensor[(10, 20), float32] */;
  %12 = fn (%p0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p2: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
    %0 = multiply(%p0, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
    %1 = add(%p1, %0) /* ty=Tensor[(10, 20), float32] */;
    add(%1, %p2) /* ty=Tensor[(10, 20), float32] */
  } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
  %12(%x3, %10, %11) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */

Environment

Linux, TVM mainline

Steps to reproduce

You can use the test from this commit: echuraev@88f2d4b

Triage

  • needs-triage
  • flow:relay

Metadata

Metadata

Assignees

No one assigned

    Labels

    flow:relayThe overall lowering flow for tvm.relay.build, including BYOC core, excluding tvm.driver.build.needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions