Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support for axis parameter in linalg.gemm #10864

Merged
merged 1 commit into from
May 29, 2018

Conversation

asmushetzel
Copy link
Contributor

Description

This PR adds an optional axis parameter to the linalg.gemm/linalg.gemm2 operators that specifies the axis that indexes the matrix rows. Default is axis = -2 which is the behavior so far for this operators (matrices are encoded by the last two dimensions).
The rationale behind this PR is that in some important uses cases (example is the attention mechanism in the transformer model for neural machine translation) situations requiring a batched matrix-matrix multiply with a non-standard axis for the matrix rows arise naturally. Such computations can be always be performed by an explicit swap-axis/transpose operator followed by a batch-dot, but this adds significant computation overhead. As the underlying blas-libraries are able to deal natively with non-consecutive matrix representations w/out performance impact, it is useful to leverage this and expose a higher level switch such that such additional transpositions can be omitted.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • [ X] Changes are complete (i.e. I finished coding on this PR)
  • [ X] All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • [X ] Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • [ X] To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

@asmushetzel
Copy link
Contributor Author

One check failed on one node and this is something completely unrelated :-(

@piiswrong
Copy link
Contributor

please rebase and try again

@piiswrong
Copy link
Contributor

This this behavior for axis common for other frameworks?

@asmushetzel asmushetzel force-pushed the gemm_axis branch 2 times, most recently from b566c4a to 529389e Compare May 11, 2018 12:21
@asmushetzel
Copy link
Contributor Author

Other frameworks apparently do not offer this additional flexibility. You would have to go through explicit transpose/swap-axis like functions.

@asmushetzel
Copy link
Contributor Author

The build system is currently unstable and produces random failures :-(

Copy link
Contributor

@piiswrong piiswrong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the axis parameter similar to http://deeplearning.net/software/theano/library/tensor/basic.html#theano.tensor.tensordot ?

Could you make more clarifications in the docs?

@@ -53,13 +54,17 @@ struct LaMatrixMacParam : public dmlc::Parameter<LaMatrixMacParam> {
DMLC_DECLARE_FIELD(beta)
.set_default(1.0)
.describe("Scalar factor multiplied with C.");
DMLC_DECLARE_FIELD(axis)
.set_default(-2)
.describe("Axis corresponding to the matrix rows.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little confusing. Is it the rows of the resulting matrix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the operator description and added an example. Let me know if this clarifies things.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not similar to the theano's tensordot

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I'm not familiar with the context, but I want to make sure we are not creating new conventions unless we have to.
Does this solve a different use case from theano's tensordot formulation?

@asmushetzel
Copy link
Contributor Author

It does solve a different use case. linalg.gemm deals with a batch of gemm operations. The extension in this PR relaxes the constraint to have only the leading dimension used as a batch coordinate (by allowing that the coordinate associated with the matrix rows is at a different axis).

Batching and tensordot are two different concepts. That is why Theano also has a batched_tensordot() operator. While a single matrix-matrix product can be formulated either as gemm or as a tensordot, a batch of gemms can not be formulated as a single tensordot. Suppose we have a batch-gemm on shapes A= (I, M, K) and B = (I, K, N) where the first coordinate is the batch coordinate, then batch_gemm will return a shape (I, M, N) while tensordot(A, B, [[2], [1]]) will return a shape (I, I, M, N), i.e. it is a completely different computation.
So this PR allows more flexibility about which axis are batch coordinates, but it is not anything that brings in any tensordot-functionalities.

@piiswrong
Copy link
Contributor

Ok. Thanks for the explanation.

@piiswrong piiswrong merged commit 4ac76c8 into apache:master May 29, 2018
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants