Skip to content

Commit 25f5a26

Browse files
committed
[io] sframe iter
1 parent 08466f8 commit 25f5a26

File tree

6 files changed

+243
-4
lines changed

6 files changed

+243
-4
lines changed

Makefile

+4-2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ ifeq ($(USE_DIST_KVSTORE), 1)
8989
LDFLAGS += $(PS_LDFLAGS_A)
9090
endif
9191

92+
include $(MXNET_PLUGINS)
93+
9294
.PHONY: clean all test lint doc clean_all rcpplint rcppexport roxygen
9395

9496
all: lib/libmxnet.a lib/libmxnet.so $(BIN)
@@ -117,7 +119,7 @@ ifeq ($(USE_TORCH), 1)
117119
ifeq ($(USE_CUDA), 1)
118120
LDFLAGS += -lcutorch -lcunn
119121
endif
120-
122+
121123
TORCH_SRC = $(wildcard plugin/torch/*.cc)
122124
PLUGIN_OBJ += $(patsubst %.cc, build/%.o, $(TORCH_SRC))
123125
TORCH_CUSRC = $(wildcard plugin/torch/*.cu)
@@ -200,7 +202,7 @@ include tests/cpp/unittest.mk
200202
test: $(TEST)
201203

202204
lint: rcpplint
203-
python2 dmlc-core/scripts/lint.py mxnet ${LINT_LANG} include src plugin scripts python predict/python
205+
python2 dmlc-core/scripts/lint.py mxnet ${LINT_LANG} include src plugin scripts python predict/python
204206

205207
doc: doxygen
206208

make/config.mk

+5
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,8 @@ EXTRA_OPERATORS =
114114
# whether to use torch integration. This requires installing torch.
115115
USE_TORCH = 0
116116
TORCH_PATH = $(HOME)/torch
117+
118+
# whether to use sframe integration. This requires build sframe
119+
# git@github.com:dato-code/SFrame.git
120+
# SFRAME_PATH = $(HOME)/SFrame
121+
# MXNET_PLUGINS += plugin/sframe/SFrame.mk

plugin/sframe/SFrame.mk

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
SFRMAE_SRC = plugin/sframe/iter_sframe.cc
2+
PLUGIN_OBJ += build/plugin/sframe/iter_sframe.o
3+
CFLAGS += -I$(SFRAME_PATH)/oss_src/unity/lib/
4+
CFLAGS += -I$(SFRAME_PATH)/oss_src/
5+
LDFLAGS += -L$(SFRAME_PATH)/release/oss_src/unity/python/sframe/
6+
LDFLAGS += -lunity_shared
7+
LDFLAGS += -lboost_system

plugin/sframe/iter_sframe.cc

+225
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
/*!
2+
* Copyright (c) 2015 by Contributors
3+
* \file iter_sframe_image.cc
4+
* \brief
5+
* \author Bing Xu
6+
*/
7+
8+
#include <unity/lib/image_util.hpp>
9+
#include <unity/lib/gl_sframe.hpp>
10+
#include <unity/lib/gl_sarray.hpp>
11+
#include <mxnet/io.h>
12+
#include <dmlc/base.h>
13+
#include <dmlc/io.h>
14+
#include <dmlc/omp.h>
15+
#include <dmlc/logging.h>
16+
#include <dmlc/parameter.h>
17+
#include <string>
18+
#include <memory>
19+
#include "../../src/io/inst_vector.h"
20+
#include "../../src/io/image_recordio.h"
21+
#include "../../src/io/image_augmenter.h"
22+
#include "../../src/io/iter_prefetcher.h"
23+
#include "../../src/io/iter_normalize.h"
24+
#include "../../src/io/iter_batchloader.h"
25+
26+
namespace mxnet {
27+
namespace io {
28+
29+
struct SFrameParam : public dmlc::Parameter<SFrameParam> {
30+
/*! \brief sframe path */
31+
std::string path_sframe;
32+
std::string data_field;
33+
std::string label_field;
34+
TShape data_shape;
35+
TShape label_shape;
36+
DMLC_DECLARE_PARAMETER(SFrameParam) {
37+
DMLC_DECLARE_FIELD(path_sframe).set_default("")
38+
.describe("Dataset Param: path to image dataset sframe");
39+
DMLC_DECLARE_FIELD(data_field).set_default("data")
40+
.describe("Dataset Param: data column in sframe");
41+
DMLC_DECLARE_FIELD(label_field).set_default("label")
42+
.describe("Dataset Param: label column in sframe");
43+
DMLC_DECLARE_FIELD(data_shape)
44+
.describe("Dataset Param: input data instance shape");
45+
DMLC_DECLARE_FIELD(label_shape)
46+
.describe("Dataset Param: input label instance shape");
47+
}
48+
}; // struct SFrameImageParam
49+
50+
class SFrameIterBase : public IIterator<DataInst> {
51+
public:
52+
SFrameIterBase() {}
53+
54+
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
55+
param_.InitAllowUnknown(kwargs);
56+
sframe_ = graphlab::gl_sframe(param_.path_sframe)[{param_.data_field, param_.label_field}];
57+
range_it_.reset(new graphlab::gl_sframe_range(sframe_.range_iterator()));
58+
this->BeforeFirst();
59+
}
60+
61+
virtual ~SFrameIterBase() {}
62+
63+
virtual void BeforeFirst() {
64+
idx_ = 0;
65+
*range_it_ = sframe_.range_iterator();
66+
current_it_ = range_it_->begin();
67+
}
68+
69+
virtual const DataInst &Value(void) const {
70+
return out_;
71+
}
72+
73+
virtual bool Next() = 0;
74+
75+
protected:
76+
/*! \brief index of instance */
77+
index_t idx_;
78+
/*! \brief output of sframe iterator */
79+
DataInst out_;
80+
/*! \brief temp space */
81+
InstVector tmp_;
82+
/*! \brief sframe iter parameter */
83+
SFrameParam param_;
84+
/*! \brief sframe object*/
85+
graphlab::gl_sframe sframe_;
86+
/*! \brief sframe range iterator */
87+
std::unique_ptr<graphlab::gl_sframe_range> range_it_;
88+
/*! \brief current iterator in range iterator */
89+
graphlab::gl_sframe_range::iterator current_it_;
90+
91+
protected:
92+
/*! \brief copy data */
93+
template<int dim>
94+
void Copy_(mshadow::Tensor<cpu, dim> tensor, const graphlab::flex_vec &vec) {
95+
CHECK_EQ(tensor.shape_.Size(), vec.size());
96+
CHECK_EQ(tensor.CheckContiguous(), true);
97+
mshadow::Tensor<cpu, 1> flatten(tensor.dptr_, mshadow::Shape1(tensor.shape_.Size()));
98+
for (index_t i = 0; i < vec.size(); ++i) {
99+
flatten[i] = static_cast<float>(vec[i]);
100+
}
101+
}
102+
}; // class SFrameIterBase
103+
104+
class SFrameImageIter : public SFrameIterBase {
105+
public:
106+
SFrameImageIter() :
107+
augmenter_(new ImageAugmenter()), prnd_(new common::RANDOM_ENGINE(8964)) {}
108+
109+
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
110+
Parent::Init(kwargs);
111+
augmenter_->Init(kwargs);
112+
CHECK_EQ(Parent::param_.data_shape.ndim(), 3)
113+
<< "Image shpae must be (channel, height, width)";
114+
}
115+
116+
bool Next(void) override {
117+
if (Parent::current_it_ == Parent::range_it_->end()) {
118+
return false;
119+
}
120+
graphlab::image_type gl_img = (*Parent::current_it_)[0];
121+
graphlab::flex_vec gl_label = (*Parent::current_it_)[1];
122+
// TODO(bing): check not decoded
123+
// TODO(bing): check img shape
124+
CHECK_EQ(gl_label.size(), Parent::param_.label_shape.Size()) << "Label shape does not match";
125+
const unsigned char *raw_data = gl_img.get_image_data();
126+
cv::Mat res;
127+
cv::Mat buf(1, gl_img.m_image_data_size, CV_8U, const_cast<unsigned char*>(raw_data));
128+
res = cv::imdecode(buf, -1);
129+
res = augmenter_->Process(res, prnd_.get());
130+
const int n_channels = res.channels();
131+
if (!tmp_.Size()) {
132+
tmp_.Push(Parent::idx_++,
133+
Parent::param_.data_shape.get<3>(),
134+
Parent::param_.label_shape.get<1>());
135+
}
136+
mshadow::Tensor<cpu, 3> data = Parent::tmp_.data().Back();
137+
std::vector<int> swap_indices;
138+
if (n_channels == 1) swap_indices = {0};
139+
if (n_channels == 3) swap_indices = {2, 1, 0};
140+
for (int i = 0; i < res.rows; ++i) {
141+
uchar* im_data = res.ptr<uchar>(i);
142+
for (int j = 0; j < res.cols; ++j) {
143+
for (int k = 0; k < n_channels; ++k) {
144+
data[k][i][j] = im_data[swap_indices[k]];
145+
}
146+
im_data += n_channels;
147+
}
148+
}
149+
mshadow::Tensor<cpu, 1> label = Parent::tmp_.label().Back();
150+
Parent::Copy_<1>(label, gl_label);
151+
res.release();
152+
out_ = Parent::tmp_[0];
153+
++current_it_;
154+
return true;
155+
}
156+
157+
private:
158+
/*! \brief parent type */
159+
typedef SFrameIterBase Parent;
160+
/*! \brief image augmenter */
161+
std::unique_ptr<ImageAugmenter> augmenter_;
162+
/*! \brief randim generator*/
163+
std::unique_ptr<common::RANDOM_ENGINE> prnd_;
164+
}; // class SFrameImageIter
165+
166+
class SFrameDataIter : public SFrameIterBase {
167+
public:
168+
bool Next() override {
169+
if (Parent::current_it_ == Parent::range_it_->end()) {
170+
return false;
171+
}
172+
graphlab::flex_vec gl_data = (*Parent::current_it_)[0];
173+
graphlab::flex_vec gl_label = (*Parent::current_it_)[1];
174+
CHECK_EQ(gl_data.size(), Parent::param_.data_shape.Size()) << "Data shape does not match";
175+
CHECK_EQ(gl_label.size(), Parent::param_.label_shape.Size()) << "Label shape does not match";
176+
if (!Parent::tmp_.Size()) {
177+
Parent::tmp_.Push(Parent::idx_++,
178+
Parent::param_.data_shape.get<3>(),
179+
Parent::param_.label_shape.get<1>());
180+
}
181+
mshadow::Tensor<cpu, 3> data = Parent::tmp_.data().Back();
182+
Parent::Copy_<3>(data, gl_data);
183+
mshadow::Tensor<cpu, 1> label = Parent::tmp_.label().Back();
184+
Parent::Copy_<1>(label, gl_label);
185+
out_ = Parent::tmp_[0];
186+
++current_it_;
187+
return true;
188+
}
189+
190+
private:
191+
/*! \brief parent type */
192+
typedef SFrameIterBase Parent;
193+
}; // class SFrameDataIter
194+
195+
DMLC_REGISTER_PARAMETER(SFrameParam);
196+
197+
MXNET_REGISTER_IO_ITER(SFrameImageIter)
198+
.describe("Naive SFrame image iterator prototype")
199+
.add_arguments(SFrameParam::__FIELDS__())
200+
.add_arguments(BatchParam::__FIELDS__())
201+
.add_arguments(PrefetcherParam::__FIELDS__())
202+
.add_arguments(ImageAugmentParam::__FIELDS__())
203+
.add_arguments(ImageNormalizeParam::__FIELDS__())
204+
.set_body([]() {
205+
return new PrefetcherIter(
206+
new BatchLoader(
207+
new ImageNormalizeIter(
208+
new SFrameImageIter())));
209+
});
210+
211+
MXNET_REGISTER_IO_ITER(SFrameDataIter)
212+
.describe("Naive SFrame data iterator prototype")
213+
.add_arguments(SFrameParam::__FIELDS__())
214+
.add_arguments(BatchParam::__FIELDS__())
215+
.add_arguments(PrefetcherParam::__FIELDS__())
216+
.set_body([]() {
217+
return new PrefetcherIter(
218+
new BatchLoader(
219+
new SFrameDataIter()));
220+
});
221+
222+
223+
} // namespace io
224+
} // namespace mxnet
225+

ps-lite

0 commit comments

Comments
 (0)