Skip to content

Commit

Permalink
standard adam update for sparse tensor (apache#9468)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueHuang authored and zheng-da committed Jun 28, 2018
1 parent ffa5934 commit 6d4bd4c
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 22 deletions.
31 changes: 18 additions & 13 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,28 +780,28 @@ class Adam(Optimizer):
This class implements the optimizer described in *Adam: A Method for
Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
The optimizer updates the weight by::
rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
m = beta1 * m + (1 - beta1) * rescaled_grad
v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
w = w - learning_rate * m / (sqrt(v) + epsilon)
If the storage types of weight, state and grad are all ``row_sparse``, \
**sparse updates** are applied by::
If the storage types of weight and grad are both ``row_sparse``, and ``lazy_update`` is True, \
**lazy updates** are applied by::
for row in grad.indices:
rescaled_grad[row] = clip(grad[row] * rescale_grad + wd * weight[row], clip_gradient)
m[row] = beta1 * m[row] + (1 - beta1) * rescaled_grad[row]
v[row] = beta2 * v[row] + (1 - beta2) * (rescaled_grad[row]**2)
w[row] = w[row] - learning_rate * m[row] / (sqrt(v[row]) + epsilon)
The sparse update only updates the mean and var for the weights whose row_sparse
The lazy update only updates the mean and var for the weights whose row_sparse
gradient indices appear in the current batch, rather than updating it for all indices.
Compared with the original update, it can provide large improvements in model training
throughput for some applications. However, it provides slightly different semantics than
the original update, and may lead to different empirical results.
Otherwise, **standard updates** are applied by::
rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
m = beta1 * m + (1 - beta1) * rescaled_grad
v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
w = w - learning_rate * m / (sqrt(v) + epsilon)
This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.
Expand All @@ -815,19 +815,24 @@ class Adam(Optimizer):
Exponential decay rate for the second moment estimates.
epsilon : float, optional
Small value to avoid division by 0.
lazy_update : bool, optional
Default is True. If True, lazy updates are applied \
if the storage types of weight and grad are both ``row_sparse``.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
**kwargs):
lazy_update=True, **kwargs):
super(Adam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.lazy_update = lazy_update

def create_state(self, index, weight):
stype = weight.stype if self.lazy_update else 'default'
return (zeros(weight.shape, weight.context, dtype=weight.dtype,
stype=weight.stype), # mean
stype=stype), # mean
zeros(weight.shape, weight.context, dtype=weight.dtype,
stype=weight.stype)) # variance
stype=stype)) # variance

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
Expand Down
77 changes: 72 additions & 5 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,71 @@ inline void AdamUpdateRspRspRspImpl(const AdamParam& param,
var.data(), req, &out_blob);
}

template<int req>
struct AdamStdDnsRspDnsKernel {
template<typename DType, typename IType, typename RType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
const DType beta1, const DType beta2, const DType lr, const DType wd,
const DType epsilon, const DType rescale_grad) {
using namespace mshadow_op;
const bool non_zero = (i == 0) ? prefix_sum[0] > 0
: prefix_sum[i] > prefix_sum[i-1];

const index_t row_i = i * row_length;
const RType grad_i = (prefix_sum[i]-1) * row_length;
for (index_t j = 0; j < row_length; j++) {
const index_t data_i = row_i + j;
const DType grad_rescaled = non_zero ? static_cast<DType>(
grad_data[grad_i + j] * rescale_grad +
weight_data[data_i] * wd)
: static_cast<DType>(weight_data[data_i] * wd);
if (clip_gradient >= 0.0f) {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
clip::Map(grad_rescaled, clip_gradient);
var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map(
clip::Map(grad_rescaled, clip_gradient));
} else {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
var_data[data_i] = beta2 * var_data[data_i] +
(1.f - beta2) * square::Map(grad_rescaled);
}
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
(square_root::Map(var_data[data_i]) + epsilon));
}
}
};


template<typename xpu>
void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
const OpContext& ctx,
const TBlob& weight,
const NDArray& grad,
const TBlob& mean,
const TBlob& var,
const OpReqType& req,
TBlob *out);

template<typename xpu>
inline void AdamStdUpdateRspRspRspImpl(const AdamParam& param,
const OpContext& ctx,
const NDArray& weight,
const NDArray& grad,
const NDArray& mean,
const NDArray& var,
const OpReqType& req,
NDArray *out) {
using namespace mxnet_op;
using namespace rowsparse;
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "AdamStdUpdate", "weights");
mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
AdamStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
var.data(), req, &out_blob);
}

template<typename xpu>
inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
Expand All @@ -868,18 +933,20 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &outputs) {
const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
const auto weight_stype = inputs[0].storage_type();
const auto grad_stype = inputs[1].storage_type();
const auto mean_stype = inputs[2].storage_type();
const auto var_stype = inputs[3].storage_type();
const auto out_stype = outputs[0].storage_type();
CHECK_EQ(mean_stype, weight_stype) << "Inconsistent storage type detected between "
<< " mean.stype = " << mean_stype << " and weight.stype = " << weight_stype;
CHECK_EQ(var_stype, weight_stype) << "Inconsistent storage type detected between "
<< " var.stype = " << var_stype << " and weight.stype = " << weight_stype;
NDArray out = outputs[0];
if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
out_stype == kRowSparseStorage) {
NDArray out = outputs[0];
AdamUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
inputs[3], req[0], &out);
} else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage &&
mean_stype == kDefaultStorage && var_stype == kDefaultStorage &&
out_stype == kRowSparseStorage) {
AdamStdUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
inputs[3], req[0], &out);
} else {
LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
}
Expand Down
62 changes: 61 additions & 1 deletion src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,62 @@ void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
});
}

template<>
void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& param,
const OpContext& ctx,
const TBlob& weight,
const NDArray& grad,
const TBlob& mean,
const TBlob& var,
const OpReqType& req,
TBlob *out) {
using namespace mxnet_op;
using namespace rowsparse;
using namespace mshadow;
Stream<cpu>* s = ctx.get_stream<cpu>();
if (req == kNullOp) return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adam_update";
CHECK_GT(weight.shape_.Size(), 0);
CHECK_GT(mean.shape_.Size(), 0);
CHECK_GT(var.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
const DType* weight_data = weight.dptr<DType>();
const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
const DType* grad_val = grad.data().dptr<DType>();
DType* mean_data = mean.dptr<DType>();
DType* var_data = var.dptr<DType>();
DType* out_data = out->dptr<DType>();
nnvm::dim_t num_rows = weight.shape_[0];
nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
Tensor<cpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<cpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)), s);

nnvm::dim_t* prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
// mark row flags
Kernel<set_zero, cpu>::Launch(s, num_rows, prefix_sum);
if (grad.storage_initialized()) {
Kernel<MarkRowFlgKernel, cpu>::Launch(s, grad.aux_shape(kIdx)[0],
prefix_sum, grad_idx);
// calculate inclusive prefix sum
for (nnvm::dim_t i = 1; i < num_rows; i++) {
prefix_sum[i] += prefix_sum[i - 1];
}
}

Kernel<AdamStdDnsRspDnsKernel<req_type>, cpu>::Launch(s, num_rows, row_length,
out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.rescale_grad));
});
});
});
}


NNVM_REGISTER_OP(sgd_update)
MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
Expand Down Expand Up @@ -329,8 +385,12 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a
.set_num_outputs(1)
.set_attr_parser(ParamParser<AdamParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<4, 1, false, true, false>)
.set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, 2>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
Expand Down
68 changes: 68 additions & 0 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,74 @@ void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& param,
});
}

template<>
void AdamStdUpdateDnsRspDnsImpl<gpu>(const AdamParam& param,
const OpContext& ctx,
const TBlob& weight,
const NDArray& grad,
const TBlob& mean,
const TBlob& var,
const OpReqType& req,
TBlob *out) {
using namespace mxnet_op;
using namespace rowsparse;
using namespace mshadow;
Stream<gpu>* s = ctx.get_stream<gpu>();
if (req == kNullOp) return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adam_update";
CHECK_GT(weight.shape_.Size(), 0);
CHECK_GT(mean.shape_.Size(), 0);
CHECK_GT(var.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
const DType* weight_data = weight.dptr<DType>();
const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
const DType* grad_val = grad.data().dptr<DType>();
DType* mean_data = mean.dptr<DType>();
DType* var_data = var.dptr<DType>();
DType* out_data = out->dptr<DType>();
nnvm::dim_t num_rows = weight.shape_[0];
nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
nnvm::dim_t* prefix_sum = NULL;
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
prefix_sum,
prefix_sum,
num_rows,
Stream<gpu>::GetStream(s));
Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t) +
temp_storage_bytes), s);
prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
d_temp_storage = workspace.dptr_ + num_rows*sizeof(nnvm::dim_t);
// mark row flags
Fill<false>(s, TBlob(prefix_sum, Shape1(num_rows), gpu::kDevMask), kWriteTo, 0);
if (grad.storage_initialized()) {
Kernel<MarkRowFlgKernel, gpu>::Launch(s, grad.aux_shape(kIdx)[0],
prefix_sum, grad_idx);
// calculate inclusive prefix sum
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
prefix_sum,
prefix_sum,
num_rows,
Stream<gpu>::GetStream(s));
}

Kernel<AdamStdDnsRspDnsKernel<req_type>, gpu>::Launch(s, num_rows, row_length,
out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.rescale_grad));
});
});
});
}

NNVM_REGISTER_OP(signsgd_update)
.set_attr<FCompute>("FCompute<gpu>", SignSGDUpdate<gpu>);
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,10 @@ def test_adam():
not kwarg['multi_precision'])):
continue
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
if (default_context() == mx.cpu()):
compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
compare_optimizer(opt1(sparse_update=True, **kwarg), opt2(**kwarg), shape,
dtype, w_stype='row_sparse', g_stype='row_sparse')
compare_optimizer(opt1(**kwarg), opt2(lazy_update=False, **kwarg), shape,
dtype, w_stype='row_sparse', g_stype='row_sparse')


# Signum
class PySignum(mx.optimizer.Optimizer):
Expand Down

0 comments on commit 6d4bd4c

Please sign in to comment.