|
| 1 | +/*! |
| 2 | + * Copyright (c) 2015 by Contributors |
| 3 | + * \file iter_sframe_image.cc |
| 4 | + * \brief |
| 5 | + * \author Bing Xu |
| 6 | +*/ |
| 7 | + |
| 8 | +#include <unity/lib/image_util.hpp> |
| 9 | +#include <unity/lib/gl_sframe.hpp> |
| 10 | +#include <unity/lib/gl_sarray.hpp> |
| 11 | +#include <mxnet/io.h> |
| 12 | +#include <dmlc/base.h> |
| 13 | +#include <dmlc/io.h> |
| 14 | +#include <dmlc/omp.h> |
| 15 | +#include <dmlc/logging.h> |
| 16 | +#include <dmlc/parameter.h> |
| 17 | +#include <string> |
| 18 | +#include <memory> |
| 19 | +#include "../../src/io/inst_vector.h" |
| 20 | +#include "../../src/io/image_recordio.h" |
| 21 | +#include "../../src/io/image_augmenter.h" |
| 22 | +#include "../../src/io/iter_prefetcher.h" |
| 23 | +#include "../../src/io/iter_normalize.h" |
| 24 | +#include "../../src/io/iter_batchloader.h" |
| 25 | + |
| 26 | +namespace mxnet { |
| 27 | +namespace io { |
| 28 | + |
| 29 | +struct SFrameParam : public dmlc::Parameter<SFrameParam> { |
| 30 | + /*! \brief sframe path */ |
| 31 | + std::string path_sframe; |
| 32 | + std::string data_field; |
| 33 | + std::string label_field; |
| 34 | + TShape data_shape; |
| 35 | + TShape label_shape; |
| 36 | + DMLC_DECLARE_PARAMETER(SFrameParam) { |
| 37 | + DMLC_DECLARE_FIELD(path_sframe).set_default("") |
| 38 | + .describe("Dataset Param: path to image dataset sframe"); |
| 39 | + DMLC_DECLARE_FIELD(data_field).set_default("data") |
| 40 | + .describe("Dataset Param: data column in sframe"); |
| 41 | + DMLC_DECLARE_FIELD(label_field).set_default("label") |
| 42 | + .describe("Dataset Param: label column in sframe"); |
| 43 | + DMLC_DECLARE_FIELD(data_shape) |
| 44 | + .describe("Dataset Param: input data instance shape"); |
| 45 | + DMLC_DECLARE_FIELD(label_shape) |
| 46 | + .describe("Dataset Param: input label instance shape"); |
| 47 | + } |
| 48 | +}; // struct SFrameImageParam |
| 49 | + |
| 50 | +class SFrameIterBase : public IIterator<DataInst> { |
| 51 | + public: |
| 52 | + SFrameIterBase() {} |
| 53 | + |
| 54 | + void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { |
| 55 | + param_.InitAllowUnknown(kwargs); |
| 56 | + sframe_ = graphlab::gl_sframe(param_.path_sframe)[{param_.data_field, param_.label_field}]; |
| 57 | + range_it_.reset(new graphlab::gl_sframe_range(sframe_.range_iterator())); |
| 58 | + this->BeforeFirst(); |
| 59 | + } |
| 60 | + |
| 61 | + virtual ~SFrameIterBase() {} |
| 62 | + |
| 63 | + virtual void BeforeFirst() { |
| 64 | + idx_ = 0; |
| 65 | + *range_it_ = sframe_.range_iterator(); |
| 66 | + current_it_ = range_it_->begin(); |
| 67 | + } |
| 68 | + |
| 69 | + virtual const DataInst &Value(void) const { |
| 70 | + return out_; |
| 71 | + } |
| 72 | + |
| 73 | + virtual bool Next() = 0; |
| 74 | + |
| 75 | + protected: |
| 76 | + /*! \brief index of instance */ |
| 77 | + index_t idx_; |
| 78 | + /*! \brief output of sframe iterator */ |
| 79 | + DataInst out_; |
| 80 | + /*! \brief temp space */ |
| 81 | + InstVector tmp_; |
| 82 | + /*! \brief sframe iter parameter */ |
| 83 | + SFrameParam param_; |
| 84 | + /*! \brief sframe object*/ |
| 85 | + graphlab::gl_sframe sframe_; |
| 86 | + /*! \brief sframe range iterator */ |
| 87 | + std::unique_ptr<graphlab::gl_sframe_range> range_it_; |
| 88 | + /*! \brief current iterator in range iterator */ |
| 89 | + graphlab::gl_sframe_range::iterator current_it_; |
| 90 | + |
| 91 | + protected: |
| 92 | + /*! \brief copy data */ |
| 93 | + template<int dim> |
| 94 | + void Copy_(mshadow::Tensor<cpu, dim> tensor, const graphlab::flex_vec &vec) { |
| 95 | + CHECK_EQ(tensor.shape_.Size(), vec.size()); |
| 96 | + CHECK_EQ(tensor.CheckContiguous(), true); |
| 97 | + mshadow::Tensor<cpu, 1> flatten(tensor.dptr_, mshadow::Shape1(tensor.shape_.Size())); |
| 98 | + for (index_t i = 0; i < vec.size(); ++i) { |
| 99 | + flatten[i] = static_cast<float>(vec[i]); |
| 100 | + } |
| 101 | + } |
| 102 | +}; // class SFrameIterBase |
| 103 | + |
| 104 | +class SFrameImageIter : public SFrameIterBase { |
| 105 | + public: |
| 106 | + SFrameImageIter() : |
| 107 | + augmenter_(new ImageAugmenter()), prnd_(new common::RANDOM_ENGINE(8964)) {} |
| 108 | + |
| 109 | + void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { |
| 110 | + Parent::Init(kwargs); |
| 111 | + augmenter_->Init(kwargs); |
| 112 | + CHECK_EQ(Parent::param_.data_shape.ndim(), 3) |
| 113 | + << "Image shpae must be (channel, height, width)"; |
| 114 | + } |
| 115 | + |
| 116 | + bool Next(void) override { |
| 117 | + if (Parent::current_it_ == Parent::range_it_->end()) { |
| 118 | + return false; |
| 119 | + } |
| 120 | + graphlab::image_type gl_img = (*Parent::current_it_)[0]; |
| 121 | + graphlab::flex_vec gl_label = (*Parent::current_it_)[1]; |
| 122 | + // TODO(bing): check not decoded |
| 123 | + // TODO(bing): check img shape |
| 124 | + CHECK_EQ(gl_label.size(), Parent::param_.label_shape.Size()) << "Label shape does not match"; |
| 125 | + const unsigned char *raw_data = gl_img.get_image_data(); |
| 126 | + cv::Mat res; |
| 127 | + cv::Mat buf(1, gl_img.m_image_data_size, CV_8U, const_cast<unsigned char*>(raw_data)); |
| 128 | + res = cv::imdecode(buf, -1); |
| 129 | + res = augmenter_->Process(res, prnd_.get()); |
| 130 | + const int n_channels = res.channels(); |
| 131 | + if (!tmp_.Size()) { |
| 132 | + tmp_.Push(Parent::idx_++, |
| 133 | + Parent::param_.data_shape.get<3>(), |
| 134 | + Parent::param_.label_shape.get<1>()); |
| 135 | + } |
| 136 | + mshadow::Tensor<cpu, 3> data = Parent::tmp_.data().Back(); |
| 137 | + std::vector<int> swap_indices; |
| 138 | + if (n_channels == 1) swap_indices = {0}; |
| 139 | + if (n_channels == 3) swap_indices = {2, 1, 0}; |
| 140 | + for (int i = 0; i < res.rows; ++i) { |
| 141 | + uchar* im_data = res.ptr<uchar>(i); |
| 142 | + for (int j = 0; j < res.cols; ++j) { |
| 143 | + for (int k = 0; k < n_channels; ++k) { |
| 144 | + data[k][i][j] = im_data[swap_indices[k]]; |
| 145 | + } |
| 146 | + im_data += n_channels; |
| 147 | + } |
| 148 | + } |
| 149 | + mshadow::Tensor<cpu, 1> label = Parent::tmp_.label().Back(); |
| 150 | + Parent::Copy_<1>(label, gl_label); |
| 151 | + res.release(); |
| 152 | + out_ = Parent::tmp_[0]; |
| 153 | + ++current_it_; |
| 154 | + return true; |
| 155 | + } |
| 156 | + |
| 157 | + private: |
| 158 | + /*! \brief parent type */ |
| 159 | + typedef SFrameIterBase Parent; |
| 160 | + /*! \brief image augmenter */ |
| 161 | + std::unique_ptr<ImageAugmenter> augmenter_; |
| 162 | + /*! \brief randim generator*/ |
| 163 | + std::unique_ptr<common::RANDOM_ENGINE> prnd_; |
| 164 | +}; // class SFrameImageIter |
| 165 | + |
| 166 | +class SFrameDataIter : public SFrameIterBase { |
| 167 | + public: |
| 168 | + bool Next() override { |
| 169 | + if (Parent::current_it_ == Parent::range_it_->end()) { |
| 170 | + return false; |
| 171 | + } |
| 172 | + graphlab::flex_vec gl_data = (*Parent::current_it_)[0]; |
| 173 | + graphlab::flex_vec gl_label = (*Parent::current_it_)[1]; |
| 174 | + CHECK_EQ(gl_data.size(), Parent::param_.data_shape.Size()) << "Data shape does not match"; |
| 175 | + CHECK_EQ(gl_label.size(), Parent::param_.label_shape.Size()) << "Label shape does not match"; |
| 176 | + if (!Parent::tmp_.Size()) { |
| 177 | + Parent::tmp_.Push(Parent::idx_++, |
| 178 | + Parent::param_.data_shape.get<3>(), |
| 179 | + Parent::param_.label_shape.get<1>()); |
| 180 | + } |
| 181 | + mshadow::Tensor<cpu, 3> data = Parent::tmp_.data().Back(); |
| 182 | + Parent::Copy_<3>(data, gl_data); |
| 183 | + mshadow::Tensor<cpu, 1> label = Parent::tmp_.label().Back(); |
| 184 | + Parent::Copy_<1>(label, gl_label); |
| 185 | + out_ = Parent::tmp_[0]; |
| 186 | + ++current_it_; |
| 187 | + return true; |
| 188 | + } |
| 189 | + |
| 190 | + private: |
| 191 | + /*! \brief parent type */ |
| 192 | + typedef SFrameIterBase Parent; |
| 193 | +}; // class SFrameDataIter |
| 194 | + |
| 195 | +DMLC_REGISTER_PARAMETER(SFrameParam); |
| 196 | + |
| 197 | +MXNET_REGISTER_IO_ITER(SFrameImageIter) |
| 198 | +.describe("Naive SFrame image iterator prototype") |
| 199 | +.add_arguments(SFrameParam::__FIELDS__()) |
| 200 | +.add_arguments(BatchParam::__FIELDS__()) |
| 201 | +.add_arguments(PrefetcherParam::__FIELDS__()) |
| 202 | +.add_arguments(ImageAugmentParam::__FIELDS__()) |
| 203 | +.add_arguments(ImageNormalizeParam::__FIELDS__()) |
| 204 | +.set_body([]() { |
| 205 | + return new PrefetcherIter( |
| 206 | + new BatchLoader( |
| 207 | + new ImageNormalizeIter( |
| 208 | + new SFrameImageIter()))); |
| 209 | + }); |
| 210 | + |
| 211 | +MXNET_REGISTER_IO_ITER(SFrameDataIter) |
| 212 | +.describe("Naive SFrame data iterator prototype") |
| 213 | +.add_arguments(SFrameParam::__FIELDS__()) |
| 214 | +.add_arguments(BatchParam::__FIELDS__()) |
| 215 | +.add_arguments(PrefetcherParam::__FIELDS__()) |
| 216 | +.set_body([]() { |
| 217 | + return new PrefetcherIter( |
| 218 | + new BatchLoader( |
| 219 | + new SFrameDataIter())); |
| 220 | + }); |
| 221 | + |
| 222 | + |
| 223 | +} // namespace io |
| 224 | +} // namespace mxnet |
| 225 | + |
0 commit comments