-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Support for axis parameter in linalg.gemm #10864
Conversation
One check failed on one node and this is something completely unrelated :-( |
please rebase and try again |
This this behavior for axis common for other frameworks? |
b566c4a
to
529389e
Compare
Other frameworks apparently do not offer this additional flexibility. You would have to go through explicit transpose/swap-axis like functions. |
The build system is currently unstable and produces random failures :-( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the axis parameter similar to http://deeplearning.net/software/theano/library/tensor/basic.html#theano.tensor.tensordot ?
Could you make more clarifications in the docs?
@@ -53,13 +54,17 @@ struct LaMatrixMacParam : public dmlc::Parameter<LaMatrixMacParam> { | |||
DMLC_DECLARE_FIELD(beta) | |||
.set_default(1.0) | |||
.describe("Scalar factor multiplied with C."); | |||
DMLC_DECLARE_FIELD(axis) | |||
.set_default(-2) | |||
.describe("Axis corresponding to the matrix rows."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a little confusing. Is it the rows of the resulting matrix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the operator description and added an example. Let me know if this clarifies things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not similar to the theano's tensordot
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I'm not familiar with the context, but I want to make sure we are not creating new conventions unless we have to.
Does this solve a different use case from theano's tensordot formulation?
It does solve a different use case. linalg.gemm deals with a batch of gemm operations. The extension in this PR relaxes the constraint to have only the leading dimension used as a batch coordinate (by allowing that the coordinate associated with the matrix rows is at a different axis). Batching and tensordot are two different concepts. That is why Theano also has a batched_tensordot() operator. While a single matrix-matrix product can be formulated either as gemm or as a tensordot, a batch of gemms can not be formulated as a single tensordot. Suppose we have a batch-gemm on shapes A= (I, M, K) and B = (I, K, N) where the first coordinate is the batch coordinate, then batch_gemm will return a shape (I, M, N) while tensordot(A, B, [[2], [1]]) will return a shape (I, I, M, N), i.e. it is a completely different computation. |
Ok. Thanks for the explanation. |
Description
This PR adds an optional axis parameter to the linalg.gemm/linalg.gemm2 operators that specifies the axis that indexes the matrix rows. Default is axis = -2 which is the behavior so far for this operators (matrices are encoded by the last two dimensions).
The rationale behind this PR is that in some important uses cases (example is the attention mechanism in the transformer model for neural machine translation) situations requiring a batched matrix-matrix multiply with a non-standard axis for the matrix rows arise naturally. Such computations can be always be performed by an explicit swap-axis/transpose operator followed by a batch-dot, but this adds significant computation overhead. As the underlying blas-libraries are able to deal natively with non-consecutive matrix representations w/out performance impact, it is useful to leverage this and expose a higher level switch such that such additional transpositions can be omitted.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.