-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
While working on the PR #15137, I have found that FuseOps pass fuse operations in none intuitive way and sometimes ops don't fuse together and remain individual PrimFuncs. On this branch I prepared a test script which reproduces this situation.
We have a base function which contains one operation:
fn (%x0: Tensor[(10, 20), float32]) {
multiply(%x0, 2f);
}
In the test we create 5 such operations and compute their sum (base_func0 + base_func1 + ... + base_func4). This Relay function looks in the following way:
fn (%x0: Tensor[(10, 20), float32], %x1: Tensor[(10, 20), float32], %x2: Tensor[(10, 20), float32], %x3: Tensor[(10, 20), float32], %x4: Tensor[(10, 20), float32]) {
%0 = multiply(%x0, 2f);
%1 = multiply(%x1, 2f);
%2 = add(%0, %1);
%3 = multiply(%x2, 2f);
%4 = add(%2, %3);
%5 = multiply(%x3, 2f);
%6 = add(%4, %5);
%7 = multiply(%x4, 2f);
add(%6, %7)
}
We want to specify fusing depth that each PrimFunc will contain maximum two base functions. In this case max_fused_ops = (base_function_ops + 1) * number_of_fused_base_func, where base_function_ops = 1 is the number of operations in base func, number_of_fused_base_func = 2 is the maximum number of base functions in one PrimFunc.
In the formula we add one to base_function_ops, because if we want to fuse N base functions into one function, then for each base function we will have additionally N-1 add operations and +1 add operation for the previous result.
Expected behavior
After fusing algorithm, I expected to see the code that fuse several base functions into one PrimFunc. E.g. in the code below, 5 base functions were fused into 3 PrimFuncs. The first and the second PrimFuncs contain 4 base functions, and the last PrimFunc computes the result of computation for 4 base functions with the last one.
fn (%x0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x01: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x2: Tensor[(10, 20), float32] /* ty=Tensor[(10, 2
0), float32] */, %x02: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */) -> Tensor[(10, 20), float32] {
%6 = fn (%p02: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p12: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
%4 = multiply(%p02, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
%5 = multiply(%p12, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
add(%4, %5) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
%7 = %6(%x0, %x1) /* ty=Tensor[(10, 20), float32] */;
%8 = fn (%p01: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p11: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p2: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32]
{
%1 = multiply(%p01, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
%2 = add(%p11, %1) /* ty=Tensor[(10, 20), float32] */;
%3 = multiply(%p2, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
add(%2, %3) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
%9 = %8(%x01, %7, %x2) /* ty=Tensor[(10, 20), float32] */;
%10 = fn (%p0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
%0 = multiply(%p0, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
add(%p1, %0) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
%10(%x02, %9) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */
Actual behavior
Actually, with the current FuseOps pass implementation, I see a bit different code. There are 5 PrimFuncs and each of these PrimFuncs contains one base function. I suppose it is an incorrect behavior, please correct me if I'm wrong.
fn (%x0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x2: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x3: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20
), float32] */, %x4: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */) -> Tensor[(10, 20), float32] {
%4 = fn (%p02: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
multiply(%p02, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
%5 = fn (%p03: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
multiply(%p03, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
%6 = %4(%x1) /* ty=Tensor[(10, 20), float32] */;
%7 = %5(%x2) /* ty=Tensor[(10, 20), float32] */;
%8 = fn (%p01: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p11: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p21: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32]
{
%2 = multiply(%p01, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
%3 = add(%2, %p11) /* ty=Tensor[(10, 20), float32] */;
add(%3, %p21) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
%9 = fn (%p04: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
multiply(%p04, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
%10 = %8(%x0, %6, %7) /* ty=Tensor[(10, 20), float32] */;
%11 = %9(%x4) /* ty=Tensor[(10, 20), float32] */;
%12 = fn (%p0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p2: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
%0 = multiply(%p0, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
%1 = add(%p1, %0) /* ty=Tensor[(10, 20), float32] */;
add(%1, %p2) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
%12(%x3, %10, %11) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */
Environment
Linux, TVM mainline
Steps to reproduce
You can use the test from this commit: echuraev@88f2d4b
Triage
- needs-triage
- flow:relay