Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Topi] Allow batch_matmul to broadcast along batch dimension. #6616

Merged
merged 9 commits into from
Oct 6, 2020

Conversation

jwfromm
Copy link
Contributor

@jwfromm jwfromm commented Oct 2, 2020

We found that requiring explicit broadcasting along the batch dimension for batch_matmul could cause serious memory issues during constant folding, since it would effectively multiply the size of weights by the input batch size. This PR allows implicit broadcasting along the batch dimension for batch_matmul without increasing compute or memory requirements. This should in fact give pretty significant speedups in cases where we previously applied explicit broadcasting. I also noticed that we had an unused C++ definition of batch_matmul and removed it to prevent confusion.

@jwfromm
Copy link
Contributor Author

jwfromm commented Oct 2, 2020

@mbrookhart, @csullivan, @rkimball can you guys take a look at this PR?

python/tvm/topi/x86/batch_matmul.py Show resolved Hide resolved
src/relay/op/nn/nn.cc Outdated Show resolved Hide resolved
tests/python/topi/python/test_topi_dense.py Outdated Show resolved Hide resolved
Copy link
Contributor

@csullivan csullivan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jwfromm !

One comment regarding the memory use. For best of both worlds when a vendor library is used for batch_matmul, e.g. rocBLAS, if the primitive doesn't support implicit broadcast we will still see excessive memory use from the folded constants. Can you think of a clean solution for that case? My only idea at the moment is to disable constant folding for that case, but that coupling between opt. passes and supported codegen/runtime primitives isn't great.

tests/python/topi/python/test_topi_dense.py Outdated Show resolved Hide resolved
@tqchen
Copy link
Member

tqchen commented Oct 3, 2020

cc @yzhliu @icemelon9 @masahi

@jwfromm
Copy link
Contributor Author

jwfromm commented Oct 4, 2020

I need to make a few fixes after the merge with the dynamic shapes PR, which involved several changes to batch_matmul.

@@ -3628,7 +3628,6 @@ def verify_roi_align(
test_clip_min_max_as_inputs()
test_onehot()
test_matmul()
test_batch_matmul()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to note that I removed this since test_batch_matmul is now run with tvm.testing.parametrize, which means it will cause an error when run using python instead of pytest.

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jwfromm
Copy link
Contributor Author

jwfromm commented Oct 6, 2020

@masahi can you take a quick look at this PR?

@masahi masahi merged commit 889fac1 into apache:master Oct 6, 2020
@masahi
Copy link
Member

masahi commented Oct 6, 2020

Thanks @jwfromm @mbrookhart @csullivan @rkimball

TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Oct 13, 2020
…#6616)

* Allow batch_matmul to broadcast along batch dimension.

* Added typerel checking.

* Fix style issue and respond to feedback.

* Fix style.

* More formatting issues :(

* Fix issues after merge.

* Comment update.

* Small tweak.
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Oct 14, 2020
…#6616)

* Allow batch_matmul to broadcast along batch dimension.

* Added typerel checking.

* Fix style issue and respond to feedback.

* Fix style.

* More formatting issues :(

* Fix issues after merge.

* Comment update.

* Small tweak.
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Oct 15, 2020
…#6616)

* Allow batch_matmul to broadcast along batch dimension.

* Added typerel checking.

* Fix style issue and respond to feedback.

* Fix style.

* More formatting issues :(

* Fix issues after merge.

* Comment update.

* Small tweak.
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Oct 15, 2020
…#6616)

* Allow batch_matmul to broadcast along batch dimension.

* Added typerel checking.

* Fix style issue and respond to feedback.

* Fix style.

* More formatting issues :(

* Fix issues after merge.

* Comment update.

* Small tweak.
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Oct 16, 2020
…#6616)

* Allow batch_matmul to broadcast along batch dimension.

* Added typerel checking.

* Fix style issue and respond to feedback.

* Fix style.

* More formatting issues :(

* Fix issues after merge.

* Comment update.

* Small tweak.
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Oct 19, 2020
…#6616)

* Allow batch_matmul to broadcast along batch dimension.

* Added typerel checking.

* Fix style issue and respond to feedback.

* Fix style.

* More formatting issues :(

* Fix issues after merge.

* Comment update.

* Small tweak.
@jwfromm jwfromm deleted the broadcast_matmul branch April 12, 2023 15:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants