-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Topi] Allow batch_matmul to broadcast along batch dimension. #6616
Conversation
@mbrookhart, @csullivan, @rkimball can you guys take a look at this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jwfromm !
One comment regarding the memory use. For best of both worlds when a vendor library is used for batch_matmul, e.g. rocBLAS, if the primitive doesn't support implicit broadcast we will still see excessive memory use from the folded constants. Can you think of a clean solution for that case? My only idea at the moment is to disable constant folding for that case, but that coupling between opt. passes and supported codegen/runtime primitives isn't great.
I need to make a few fixes after the merge with the dynamic shapes PR, which involved several changes to batch_matmul. |
@@ -3628,7 +3628,6 @@ def verify_roi_align( | |||
test_clip_min_max_as_inputs() | |||
test_onehot() | |||
test_matmul() | |||
test_batch_matmul() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wanted to note that I removed this since test_batch_matmul
is now run with tvm.testing.parametrize
, which means it will cause an error when run using python instead of pytest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@masahi can you take a quick look at this PR? |
Thanks @jwfromm @mbrookhart @csullivan @rkimball |
…#6616) * Allow batch_matmul to broadcast along batch dimension. * Added typerel checking. * Fix style issue and respond to feedback. * Fix style. * More formatting issues :( * Fix issues after merge. * Comment update. * Small tweak.
…#6616) * Allow batch_matmul to broadcast along batch dimension. * Added typerel checking. * Fix style issue and respond to feedback. * Fix style. * More formatting issues :( * Fix issues after merge. * Comment update. * Small tweak.
…#6616) * Allow batch_matmul to broadcast along batch dimension. * Added typerel checking. * Fix style issue and respond to feedback. * Fix style. * More formatting issues :( * Fix issues after merge. * Comment update. * Small tweak.
…#6616) * Allow batch_matmul to broadcast along batch dimension. * Added typerel checking. * Fix style issue and respond to feedback. * Fix style. * More formatting issues :( * Fix issues after merge. * Comment update. * Small tweak.
…#6616) * Allow batch_matmul to broadcast along batch dimension. * Added typerel checking. * Fix style issue and respond to feedback. * Fix style. * More formatting issues :( * Fix issues after merge. * Comment update. * Small tweak.
…#6616) * Allow batch_matmul to broadcast along batch dimension. * Added typerel checking. * Fix style issue and respond to feedback. * Fix style. * More formatting issues :( * Fix issues after merge. * Comment update. * Small tweak.
We found that requiring explicit broadcasting along the batch dimension for
batch_matmul
could cause serious memory issues during constant folding, since it would effectively multiply the size of weights by the input batch size. This PR allows implicit broadcasting along the batch dimension for batch_matmul without increasing compute or memory requirements. This should in fact give pretty significant speedups in cases where we previously applied explicit broadcasting. I also noticed that we had an unused C++ definition ofbatch_matmul
and removed it to prevent confusion.