-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Layout] Unify dense op input layout #8921
Conversation
@masahi Thanks for fixing this issue! Just one random thought out of this issue that may further improve TVM: is the layout annotation imposing too many unnecessary constraints (i.e., asserting semantics of each dimension)? Will it be better to not annotate layout itself, but instead annotate how the layout got transformed? As in this case, previously TVM assumes However, I feel like what |
@lazycal That is an interesting suggestion. Layout annotation in TVM was introduced a long time ago, and I believe this is the simplest solution that works in most cases. Indeed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks @masahi. The fix makes perfect sense to me. |
Fixes the issue reported in https://discuss.tvm.apache.org/t/pytorch-layout-cannot-convert-f-linear-x-f-linear-y-z/10866/
Currently,
dense
op's input layouts areNC
andNK
for data and weight respectively. This causes an issue when the output ofdense
is used as the weight for anotherdense
. We get "Incompatible layouts: NC vs NK" error.There is no need to distinguish the second dim of
data
andweight
, since both must be the same value. So I replacedNK
withNC
.@comaniac @yzhliu