Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueHuang committed Mar 21, 2018
1 parent ae77b61 commit 7755f60
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 54 deletions.
78 changes: 78 additions & 0 deletions benchmark/python/sparse/updater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import time
import mxnet as mx
from mxnet.ndarray.sparse import adam_update
import numpy as np
import argparse

mx.random.seed(0)
np.random.seed(0)

parser = argparse.ArgumentParser(description='Benchmark adam updater')
parser.add_argument('--dim-in', type=int, default=240000, help='weight.shape[0]')
parser.add_argument('--dim-out', type=int, default=512, help='weight.shape[1]')
parser.add_argument('--nnr', type=int, default=5000, help='grad.indices.shape[0]')
parser.add_argument('--repeat', type=int, default=1000, help='num repeat')
parser.add_argument('--dense-grad', action='store_true',
help='if set to true, both gradient and weight are dense.')
parser.add_argument('--dense-state', action='store_true',
help='if set to true, states are dense, indicating standard update')
parser.add_argument('--cpu', action='store_true')


args = parser.parse_args()
dim_in = args.dim_in
dim_out = args.dim_out
nnr = args.nnr
ctx = mx.cpu() if args.cpu else mx.gpu()

ones = mx.nd.ones((dim_in, dim_out), ctx=ctx)

if not args.dense_grad:
weight = ones.tostype('row_sparse')
indices = np.arange(dim_in)
np.random.shuffle(indices)
indices = np.unique(indices[:nnr])
indices = mx.nd.array(indices, ctx=ctx)
grad = mx.nd.sparse.retain(weight, indices)
else:
weight = ones.copy()
grad = ones.copy()

if args.dense_state:
mean = ones.copy()
else:
mean = ones.tostype('row_sparse')

var = mean.copy()

# warmup
for i in range(10):
adam_update(weight, grad, mean, var, out=weight, lr=1, wd=0, beta1=0.9,
beta2=0.99, rescale_grad=0.5, epsilon=1e-8)
weight.wait_to_read()

# measure speed
a = time.time()
for i in range(args.repeat):
adam_update(weight, grad, mean, var, out=weight, lr=1, wd=0, beta1=0.9,
beta2=0.99, rescale_grad=0.5, epsilon=1e-8)
weight.wait_to_read()
b = time.time()
print(b - a)
67 changes: 16 additions & 51 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -749,14 +749,17 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
});
}

template<int req, typename xpu>
struct AdamDnsRspDnsKernel;

/*!
* Note: this kernel performs sparse adam update. For each row-slice in row_sparse
* gradient, it finds the corresponding elements in weight, mean and var and performs
* the update.
* The kernel assumes dense weight/mean/var, and row_sparse gradient
*/
template<int req>
struct AdamDnsRspDnsKernelByRow {
struct AdamDnsRspDnsKernel<req, cpu> {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
Expand Down Expand Up @@ -787,8 +790,9 @@ struct AdamDnsRspDnsKernelByRow {
}
};


template<int req>
struct AdamDnsRspDnsKernelByElem {
struct AdamDnsRspDnsKernel<req, gpu> {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
Expand Down Expand Up @@ -843,21 +847,16 @@ inline void AdamUpdateDnsRspDnsImpl(const AdamParam& param,
DType* out_data = out->dptr<DType>();
nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0];
const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
size_t num_threads = num_rows;
if (std::is_same<xpu, gpu>::value) {
Kernel<AdamDnsRspDnsKernelByElem<req_type>, xpu>::Launch(s, num_rows * row_length,
row_length, out_data, mean_data, var_data, weight_data, grad_idx, grad_val,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.rescale_grad));
} else {
Kernel<AdamDnsRspDnsKernelByRow<req_type>, xpu>::Launch(s, num_rows, row_length,
out_data, mean_data, var_data, weight_data, grad_idx, grad_val,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.rescale_grad));
num_threads = num_rows * row_length;
}
Kernel<AdamDnsRspDnsKernel<req_type, xpu>, xpu>::Launch(s, num_threads,
row_length, out_data, mean_data, var_data, weight_data, grad_idx, grad_val,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.rescale_grad));
});
});
});
Expand Down Expand Up @@ -893,42 +892,8 @@ inline void AdamUpdateRspRspRspImpl(const AdamParam& param,
var.data(), req, &out_blob);
}

template<int req>
struct AdamStdDnsRspDnsKernel {
template<typename DType, typename IType, typename RType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
const DType beta1, const DType beta2, const DType lr, const DType wd,
const DType epsilon, const DType rescale_grad) {
using namespace mshadow_op;
const bool non_zero = (i == 0) ? prefix_sum[0] > 0
: prefix_sum[i] > prefix_sum[i-1];

const index_t row_i = i * row_length;
const RType grad_i = (prefix_sum[i]-1) * row_length;
for (index_t j = 0; j < row_length; j++) {
const index_t data_i = row_i + j;
const DType grad_rescaled = non_zero ? static_cast<DType>(
grad_data[grad_i + j] * rescale_grad +
weight_data[data_i] * wd)
: static_cast<DType>(weight_data[data_i] * wd);
if (clip_gradient >= 0.0f) {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
clip::Map(grad_rescaled, clip_gradient);
var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map(
clip::Map(grad_rescaled, clip_gradient));
} else {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
var_data[data_i] = beta2 * var_data[data_i] +
(1.f - beta2) * square::Map(grad_rescaled);
}
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
(square_root::Map(var_data[data_i]) + epsilon));
}
}
};

template<int req, typename xpu>
struct AdamStdDnsRspDnsKernel;

template<typename xpu>
void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
Expand Down
39 changes: 38 additions & 1 deletion src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,43 @@ void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
});
}

template<int req>
struct AdamStdDnsRspDnsKernel<req, cpu> {
template<typename DType, typename IType, typename RType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
const DType beta1, const DType beta2, const DType lr, const DType wd,
const DType epsilon, const DType rescale_grad) {
using namespace mshadow_op;
const bool non_zero = (i == 0) ? prefix_sum[0] > 0
: prefix_sum[i] > prefix_sum[i-1];

const index_t row_i = i * row_length;
const RType grad_i = (prefix_sum[i]-1) * row_length;
for (index_t j = 0; j < row_length; j++) {
const index_t data_i = row_i + j;
const DType grad_rescaled = non_zero ? static_cast<DType>(
grad_data[grad_i + j] * rescale_grad +
weight_data[data_i] * wd)
: static_cast<DType>(weight_data[data_i] * wd);
if (clip_gradient >= 0.0f) {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
clip::Map(grad_rescaled, clip_gradient);
var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map(
clip::Map(grad_rescaled, clip_gradient));
} else {
mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
var_data[data_i] = beta2 * var_data[data_i] +
(1.f - beta2) * square::Map(grad_rescaled);
}
KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
(square_root::Map(var_data[data_i]) + epsilon));
}
}
};


template<>
void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& param,
const OpContext& ctx,
Expand Down Expand Up @@ -193,7 +230,7 @@ void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& param,
}
}

Kernel<AdamStdDnsRspDnsKernel<req_type>, cpu>::Launch(s, num_rows, row_length,
Kernel<AdamStdDnsRspDnsKernel<req_type, cpu>, cpu>::Launch(s, num_rows, row_length,
out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
Expand Down
4 changes: 2 additions & 2 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ void SGDMomStdUpdateDnsRspDnsImpl<gpu>(const SGDMomParam& param,
}

template<int req>
struct AdamStdDnsRspDnsKernelByElem {
struct AdamStdDnsRspDnsKernel<req, gpu> {
template<typename DType, typename IType, typename RType>
MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
Expand Down Expand Up @@ -181,7 +181,7 @@ void AdamStdUpdateDnsRspDnsImpl<gpu>(const AdamParam& param,
Stream<gpu>::GetStream(s));
}

Kernel<AdamStdDnsRspDnsKernelByElem<req_type>, gpu>::Launch(s, weight.shape_.Size(),
Kernel<AdamStdDnsRspDnsKernel<req_type, gpu>, gpu>::Launch(s, weight.shape_.Size(),
row_length, out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
Expand Down

0 comments on commit 7755f60

Please sign in to comment.