From 36043b9530572e0377fa94cbe743b3fd19d578e9 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Sat, 30 Sep 2017 23:25:06 -0700 Subject: [PATCH] Enhancement for distributed sparse linear regression example (#7864) * add log loss * update sparse LR example * add readme * fix typo * fix lint * change name from log loss to nll * lint * enhance test --- benchmark/python/sparse/sparse_end2end.py | 4 +- example/sparse/get_data.py | 11 +- example/sparse/linear_classification.py | 176 +++++++--------------- example/sparse/readme.md | 14 ++ python/mxnet/metric.py | 67 ++++++++ src/io/iter_libsvm.cc | 4 +- tests/python/unittest/test_metric.py | 11 ++ 7 files changed, 160 insertions(+), 127 deletions(-) create mode 100644 example/sparse/readme.md diff --git a/benchmark/python/sparse/sparse_end2end.py b/benchmark/python/sparse/sparse_end2end.py index 0c1699b1daa5..e9efc7577923 100644 --- a/benchmark/python/sparse/sparse_end2end.py +++ b/benchmark/python/sparse/sparse_end2end.py @@ -84,7 +84,7 @@ def next(self): 'data_name': 'avazu-app.t', 'data_origin_name': 'avazu-app.t.bz2', 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2", - 'feature_dim': 1000000, + 'feature_dim': 1000001, 'lc': 1719304, } @@ -92,7 +92,7 @@ def next(self): 'data_name': 'kdda.t', 'data_origin_name': 'kdda.t.bz2', 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2", - 'feature_dim': 20216830, + 'feature_dim': 20216831, 'lc': 510302, } diff --git a/example/sparse/get_data.py b/example/sparse/get_data.py index 578cf2ce5226..21db06d8e746 100644 --- a/example/sparse/get_data.py +++ b/example/sparse/get_data.py @@ -17,16 +17,17 @@ # pylint: skip-file import os, gzip -import pickle as pickle import sys -def get_libsvm_data(data_dir, data_name, url, data_origin_name): +def get_libsvm_data(data_dir, data_name, url): if not os.path.isdir(data_dir): os.mkdir(data_dir) os.chdir(data_dir) if (not os.path.exists(data_name)): + print("Dataset " + data_name + " not present. Downloading now ...") import urllib - zippath = os.path.join(data_dir, data_origin_name) - urllib.urlretrieve(url, zippath) - os.system("bzip2 -d %r" % data_origin_name) + zippath = os.path.join(data_dir, data_name + ".bz2") + urllib.urlretrieve(url + data_name + ".bz2", zippath) + os.system("bzip2 -d %r" % data_name + ".bz2") + print("Dataset " + data_name + " is now present.") os.chdir("..") diff --git a/example/sparse/linear_classification.py b/example/sparse/linear_classification.py index 567568c6eb80..188d55f8ae86 100644 --- a/example/sparse/linear_classification.py +++ b/example/sparse/linear_classification.py @@ -18,168 +18,108 @@ import mxnet as mx from mxnet.test_utils import * from get_data import get_libsvm_data -import time import argparse import os parser = argparse.ArgumentParser(description="Run sparse linear classification " \ "with distributed kvstore", formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument('--profiler', type=int, default=0, - help='whether to use profiler') -parser.add_argument('--num-epoch', type=int, default=1, +parser.add_argument('--num-epoch', type=int, default=5, help='number of epochs to train') parser.add_argument('--batch-size', type=int, default=8192, help='number of examples per batch') -parser.add_argument('--num-batch', type=int, default=99999999, - help='number of batches per epoch') -parser.add_argument('--dummy-iter', type=int, default=0, - help='whether to use dummy iterator to exclude io cost') -parser.add_argument('--kvstore', type=str, default='dist_sync', - help='what kvstore to use [local, dist_sync, etc]') -parser.add_argument('--log-level', type=str, default='DEBUG', - help='logging level [debug, info, error]') -parser.add_argument('--dataset', type=str, default='avazu', - help='what test dataset to use') - -class DummyIter(mx.io.DataIter): - "A dummy iterator that always return the same batch, used for speed testing" - def __init__(self, real_iter): - super(DummyIter, self).__init__() - self.real_iter = real_iter - self.provide_data = real_iter.provide_data - self.provide_label = real_iter.provide_label - self.batch_size = real_iter.batch_size - - for batch in real_iter: - self.the_batch = batch - break - - def __iter__(self): - return self - - def next(self): - return self.the_batch - -# testing dataset sources -avazu = { - 'data_name': 'avazu-app.t', - 'data_origin_name': 'avazu-app.t.bz2', - 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2", - 'feature_dim': 1000000, +parser.add_argument('--kvstore', type=str, default=None, + help='what kvstore to use [local, dist_async, etc]') + +AVAZU = { + 'train': 'avazu-app', + 'test': 'avazu-app.t', + 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/", + # 1000000 + 1 since LibSVMIter uses zero-based indexing + 'num_features': 1000001, } -kdda = { - 'data_name': 'kdda.t', - 'data_origin_name': 'kdda.t.bz2', - 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2", - 'feature_dim': 20216830, -} - -datasets = { 'kdda' : kdda, 'avazu' : avazu } - -def linear_model(feature_dim): - x = mx.symbol.Variable("data", stype='csr') - norm_init = mx.initializer.Normal(sigma=0.01) - weight = mx.symbol.Variable("weight", shape=(feature_dim, 1), init=norm_init, stype='row_sparse') - bias = mx.symbol.Variable("bias", shape=(1,), init=norm_init) - dot = mx.symbol.dot(x, weight) - pred = mx.symbol.broadcast_add(dot, bias) - y = mx.symbol.Variable("softmax_label") - model = mx.symbol.SoftmaxOutput(data=pred, label=y, name="out") - return model +def linear_model(num_features): + # data with csr storage type to enable feeding data with CSRNDArray + x = mx.symbol.Variable("data", stype='csr') + norm_init = mx.initializer.Normal(sigma=0.01) + # weight with row_sparse storage type to enable sparse gradient updates + weight = mx.symbol.Variable("weight", shape=(num_features, 2), init=norm_init, stype='row_sparse') + bias = mx.symbol.Variable("bias", shape=(2, )) + dot = mx.symbol.sparse.dot(x, weight) + pred = mx.symbol.broadcast_add(dot, bias) + y = mx.symbol.Variable("softmax_label") + model = mx.symbol.SoftmaxOutput(data=pred, label=y, multi_output=True, name="out") + return model if __name__ == '__main__': + import logging + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=logging.INFO, format=head) + # arg parser args = parser.parse_args() num_epoch = args.num_epoch - num_batch = args.num_batch kvstore = args.kvstore - profiler = args.profiler > 0 batch_size = args.batch_size - dummy_iter = args.dummy_iter - dataset = args.dataset - log_level = args.log_level # create kvstore - kv = mx.kvstore.create(kvstore) - rank = kv.rank - num_worker = kv.num_workers - - # only print log for rank 0 worker - import logging - if rank != 0: - log_level = logging.ERROR - elif log_level == 'DEBUG': - log_level = logging.DEBUG - else: - log_level = logging.INFO - head = '%(asctime)-15s %(message)s' - logging.basicConfig(level=log_level, format=head) + kv = mx.kvstore.create(kvstore) if kvstore else None + rank = kv.rank if kv else 0 + num_worker = kv.num_workers if kv else 1 # dataset - assert(dataset in datasets), "unknown dataset " + dataset - metadata = datasets[dataset] - feature_dim = metadata['feature_dim'] - if logging: - logging.debug('preparing data ... ') + num_features = AVAZU['num_features'] data_dir = os.path.join(os.getcwd(), 'data') - path = os.path.join(data_dir, metadata['data_name']) - if not os.path.exists(path): - get_libsvm_data(data_dir, metadata['data_name'], metadata['url'], - metadata['data_origin_name']) - assert os.path.exists(path) + train_data = os.path.join(data_dir, AVAZU['train']) + val_data = os.path.join(data_dir, AVAZU['test']) + get_libsvm_data(data_dir, AVAZU['train'], AVAZU['url']) + get_libsvm_data(data_dir, AVAZU['test'], AVAZU['url']) # data iterator - train_data = mx.io.LibSVMIter(data_libsvm=path, data_shape=(feature_dim,), + train_data = mx.io.LibSVMIter(data_libsvm=train_data, data_shape=(num_features,), batch_size=batch_size, num_parts=num_worker, part_index=rank) - if dummy_iter: - train_data = DummyIter(train_data) + eval_data = mx.io.LibSVMIter(data_libsvm=val_data, data_shape=(num_features,), + batch_size=batch_size) # model - model = linear_model(feature_dim) + model = linear_model(num_features) # module mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['softmax_label']) mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label) - mod.init_params(initializer=mx.init.Uniform(scale=.1)) + mod.init_params() sgd = mx.optimizer.SGD(momentum=0.0, clip_gradient=5.0, - learning_rate=0.1, rescale_grad=1.0/batch_size/num_worker) + learning_rate=0.001, rescale_grad=1.0/batch_size/num_worker) mod.init_optimizer(optimizer=sgd, kvstore=kv) # use accuracy as the metric - metric = mx.metric.create('Accuracy') - - # start profiler - if profiler: - name = 'profile_output_' + str(num_worker) + '.json' - mx.profiler.profiler_set_config(mode='all', filename=name) - mx.profiler.profiler_set_state('run') + metric = mx.metric.create('log_loss') - logging.debug('start training ...') - start = time.time() + logging.info('Training started ...') data_iter = iter(train_data) for epoch in range(num_epoch): nbatch = 0 - data_iter.reset() metric.reset() for batch in data_iter: nbatch += 1 - row_ids = batch.data[0].indices - # pull sparse weight - index = mod._exec_group.param_names.index('weight') - kv.row_sparse_pull('weight', mod._exec_group.param_arrays[index], - priority=-index, row_ids=[row_ids]) + # for distributed training, we need to explicitly pull sparse weights from kvstore + if kv: + row_ids = batch.data[0].indices + # pull sparse weight based on the indices + index = mod._exec_group.param_names.index('weight') + kv.row_sparse_pull('weight', mod._exec_group.param_arrays[index], + priority=-index, row_ids=[row_ids]) mod.forward_backward(batch) # update parameters mod.update() - # accumulate prediction accuracy + # update training metric mod.update_metric(metric, batch.label) - if nbatch == num_batch: - break - logging.info('epoch %d, %s' % (epoch, metric.get())) - if profiler: - mx.profiler.profiler_set_state('stop') - end = time.time() - time_cost = end - start - logging.info('num_worker = ' + str(num_worker) + ', time cost = ' + str(time_cost)) + if nbatch % 100 == 0: + logging.info('epoch %d batch %d, train log loss = %s' % (epoch, nbatch, metric.get()[1])) + # evaluate metric on validation dataset + score = mod.score(eval_data, ['log_loss']) + logging.info('epoch %d, eval log loss = %s' % (epoch, score[0][1])) + # reset the iterator for next pass of data + data_iter.reset() + logging.info('Training completed.') diff --git a/example/sparse/readme.md b/example/sparse/readme.md new file mode 100644 index 000000000000..14dcbd820522 --- /dev/null +++ b/example/sparse/readme.md @@ -0,0 +1,14 @@ +Example +=========== +This folder contains examples using the sparse feature in MXNet. + +## Linear Classification + +The example utilizes the sparse data loader, sparse operators and a sparse gradient updater to train a linear model on the [Avazu](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#avazu) click-through-prediction dataset. + +- `python linear_classification.py` + +Notes on Distributed Training: + +- For distributed training, please use the `../../tools/launch.py` script to launch a cluster. +- For example, to run two workers and two servers with one machine, run `../../tools/launch.py -n 2 --cluster=local python linear_classification.py --kvstore=dist_async` diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 55d9859c6643..5b0780aeccee 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -917,6 +917,73 @@ def update(self, labels, preds): self.sum_metric += (-numpy.log(prob + self.eps)).sum() self.num_inst += label.shape[0] +@register +@alias('nll_loss') +class NegativeLogLikelihood(EvalMetric): + """Computes the negative log-likelihood loss. + + The negative log-likelihoodd loss over a batch of sample size :math:`N` is given by + + .. math:: + -\\sum_{n=1}^{N}\\sum_{k=1}^{K}t_{nk}\\log (y_{nk}), + + where :math:`K` is the number of classes, :math:`y_{nk}` is the prediceted probability for + :math:`k`-th class for :math:`n`-th sample. :math:`t_{nk}=1` if and only if sample + :math:`n` belongs to class :math:`k`. + + Parameters + ---------- + eps : float + Negative log-likelihood loss is undefined for predicted value is 0, + so predicted values are added with the small constant. + name : str + Name of this metric instance for display. + output_names : list of str, or None + Name of predictions that should be used when updating with update_dict. + By default include all predictions. + label_names : list of str, or None + Name of labels that should be used when updating with update_dict. + By default include all labels. + + Examples + -------- + >>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])] + >>> labels = [mx.nd.array([0, 1, 1])] + >>> nll_loss = mx.metric.NegativeLogLikelihood() + >>> nll_loss.update(labels, predicts) + >>> print nll_loss.get() + ('nll-loss', 0.57159948348999023) + """ + def __init__(self, eps=1e-12, name='nll-loss', + output_names=None, label_names=None): + super(NegativeLogLikelihood, self).__init__( + name, eps=eps, + output_names=output_names, label_names=label_names) + self.eps = eps + + def update(self, labels, preds): + """Updates the internal evaluation result. + + Parameters + ---------- + labels : list of `NDArray` + The labels of the data. + + preds : list of `NDArray` + Predicted values. + """ + check_label_shapes(labels, preds) + + for label, pred in zip(labels, preds): + label = label.asnumpy() + pred = pred.asnumpy() + + label = label.ravel() + num_examples = pred.shape[0] + assert label.shape[0] == num_examples, (label.shape[0], num_examples) + prob = pred[numpy.arange(num_examples, dtype=numpy.int64), numpy.int64(label)] + self.sum_metric += (-numpy.log(prob + self.eps)).sum() + self.num_inst += num_examples @register @alias('pearsonr') diff --git a/src/io/iter_libsvm.cc b/src/io/iter_libsvm.cc index ab6cacb11773..3ccbc9cccea7 100644 --- a/src/io/iter_libsvm.cc +++ b/src/io/iter_libsvm.cc @@ -48,11 +48,11 @@ struct LibSVMIterParam : public dmlc::Parameter { // declare parameters DMLC_DECLARE_PARAMETER(LibSVMIterParam) { DMLC_DECLARE_FIELD(data_libsvm) - .describe("The input LibSVM file or a directory path."); + .describe("The input zero-base indexed LibSVM data file or a directory path."); DMLC_DECLARE_FIELD(data_shape) .describe("The shape of one example."); DMLC_DECLARE_FIELD(label_libsvm).set_default("NULL") - .describe("The input LibSVM file or a directory path. " + .describe("The input LibSVM label file or a directory path. " "If NULL, all labels will be read from ``data_libsvm``."); index_t shape1[] = {1}; DMLC_DECLARE_FIELD(label_shape).set_default(TShape(shape1, shape1 + 1)) diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index 7ae93bf36299..31f31e6e626d 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -16,6 +16,7 @@ # under the License. import mxnet as mx +import numpy as np import json def check_metric(metric, *args, **kwargs): @@ -31,9 +32,19 @@ def test_metrics(): check_metric('f1') check_metric('perplexity', -1) check_metric('pearsonr') + check_metric('nll_loss') composite = mx.metric.create(['acc', 'f1']) check_metric(composite) +def test_nll_loss(): + metric = mx.metric.create('nll_loss') + pred = mx.nd.array([[0.2, 0.3, 0.5], [0.6, 0.1, 0.3]]) + label = mx.nd.array([2, 1]) + metric.update([label], [pred]) + _, loss = metric.get() + expected_loss = 0.0 + expected_loss = -(np.log(pred[0][2].asscalar()) + np.log(pred[1][1].asscalar())) / 2 + assert loss == expected_loss if __name__ == '__main__': import nose