-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI] Add transpose_a/b & dynamic shape support for batch matmul #8527
Conversation
Also cc @wyc-ruiker , I remember you mentioned the requirement of this in #8402? Seems you need to add topi schedules yourself after merging this. 😆 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Thanks.
f79d12b
to
76e2b46
Compare
492ee2f
to
8dac20a
Compare
295777b
to
0649ec8
Compare
@comaniac Thanks! I think I've addressed all of the comments. I have tested that the default topi schedule works well for all 4 input formats, at least won't cause any error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…ache#8527) * Add basic support for batch matmul transpose * Update * Lint fix & add tf convert support * Update Lint fix * Bug fix for qnn.batch_matmul * Bug fix for tensorflow test * Add grad support for batch_matmul * Lint fix Re-triggle CI Bug fix Re-triggle CI Re-triggle CI Re-triggle CI
…ache#8527) * Add basic support for batch matmul transpose * Update * Lint fix & add tf convert support * Update Lint fix * Bug fix for qnn.batch_matmul * Bug fix for tensorflow test * Add grad support for batch_matmul * Lint fix Re-triggle CI Bug fix Re-triggle CI Re-triggle CI Re-triggle CI
In #8234, a new op
nn.matmul
was added to extendnn.dense
with transpose/non-transpose inputs support.This PR is going to also extend
nn.batch_matmul
to support inputs to be in transpose or non-transposed format.Since the original topi schedule is still for NT format, I set the default format of
nn.batch_matmul
to be NT. And for theqnn.batch_matmul
and some pass likeCombineParallelBatchMatmul
, I just left them to only support NT currently. Guys who are interested in these may continue to fix :)Things TODO in this PR:
cc @comaniac @altanh @tkonolige