Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-358] support dense weight & sparse grad for adam, sgd and sgd_…
Browse files Browse the repository at this point in the history
…momentum (#10664)

* + support for dense weight with sparse grad for adam & sgd
mom

* fix test

* sgd passes

* fix typo

* support adam

* update doc
  • Loading branch information
eric-haibin-lin authored May 1, 2018
1 parent 147c83a commit 9f8f042
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 167 deletions.
10 changes: 5 additions & 5 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def _get_wd(self, index):
class SGD(Optimizer):
"""The SGD optimizer with momentum and weight decay.
If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \
If the storage types of grad is ``row_sparse`` and ``lazy_update`` is True, \
**lazy updates** are applied by::
for row in grad.indices:
Expand Down Expand Up @@ -494,8 +494,8 @@ def create_state_multi_precision(self, index, weight):

def create_state(self, index, weight):
momentum = None
stype = weight.stype if self.lazy_update else 'default'
if self.momentum != 0.0:
stype = weight.stype if self.lazy_update else 'default'
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype)
return momentum

Expand All @@ -515,7 +515,7 @@ def _update_impl(self, index, weight, grad, state, multi_precision=False):
if not multi_precision:
if state is not None:
sgd_mom_update(weight, grad, state, out=weight,
lr=lr, wd=wd, **kwargs)
lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs)
else:
sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update,
lr=lr, wd=wd, **kwargs)
Expand Down Expand Up @@ -986,7 +986,7 @@ class Adam(Optimizer):
This class implements the optimizer described in *Adam: A Method for
Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \
If the storage types of grad is ``row_sparse``, and ``lazy_update`` is True, \
**lazy updates** are applied by::
for row in grad.indices:
Expand Down Expand Up @@ -1059,7 +1059,7 @@ def update(self, index, weight, grad, state):

mean, var = state
adam_update(weight, grad, mean, var, out=weight,
lr=lr, wd=wd, **kwargs)
lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs)

@register
class AdaGrad(Optimizer):
Expand Down
18 changes: 11 additions & 7 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,14 +471,18 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) {
attrs->parsed = std::move(param);
}

#define CHECK_RSP_ALL_ROWS_NON_ZERO(rsp, func, param) \
{ \
CHECK(rsp.storage_shape()[0] == rsp.shape()[0]) << func \
<< " for RowSparse " << param << " is only implemented for " \
<< "RowSparse " << param << " with all rows containing non-zeros. " \
<< "Expects " << param << ".values.shape[0] (" << rsp.storage_shape()[0] \
<< ") == " << param << ".shape[0] (" << rsp.shape()[0] << ")."; \
inline void CheckAllRowsPresent(const NDArray& arr, const std::string& func,
const std::string& param) {
if (arr.storage_type() == kRowSparseStorage) {
CHECK(arr.storage_shape()[0] == arr.shape()[0]) << func
<< " for RowSparse " << param << " is only implemented for "
<< "RowSparse " << param << " with all rows containing non-zeros. "
<< "Expects " << param << ".data.shape[0] (" << arr.storage_shape()[0]
<< ") == " << param << ".shape[0] (" << arr.shape()[0] << ").";
} else {
CHECK(arr.storage_type() == kDefaultStorage);
}
}

inline void LogUnimplementedOp(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
Expand Down
Loading

0 comments on commit 9f8f042

Please sign in to comment.