Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aggressive elementwise fusion #39

Closed
Tracked by #29
jafioti opened this issue Mar 6, 2024 · 5 comments
Closed
Tracked by #29

Aggressive elementwise fusion #39

jafioti opened this issue Mar 6, 2024 · 5 comments

Comments

@jafioti
Copy link
Owner

jafioti commented Mar 6, 2024

Currently the elementwise fusion is very conservative in what it fuses. It can be a lot more aggressive by:

  • Fusing constants into kernels
  • Fusing across shape changes and contiguous ops (by stacking views inside the kernel)
  • Handling intermediate elementwise outputs that get used multiple times by downstream computation:
b = exp(a); // Intermediate
c = b + sin(b);

It should be possible to fuse this test down to a single kernel:

#[test]
fn test_fusion() {
    let mut cx = Graph::new();
    let a = cx.named_tensor::<R1<10>>("a").set(random_vec(10)).keep();
    let b = cx.named_tensor::<R1<10>>("b").set(random_vec(10)).keep();
    let c = cx.constant(2.123);
    let d = cx.named_tensor::<R1<10>>("b").set(random_vec(10)).keep();
    let mut out = ((a.exp2() - b.sin()).relu() + c.expand() / d).retrieve();

    cx.execute();
    let unopt_out = out.data();
    out.drop();

    cx.compile(<(GenericCompiler, MetalCompiler<f16>)>::default(), &mut out);
    cx.execute();

    assert_close(&out.data(), &unopt_out);
}
@jafioti jafioti mentioned this issue Mar 6, 2024
11 tasks
@jafioti
Copy link
Owner Author

jafioti commented Mar 11, 2024

As of 3e956f91e77c0e134c51553f1528a03e1acffa02 we can now fuse across arbitrary reshapes / contiguous ops. The other half of this is still to come: supporting common subexpressions internal to the kernel

@jafioti
Copy link
Owner Author

jafioti commented Mar 12, 2024

As of 780951c8289693095eb56b4653d4b9353cf0b083 the fusion is good enough to remove the custom rope kernel! There's a slight performance disadvantage (17.8 tps custom vs 17.0 tps automatic) mainly due to not fusing in the subexpressions yet, but this approch demonstrates powerful kernel generation

@jafioti
Copy link
Owner Author

jafioti commented Apr 4, 2024

What happens if we do tensor.slice(..1).cos().pad([0, 1])? This causes the sliced out indexes to get passed through cos as 0, which means cos(0) = 1, so we get out 1s for the sliced out pieces, which are then padded back in.

Instead we need to insert indexing and valid expressions in between each and every component of the equation, not just stack them at the beginning.

@jafioti
Copy link
Owner Author

jafioti commented Apr 7, 2024

As of 589148707f5c53f479703b709d663257a21f3760 this is no longer a problem, as I have rewritten the entire fusion to now use subexpressions and properly do valid checks. Each subexpression now does a valid check, not just a single valid check at the beninning. As a side benifit, the kernels are much easier to read!

Only remaining issue is mistral still outputs jibberish and runs slow for some reason when fusion is on.

@jafioti
Copy link
Owner Author

jafioti commented Apr 9, 2024

Ok mistral is finally fixed. Current fusion is slower than before (~15 tps vs 17 tps) but it is correct, unlike before. The reason for the slowdown is almost definitely due to huge index and valid expressions. This will be solved in #47

@jafioti jafioti closed this as completed Apr 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant