From 05e07287a7fde466e92f2874abcd5b684ba2a49e Mon Sep 17 00:00:00 2001 From: ziheng Date: Mon, 29 May 2017 23:37:08 -0700 Subject: [PATCH] Change Interface of NDArray & TBlob for DLPack Compatible (#6345) * Change Interface of NDArray & TBlob for DLPack Compatible Fix for cudnn operator Fix cpp tests * Update nnvm * Fix for MKL mem * Fix for windows macro * Bump up version number to 0.10.1 * Update NDArray Save&Load * trigger update * Add test for legacy data load * Use LegacyTShapeLoad * trigger update * Update tensor_blob.h --- .gitmodules | 3 + CMakeLists.txt | 1 + Makefile | 6 +- R-package/DESCRIPTION | 2 +- dlpack | 1 + include/mxnet/base.h | 2 +- include/mxnet/c_api.h | 2 +- include/mxnet/ndarray.h | 47 ++++--- include/mxnet/tensor_blob.h | 128 ++++++++++++------ nnvm | 2 +- python/mxnet/libinfo.py | 2 +- .../assembly/linux-x86_64-cpu/pom.xml | 4 +- .../assembly/linux-x86_64-gpu/pom.xml | 4 +- scala-package/assembly/osx-x86_64-cpu/pom.xml | 4 +- scala-package/assembly/pom.xml | 4 +- scala-package/core/pom.xml | 4 +- scala-package/examples/pom.xml | 6 +- .../init-native/linux-x86_64/pom.xml | 4 +- scala-package/init-native/osx-x86_64/pom.xml | 4 +- scala-package/init-native/pom.xml | 4 +- scala-package/init/pom.xml | 4 +- scala-package/macros/pom.xml | 4 +- scala-package/native/linux-x86_64-cpu/pom.xml | 4 +- scala-package/native/linux-x86_64-gpu/pom.xml | 4 +- scala-package/native/osx-x86_64-cpu/pom.xml | 4 +- scala-package/native/pom.xml | 4 +- scala-package/pom.xml | 2 +- scala-package/spark/pom.xml | 4 +- snapcraft.yaml | 2 +- src/c_api/c_api.cc | 6 +- src/c_api/c_api_common.h | 14 +- src/c_api/c_api_ndarray.cc | 1 - src/c_api/c_api_symbolic.cc | 20 +-- src/c_api/c_predict_api.cc | 14 +- src/executor/graph_executor.cc | 4 +- src/io/image_io.cc | 4 +- src/io/iter_batchloader.h | 2 +- src/io/iter_csv.cc | 2 +- src/io/iter_image_recordio_2.cc | 2 +- src/ndarray/ndarray.cc | 30 +++- src/operator/cudnn_convolution-inl.h | 59 +++++--- src/operator/cudnn_deconvolution-inl.h | 38 ++++-- src/operator/custom/custom-inl.h | 11 +- src/operator/custom/custom.cc | 9 +- src/operator/custom/native_op-inl.h | 19 ++- src/operator/custom/ndarray_op-inl.h | 9 +- src/operator/deconvolution-inl.h | 6 +- src/operator/tensor/control_flow_op.h | 2 +- src/operator/tensor/matrix_op-inl.h | 28 ++-- tests/cpp/include/test_util.h | 4 +- tests/python/unittest/legacy_ndarray.v0 | Bin 0 -> 3224 bytes tests/python/unittest/test_ndarray.py | 9 ++ 52 files changed, 384 insertions(+), 175 deletions(-) create mode 160000 dlpack create mode 100644 tests/python/unittest/legacy_ndarray.v0 diff --git a/.gitmodules b/.gitmodules index 08f2bc99f2aa..bfe84d7f0615 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,6 +10,9 @@ [submodule "nnvm"] path = nnvm url = https://github.com/dmlc/nnvm +[submodule "dlpack"] + path = dlpack + url = https://github.com/dmlc/dlpack [submodule "cub"] path = cub url = https://github.com/NVlabs/cub diff --git a/CMakeLists.txt b/CMakeLists.txt index c8260e94e9bc..d0835300edaa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -142,6 +142,7 @@ include_directories("mshadow") include_directories("cub") include_directories("nnvm/include") include_directories("dmlc-core/include") +include_directories("dlpack/include") if(NOT MSVC) set(BEGIN_WHOLE_ARCHIVE -Wl,--whole-archive) diff --git a/Makefile b/Makefile index c71cb1398963..501a170abeda 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,10 @@ ifndef NNVM_PATH NNVM_PATH = $(ROOTDIR)/nnvm endif +ifndef DLPACK_PATH + DLPACK_PATH = $(ROOTDIR)/dlpack +endif + ifneq ($(USE_OPENMP), 1) export NO_OPENMP = 1 endif @@ -49,7 +53,7 @@ ifeq ($(DEBUG), 1) else CFLAGS += -O3 -DNDEBUG=1 endif -CFLAGS += -I$(ROOTDIR)/mshadow/ -I$(ROOTDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -Iinclude $(MSHADOW_CFLAGS) +CFLAGS += -I$(ROOTDIR)/mshadow/ -I$(ROOTDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -Iinclude $(MSHADOW_CFLAGS) LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS) ifeq ($(DEBUG), 1) NVCCFLAGS = -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS) diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index 6aed92808020..2c8c8aa04d8d 100644 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -1,7 +1,7 @@ Package: mxnet Type: Package Title: MXNet -Version: 0.10.0 +Version: 0.10.1 Date: 2015-12-23 Author: Tianqi Chen, Qiang Kou, Tong He Maintainer: Qiang Kou diff --git a/dlpack b/dlpack new file mode 160000 index 000000000000..a6e09b58dc00 --- /dev/null +++ b/dlpack @@ -0,0 +1 @@ +Subproject commit a6e09b58dc00ee0065f5b7879800e646fbb01d1e diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 8747109ce564..0c4c9d3daa77 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -87,7 +87,7 @@ /*! \brief minor version */ #define MXNET_MINOR 10 /*! \brief patch version */ -#define MXNET_PATCH 0 +#define MXNET_PATCH 1 /*! \brief mxnet version */ #define MXNET_VERSION (MXNET_MAJOR*10000 + MXNET_MINOR*100 + MXNET_PATCH) /*! \brief helper for making version number */ diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 1b112abe2ba9..4508a51e64d4 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -390,7 +390,7 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, const mx_uint **out_pdata); /*! * \brief get the content of the data in NDArray - * \param handle the handle to the narray + * \param handle the handle to the ndarray * \param out_pdata pointer holder to get pointer of data * \return 0 when success, -1 when failure happens */ diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index ea38909d07f1..b8cd550118d3 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -57,10 +57,10 @@ class AutogradRuntime; */ class NDArray { public: - /*! \brief default cosntructor */ + /*! \brief default constructor */ NDArray() { #if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = MKLMemHolder::create(); + Mkl_mem_ = MKLMemHolder::create(); #endif } /*! @@ -75,7 +75,7 @@ class NDArray { : ptr_(std::make_shared(shape.Size(), ctx, delay_alloc, dtype)), shape_(shape), offset_(0), dtype_(dtype), entry_({nullptr, 0, 0}) { #if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = std::make_shared(); + Mkl_mem_ = std::make_shared(); #endif } /*! @@ -89,29 +89,32 @@ class NDArray { : ptr_(std::make_shared(data, dev_id)), shape_(data.shape_), offset_(0), dtype_(data.type_flag_), entry_({nullptr, 0, 0}) { #if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = std::make_shared(); + Mkl_mem_ = std::make_shared(); #endif } /*! * \return the shape of current NDArray */ - inline const TShape &shape() const { + inline const TShape& shape() const { return shape_; } /*! * \return the data TBlob */ - inline TBlob data() const { + inline const TBlob& data() const { CheckAndAlloc(); - TBlob res; +#if MKL_EXPERIMENTAL == 1 MSHADOW_TYPE_SWITCH(dtype_, DType, { - res = TBlob(static_cast(ptr_->shandle.dptr) - + offset_, shape_, ptr_->shandle.ctx.dev_mask()); + tblob_ = TBlob(static_cast(ptr_->shandle.dptr) + offset_, + shape_, ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id, Mkl_mem_); + }); +#else + MSHADOW_TYPE_SWITCH(dtype_, DType, { + tblob_ = TBlob(static_cast(ptr_->shandle.dptr) + offset_, + shape_, ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id); }); -#if MKL_EXPERIMENTAL == 1 - res.Mkl_mem_ = Mkl_mem_; #endif - return res; + return tblob_; } /*! * \return a chunk of raw data in TBlob @@ -122,8 +125,8 @@ class NDArray { TShape raw_shape(1); raw_shape[0] = length; MSHADOW_TYPE_SWITCH(dtype_, DType, { - res = TBlob(static_cast(ptr_->shandle.dptr) - + offset_ + offset, raw_shape, ptr_->shandle.ctx.dev_mask()); + res = TBlob(static_cast(ptr_->shandle.dptr) + offset_ + offset, + raw_shape, ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id); }); #if MKL_EXPERIMENTAL == 1 res.Mkl_mem_ = Mkl_mem_; @@ -326,7 +329,7 @@ class NDArray { ptr_->CheckAndAlloc(); } /*! - * \brief Save list of narray into the Stream.x + * \brief Save list of ndarray into the Stream.x * \param fo The stream of output. * \param data the NDArrays to be saved. * \param names the name of the NDArray, optional, can be zero length. @@ -335,7 +338,7 @@ class NDArray { const std::vector& data, const std::vector& names); /*! - * \brief Load list of narray into from the stream. + * \brief Load list of ndarray into from the stream. * \param fi The stream of the input file. * \param data the NDArrays to be loaded * \param keys the name of the NDArray, if saved in the file. @@ -368,10 +371,10 @@ class NDArray { : static_data(true), delay_alloc(false) { var = Engine::Get()->NewVariable(); - if (data.dev_mask_ == cpu::kDevMask) { + if (data.dev_mask() == cpu::kDevMask) { shandle.ctx = Context::CPU(); } else { - CHECK_EQ(data.dev_mask_, gpu::kDevMask); + CHECK_EQ(data.dev_mask(), gpu::kDevMask); shandle.ctx = Context::GPU(dev_id); } shandle.dptr = data.dptr_; @@ -418,6 +421,14 @@ class NDArray { int dtype_ = -1; /*! \brief node entry for autograd */ autograd::AGNodeEntry entry_; + /*! + * \brief internal TBlob + * \note When user access tblob_ by some const methods like + * NDArray::data(), the dptr in tblob_ still need to be updated + * in case that allocation happens. So we make it mutable for + * this situation. + */ + mutable TBlob tblob_; }; /*! diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h index e4e335666d80..d142c20aa30a 100755 --- a/include/mxnet/tensor_blob.h +++ b/include/mxnet/tensor_blob.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -39,12 +40,6 @@ class TBlob { void *dptr_; /*! \brief shape of the tensor */ TShape shape_; - /*! - * \brief storing the stride information in x dimension - */ - index_t stride_; - /*! \brief device mask of the corresponding device */ - int dev_mask_; /*! \brief type flag of the tensor blob */ int type_flag_; @@ -54,49 +49,61 @@ class TBlob { #endif /*! \brief default constructor, default copy assign will work */ TBlob(void) - : dptr_(NULL), dev_mask_(cpu::kDevMask), + : dptr_(NULL), type_flag_(mshadow::DataType::kFlag) { #if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = NULL; + Mkl_mem_ = NULL; #endif + SetDLTensor(cpu::kDevMask, 0); } /*! * \brief constructor that construct TBlob from contiguous memory * \param dptr the pointer to the memory * \param shape the shape of the data * \param dev_mask the device mask, can be cpu::kDevMask or gpu::kDevMask + * \param dev_id the device id */ template - TBlob(DType *dptr, - const TShape &shape, - int dev_mask) + TBlob(DType *dptr, const TShape &shape, int dev_mask, int dev_id = -1) : dptr_(dptr), shape_(shape), - stride_(shape[shape.ndim() - 1]), - dev_mask_(dev_mask), type_flag_(mshadow::DataType::kFlag) { #if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = NULL; + Mkl_mem_ = NULL; #endif + SetDLTensor(dev_mask, dev_id); } - +#if MKL_EXPERIMENTAL == 1 /*! * \brief constructor that construct TBlob from contiguous memory * \param dptr the pointer to the memory * \param shape the shape of the data * \param dev_mask the device mask, can be cpu::kDevMask or gpu::kDevMask - * \param type_flag the type flag. Can be one of enum mshadow::dtype + * \param dev_id the device id + * \param Mkl_mem the mkl memory */ - TBlob(void *dptr, - const TShape &shape, - int dev_mask, - int type_flag) + template + TBlob(DType *dptr, const TShape &shape, int dev_mask, int dev_id, + std::shared_ptr Mkl_mem) : dptr_(dptr), shape_(shape), - stride_(shape[shape.ndim() - 1]), - dev_mask_(dev_mask), - type_flag_(type_flag) { + type_flag_(mshadow::DataType::kFlag), + Mkl_mem_(Mkl_mem) { + SetDLTensor(dev_mask, dev_id); + } +#endif + /*! + * \brief constructor that construct TBlob from contiguous memory + * \param dptr the pointer to the memory + * \param shape the shape of the data + * \param dev_mask the device mask, can be cpu::kDevMask or gpu::kDevMask + * \param type_flag the type flag. Can be one of enum mshadow::dtype + * \param dev_id the device id + */ + TBlob(void *dptr, const TShape &shape, int dev_mask, int type_flag, int dev_id = -1) + : dptr_(dptr), shape_(shape), type_flag_(type_flag) { #if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = NULL; + Mkl_mem_ = NULL; #endif + SetDLTensor(dev_mask, dev_id); } /*! * \brief constructor from tensor @@ -108,9 +115,6 @@ class TBlob { template TBlob(const mshadow::Tensor &src) { // NOLINT(*) *this = src; -#if MKL_EXPERIMENTAL == 1 - Mkl_mem_ = NULL; -#endif } /*! * \brief assignment from tensor @@ -121,20 +125,21 @@ class TBlob { * \return reference of self */ template - inline TBlob - &operator=(const mshadow::Tensor &src) { + inline TBlob &operator=(const mshadow::Tensor &src) { dptr_ = src.dptr_; shape_ = src.shape_; - stride_ = src.stride_; - dev_mask_ = Device::kDevMask; type_flag_ = mshadow::DataType::kFlag; + SetDLTensor(Device::kDevMask, -1); +#if MKL_EXPERIMENTAL == 1 + Mkl_mem_ = NULL; +#endif return *this; } /*! * \return whether the tensor's memory is continuous */ inline bool CheckContiguous(void) const { - return shape_[shape_.ndim() - 1] == stride_; + return true; } /*! * \brief reshape to shape @@ -144,7 +149,7 @@ class TBlob { inline TBlob reshape(const TShape& shape) const { CHECK_EQ(this->shape_.Size(), shape.Size()) << "Shape size mismatch " << this->shape_.Size() << " v.s. " << shape.Size(); - TBlob ret(this->dptr_, shape, this->dev_mask_, this->type_flag_); + TBlob ret(this->dptr_, shape, this->dev_mask(), this->type_flag_, this->dev_id()); return ret; } /*! @@ -157,7 +162,7 @@ class TBlob { template inline mshadow::Tensor FlatTo2D( mshadow::Stream *stream = NULL) const { - CHECK(Device::kDevMask == dev_mask_) + CHECK(Device::kDevMask == this->dev_mask()) << "TBlob.get: device type do not match specified type"; CHECK(mshadow::DataType::kFlag == type_flag_) << "TBlob.get_with_shape: data type do not match specified type." @@ -168,7 +173,9 @@ class TBlob { } #endif return mshadow::Tensor(static_cast(dptr_), - shape_.FlatTo2D(), stride_, stream); + shape_.FlatTo2D(), + shape_[shape_.ndim() - 1], + stream); } /*! * \brief flatten the tensor to 1 dimension, collapse all the dimensions together. @@ -212,6 +219,22 @@ class TBlob { #endif return static_cast(dptr_); } + /*! \brief device mask of the corresponding device */ + inline int dev_mask() const { + return dltensor_.ctx.device_type; + } + /*! \brief device index of the corresponding device */ + inline int dev_id() const { + return dltensor_.ctx.device_id; + } + /*! + * \brief return the corresponding DLTensor + * \return the address of internal DLTensor + */ + inline const DLTensor& dltensor() { + return dltensor_; + } + /*! * \brief fetch the tensor, with respect to specific dimension * if dim do not match the stored dimension, an error will be issued @@ -223,9 +246,10 @@ class TBlob { */ template inline mshadow::Tensor get(mshadow::Stream *stream = NULL) const { - CHECK(Device::kDevMask == dev_mask_) + CHECK(Device::kDevMask == this->dev_mask()) << "TBlob.get: device type do not match specified type"; - return mshadow::Tensor(dptr(), shape_.get(), stride_, stream); + return mshadow::Tensor(dptr(), + shape_.get(), shape_[shape_.ndim() - 1], stream); } /*! * \brief fetch a tensor in given shape @@ -241,7 +265,7 @@ class TBlob { inline mshadow::Tensor get_with_shape( const mshadow::Shape &shape, mshadow::Stream *stream = NULL) const { - CHECK(Device ::kDevMask == dev_mask_) + CHECK(Device::kDevMask == this->dev_mask()) << "TBlob.get: device type do not match specified type"; CHECK_EQ(this->CheckContiguous(), true) << "TBlob.get_reshape: must be contiguous"; CHECK_EQ(this->shape_.Size(), shape.Size()) @@ -281,6 +305,34 @@ class TBlob { return this->get_with_shape( this->shape_.FlatTo3D(axis_begin, axis_end), stream); } + + private: + static DLDataType DTypeTransform(int type_flag) { + static std::unordered_map + MSHADOW_DTYPE_TO_DLPACK_DTYPE = { + {0, {2, 32, 1}}, // Float32 + {1, {2, 64, 1}}, // Float64 + {2, {2, 16, 1}}, // Float16 + {3, {1, 8, 1}}, // UInt8 + {4, {0, 32, 1}}, // Int32 + {5, {0, 8, 1}} // Int8 + }; + return MSHADOW_DTYPE_TO_DLPACK_DTYPE[type_flag]; + } + + inline void SetDLTensor(int dev_mask, int dev_id) { + dltensor_.data = dptr_; + dltensor_.ctx = DLContext{static_cast(dev_mask), dev_id}; + dltensor_.ndim = shape_.ndim(); + dltensor_.dtype = DTypeTransform(type_flag_); + dltensor_.shape = shape_.data(); + dltensor_.strides = NULL; + dltensor_.byte_offset = 0; + } + + private: + /*! \brief corresponding DLTensor of this TBlob */ + DLTensor dltensor_; }; } // namespace mxnet diff --git a/nnvm b/nnvm index b279286304ac..93072dc8733a 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit b279286304ac954098d94a2695bca599e832effb +Subproject commit 93072dc8733aa2a89459ecf16413d96ad0b998db diff --git a/python/mxnet/libinfo.py b/python/mxnet/libinfo.py index 57b0a2c18130..a24756632c10 100644 --- a/python/mxnet/libinfo.py +++ b/python/mxnet/libinfo.py @@ -44,4 +44,4 @@ def find_lib_path(): # current version -__version__ = "0.10.0" +__version__ = "0.10.1" diff --git a/scala-package/assembly/linux-x86_64-cpu/pom.xml b/scala-package/assembly/linux-x86_64-cpu/pom.xml index d6639973d5c4..138c5c84304f 100644 --- a/scala-package/assembly/linux-x86_64-cpu/pom.xml +++ b/scala-package/assembly/linux-x86_64-cpu/pom.xml @@ -6,13 +6,13 @@ ml.dmlc.mxnet mxnet-full-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml ml.dmlc.mxnet mxnet-full_2.11-linux-x86_64-cpu - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Full Linux-x86_64 CPU-only jar diff --git a/scala-package/assembly/linux-x86_64-gpu/pom.xml b/scala-package/assembly/linux-x86_64-gpu/pom.xml index 38b2bd623865..7e818cb28123 100644 --- a/scala-package/assembly/linux-x86_64-gpu/pom.xml +++ b/scala-package/assembly/linux-x86_64-gpu/pom.xml @@ -6,13 +6,13 @@ ml.dmlc.mxnet mxnet-full-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml ml.dmlc.mxnet mxnet-full_2.11-linux-x86_64-gpu - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Full Linux-x86_64 GPU jar diff --git a/scala-package/assembly/osx-x86_64-cpu/pom.xml b/scala-package/assembly/osx-x86_64-cpu/pom.xml index f72be6dc17ff..ead035668892 100644 --- a/scala-package/assembly/osx-x86_64-cpu/pom.xml +++ b/scala-package/assembly/osx-x86_64-cpu/pom.xml @@ -6,13 +6,13 @@ ml.dmlc.mxnet mxnet-full-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml ml.dmlc.mxnet mxnet-full_2.11-osx-x86_64-cpu - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Full OSX-x86_64 CPU-only jar diff --git a/scala-package/assembly/pom.xml b/scala-package/assembly/pom.xml index b5b52ff35646..a1009ae6b08c 100644 --- a/scala-package/assembly/pom.xml +++ b/scala-package/assembly/pom.xml @@ -6,13 +6,13 @@ ml.dmlc.mxnet mxnet-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml ml.dmlc.mxnet mxnet-full-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Full Parent pom diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index 592e15b5c49e..7f639b9a8d39 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -6,13 +6,13 @@ ml.dmlc.mxnet mxnet-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml ml.dmlc.mxnet mxnet-core_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Core diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml index fa99ed9c44da..bda4fcdab5c4 100644 --- a/scala-package/examples/pom.xml +++ b/scala-package/examples/pom.xml @@ -6,12 +6,12 @@ ml.dmlc.mxnet mxnet-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml mxnet-examples_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Examples @@ -84,7 +84,7 @@ package copy-dependencies - + ${project.build.outputDirectory}/lib runtime diff --git a/scala-package/init-native/linux-x86_64/pom.xml b/scala-package/init-native/linux-x86_64/pom.xml index 61f49e598a3d..7e6c02aefd83 100644 --- a/scala-package/init-native/linux-x86_64/pom.xml +++ b/scala-package/init-native/linux-x86_64/pom.xml @@ -6,12 +6,12 @@ ml.dmlc.mxnet mxnet-scala-init-native-parent - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml libmxnet-init-scala-linux-x86_64 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Initializer Native Linux-x86_64 http://maven.apache.org diff --git a/scala-package/init-native/osx-x86_64/pom.xml b/scala-package/init-native/osx-x86_64/pom.xml index 449f66e3ba7f..4f5125c06f15 100644 --- a/scala-package/init-native/osx-x86_64/pom.xml +++ b/scala-package/init-native/osx-x86_64/pom.xml @@ -6,12 +6,12 @@ ml.dmlc.mxnet mxnet-scala-init-native-parent - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml libmxnet-init-scala-osx-x86_64 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Initializer Native OSX-x86_64 http://maven.apache.org diff --git a/scala-package/init-native/pom.xml b/scala-package/init-native/pom.xml index 8e02d45d015f..3ce227a9b587 100644 --- a/scala-package/init-native/pom.xml +++ b/scala-package/init-native/pom.xml @@ -6,12 +6,12 @@ ml.dmlc.mxnet mxnet-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml mxnet-scala-init-native-parent - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Initializer Native Parent pom diff --git a/scala-package/init/pom.xml b/scala-package/init/pom.xml index 44bf7a677abc..9f079565874e 100644 --- a/scala-package/init/pom.xml +++ b/scala-package/init/pom.xml @@ -6,12 +6,12 @@ ml.dmlc.mxnet mxnet-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml mxnet-init_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Initializer diff --git a/scala-package/macros/pom.xml b/scala-package/macros/pom.xml index aec0c2897fe5..fd7fe3e4ab7b 100644 --- a/scala-package/macros/pom.xml +++ b/scala-package/macros/pom.xml @@ -6,12 +6,12 @@ ml.dmlc.mxnet mxnet-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml mxnet-macros_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Macros diff --git a/scala-package/native/linux-x86_64-cpu/pom.xml b/scala-package/native/linux-x86_64-cpu/pom.xml index 4aae3d8f1bf9..b2cfa4263cda 100644 --- a/scala-package/native/linux-x86_64-cpu/pom.xml +++ b/scala-package/native/linux-x86_64-cpu/pom.xml @@ -6,13 +6,13 @@ ml.dmlc.mxnet mxnet-scala-native-parent - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml ml.dmlc.mxnet libmxnet-scala-linux-x86_64-cpu - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Native Linux-x86_64 CPU-only http://maven.apache.org diff --git a/scala-package/native/linux-x86_64-gpu/pom.xml b/scala-package/native/linux-x86_64-gpu/pom.xml index f0a158031ded..27f9221c3bad 100644 --- a/scala-package/native/linux-x86_64-gpu/pom.xml +++ b/scala-package/native/linux-x86_64-gpu/pom.xml @@ -6,13 +6,13 @@ ml.dmlc.mxnet mxnet-scala-native-parent - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml ml.dmlc.mxnet libmxnet-scala-linux-x86_64-gpu - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Native Linux-x86_64 GPU http://maven.apache.org diff --git a/scala-package/native/osx-x86_64-cpu/pom.xml b/scala-package/native/osx-x86_64-cpu/pom.xml index fa82d31ee386..f924106a605c 100644 --- a/scala-package/native/osx-x86_64-cpu/pom.xml +++ b/scala-package/native/osx-x86_64-cpu/pom.xml @@ -6,12 +6,12 @@ ml.dmlc.mxnet mxnet-scala-native-parent - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml libmxnet-scala-osx-x86_64-cpu - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Native OSX-x86_64 CPU-only http://maven.apache.org diff --git a/scala-package/native/pom.xml b/scala-package/native/pom.xml index dbf286c633e6..0af9e087f906 100644 --- a/scala-package/native/pom.xml +++ b/scala-package/native/pom.xml @@ -6,12 +6,12 @@ ml.dmlc.mxnet mxnet-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml mxnet-scala-native-parent - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Native Parent pom diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 1eae0b9eb6ed..86d8cfc16a43 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -5,7 +5,7 @@ 4.0.0 ml.dmlc.mxnet mxnet-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Parent https://github.com/dmlc/mxnet/tree/master/scala-package MXNet Scala Package diff --git a/scala-package/spark/pom.xml b/scala-package/spark/pom.xml index 9d7b31909dfb..f35cbe45d9de 100644 --- a/scala-package/spark/pom.xml +++ b/scala-package/spark/pom.xml @@ -6,12 +6,12 @@ ml.dmlc.mxnet mxnet-parent_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT ../pom.xml mxnet-spark_2.11 - 0.10.0-SNAPSHOT + 0.10.1-SNAPSHOT MXNet Scala Package - Spark ML diff --git a/snapcraft.yaml b/snapcraft.yaml index a0073f2d4f1a..b9329a0ccd41 100644 --- a/snapcraft.yaml +++ b/snapcraft.yaml @@ -1,5 +1,5 @@ name: mxnet -version: '0.10.0' +version: '0.10.1' summary: MXNet is a deep learning framework designed for efficiency and flexibility. description: | MXNet is a deep learning framework designed for both efficiency and diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index ae7af5bad129..41986a0d577b 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -336,12 +336,16 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle, int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata) { + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); NDArray *arr = static_cast(handle); if (!arr->is_none()) { const TShape &s = arr->shape(); *out_dim = s.ndim(); - *out_pdata = s.data(); + std::vector& buffer = ret->arg_shape_buffer; + buffer.resize(s.ndim()); + nnvm::ShapeTypeCast(s.begin(), s.end(), buffer.data()); + *out_pdata = buffer.data(); } else { *out_dim = 0; } diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index e2e739ae62a4..d8857f80635d 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -62,16 +62,24 @@ struct MXAPIThreadLocalEntry { std::vector arg_shape_ndim, out_shape_ndim, aux_shape_ndim; /*! \brief result holder for returning shape pointer */ std::vector arg_shape_data, out_shape_data, aux_shape_data; + /*! \brief uint32_t buffer for returning shape pointer */ + std::vector arg_shape_buffer, out_shape_buffer, aux_shape_buffer; // helper function to setup return value of shape array - inline static void SetupShapeArrayReturn( + inline static void SetupShapeArrayReturnWithBuffer( const std::vector &shapes, std::vector *ndim, - std::vector *data) { + std::vector *data, + std::vector *buffer) { ndim->resize(shapes.size()); data->resize(shapes.size()); + size_t size = 0; + for (const auto& s : shapes) size += s.ndim(); + buffer->resize(size); + uint32_t *ptr = buffer->data(); for (size_t i = 0; i < shapes.size(); ++i) { ndim->at(i) = shapes[i].ndim(); - data->at(i) = shapes[i].data(); + data->at(i) = ptr; + ptr = nnvm::ShapeTypeCast(shapes[i].begin(), shapes[i].end(), ptr); } } }; diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index c633e8609cd4..66a237a4bd36 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -396,7 +396,6 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, } } - if (outarray == nullptr) { ret->ret_handles.clear(); for (int i = 0; i < num_visible_outputs; ++i) { diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index f7281c999e6a..fdf095b09361 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -429,14 +429,14 @@ int MXSymbolInferShape(SymbolHandle sym, std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); CHECK_LE(num_args, read_only_args.size()); for (mx_uint i = 0; i < num_args; ++i) { - arg_shapes[read_only_args[i]] = TShape(arg_shape_data + arg_ind_ptr[i], - arg_shape_data + arg_ind_ptr[i+1]); + arg_shapes[read_only_args[i]] = nnvm::ShapeTypeCast( + arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]); } } else { std::unordered_map kwargs; for (mx_uint i = 0; i < num_args; ++i) { - kwargs[keys[i]] = TShape(arg_shape_data + arg_ind_ptr[i], - arg_shape_data + arg_ind_ptr[i+1]); + kwargs[keys[i]] = nnvm::ShapeTypeCast( + arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]); } mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, "InferShape"); } @@ -452,12 +452,12 @@ int MXSymbolInferShape(SymbolHandle sym, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); // copy data back - MXAPIThreadLocalEntry::SetupShapeArrayReturn( - ret->arg_shapes, &(ret->arg_shape_ndim), &(ret->arg_shape_data)); - MXAPIThreadLocalEntry::SetupShapeArrayReturn( - ret->out_shapes, &(ret->out_shape_ndim), &(ret->out_shape_data)); - MXAPIThreadLocalEntry::SetupShapeArrayReturn( - ret->aux_shapes, &(ret->aux_shape_ndim), &(ret->aux_shape_data)); + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBuffer(ret->arg_shapes, + &(ret->arg_shape_ndim), &(ret->arg_shape_data), &(ret->arg_shape_buffer)); + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBuffer(ret->out_shapes, + &(ret->out_shape_ndim), &(ret->out_shape_data), &(ret->out_shape_buffer)); + MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBuffer(ret->aux_shapes, + &(ret->aux_shape_ndim), &(ret->aux_shape_data), &(ret->aux_shape_buffer)); *in_shape_size = static_cast(ret->arg_shapes.size()); *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim); *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data); diff --git a/src/c_api/c_predict_api.cc b/src/c_api/c_predict_api.cc index 26bc44b701e5..1dd784ba2249 100644 --- a/src/c_api/c_predict_api.cc +++ b/src/c_api/c_predict_api.cc @@ -25,6 +25,8 @@ struct MXAPIPredictor { std::vector arg_arrays; // output shapes std::vector out_shapes; + // uint32_t buffer for output shapes + std::vector out_shapes_buffer; // key to arguments std::unordered_map key2arg; // executor @@ -34,6 +36,7 @@ struct MXAPIPredictor { struct MXAPINDList { std::vector keys; std::vector shapes; + std::vector shapes_buffer; std::vector indptr; std::vector data; }; @@ -228,7 +231,11 @@ int MXPredGetOutputShape(PredictorHandle handle, API_BEGIN(); CHECK_LT(out_index, p->out_arrays.size()) << "Index exceed number of outputs"; - *shape_data = p->out_shapes[out_index].data(); + + const TShape& s = p->out_shapes[out_index]; + p->out_shapes_buffer.resize(s.ndim()); + nnvm::ShapeTypeCast(s.begin(), s.end(), p->out_shapes_buffer.data()); + *shape_data = p->out_shapes_buffer.data(); *shape_ndim = p->out_shapes[out_index].ndim(); API_END(); } @@ -322,7 +329,10 @@ int MXNDListGet(NDListHandle handle, << "Index out of range"; *out_key = p->keys[index].c_str(); *out_data = dmlc::BeginPtr(p->data) + p->indptr[index]; - *out_shape = p->shapes[index].data(); + const TShape& s = p->shapes[index]; + p->shapes_buffer.resize(s.ndim()); + nnvm::ShapeTypeCast(s.begin(), s.end(), p->shapes_buffer.data()); + *out_shape = p->shapes_buffer.data(); *out_ndim = p->shapes[index].ndim(); API_END(); } diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 6ba0ff96b382..cdbb129304b1 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -538,9 +538,9 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { } if (!allocated) { size_t nword = (bytes + 3) / 4; - CHECK_LE(nword, std::numeric_limits::max()); + CHECK_LE(nword, std::numeric_limits::max()); // allocate float arrays - TShape shape{index_t(nword)}; + TShape shape{static_cast(nword)}; NDArray nd(shape, ctx); data_pool_[i] = nd; // put the new allocated arrays to shared pool diff --git a/src/io/image_io.cc b/src/io/image_io.cc index 9c65edd1aa87..1ef1df1b74bd 100644 --- a/src/io/image_io.cc +++ b/src/io/image_io.cc @@ -26,7 +26,7 @@ namespace io { // http://www.64lines.com/jpeg-width-height // Gets the JPEG size from the array of data passed to the function, // file reference: http://www.obrador.com/essentialjpeg/headerinfo.htm -bool get_jpeg_size(const uint8_t* data, uint32_t data_size, uint32_t *width, uint32_t *height) { +bool get_jpeg_size(const uint8_t* data, uint32_t data_size, int64_t *width, int64_t *height) { // Check for valid JPEG image uint32_t i = 0; // Keeps track of the position within the file if (data[i] == 0xFF && data[i+1] == 0xD8 && data[i+2] == 0xFF && data[i+3] == 0xE0) { @@ -63,7 +63,7 @@ bool get_jpeg_size(const uint8_t* data, uint32_t data_size, uint32_t *width, uin } } -bool get_png_size(const uint8_t* data, uint32_t data_size, uint32_t *width, uint32_t *height) { +bool get_png_size(const uint8_t* data, uint32_t data_size, int64_t *width, int64_t *height) { if (data[0] == 0x89 && data[1] == 0x50 && data[2] ==0x4E && data[3] == 0x47) { uint8_t const* p = data + 16; *width = ((p[0]*256 + p[1])*256 + p[2])*256 + p[3]; diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 2b53393679c6..a51e24503785 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -145,7 +145,7 @@ class BatchLoader : public IIterator { shape_[i] = dst_shape; data_[i].resize(mshadow::Shape1(dst_shape.Size()), src_type_flag); unit_size_[i] = src_shape.Size(); - out_.data.push_back(TBlob(data_[i].dptr_, dst_shape, cpu::kDevMask, src_type_flag)); + out_.data.push_back(TBlob(data_[i].dptr_, dst_shape, cpu::kDevMask, src_type_flag, 0)); } } }; // class BatchLoader diff --git a/src/io/iter_csv.cc b/src/io/iter_csv.cc index 2817b4d8ff51..c43f99911f69 100644 --- a/src/io/iter_csv.cc +++ b/src/io/iter_csv.cc @@ -107,7 +107,7 @@ class CSVIter: public IIterator { << "The data size in CSV do not match size of shape: " << "specified shape=" << shape << ", the csv row-length=" << row.length; const real_t* ptr = row.value; - return TBlob((real_t*)ptr, shape, cpu::kDevMask); // NOLINT(*) + return TBlob((real_t*)ptr, shape, cpu::kDevMask, 0); // NOLINT(*) } CSVIterParam param_; diff --git a/src/io/iter_image_recordio_2.cc b/src/io/iter_image_recordio_2.cc index 94019fe293df..ace42855b6a7 100644 --- a/src/io/iter_image_recordio_2.cc +++ b/src/io/iter_image_recordio_2.cc @@ -266,7 +266,7 @@ inline bool ImageRecordIOParser2::ParseNext(DataBatch *out) { auto dtype = prefetch_param_.dtype ? prefetch_param_.dtype.value() : first_batch.data[i].type_flag_; - out->data.at(i) = NDArray(dst_shape, Context::CPU(), false , src_type_flag); + out->data.at(i) = NDArray(dst_shape, Context::CPU(), false, src_type_flag); unit_size_[i] = src_shape.Size(); } } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index c19a82b164c4..717ba170aaf7 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -4,6 +4,7 @@ * \brief ndarry module of mxnet */ #include +#include #include #include #include @@ -613,8 +614,11 @@ NDArray &NDArray::operator/=(const real_t &src) { return ScalarOpApply(this, src); } +/* magic number for ndarray version 1, with int64_t TShape */ +static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8; + void NDArray::Save(dmlc::Stream *strm) const { - // save shape + strm->Write(NDARRAY_V1_MAGIC); shape_.Save(strm); if (is_none()) return; // save context @@ -638,10 +642,28 @@ void NDArray::Save(dmlc::Stream *strm) const { strm->Write(save_data.dptr_, type_size * shape_.Size()); } +bool LegacyTShapeLoad(dmlc::Stream *strm, TShape *shape) { + uint32_t magic; + if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false; + switch (magic) { + case NDARRAY_V1_MAGIC: + return shape->Load(strm); + default: + // meet legacy TShape, magic is ndim here + uint32_t ndim = magic; + *shape = TShape(ndim); + std::vector buffer(ndim); + size_t nread = ndim * sizeof(uint32_t); + if (strm->Read(buffer.data(), nread) != nread) return false; + nnvm::ShapeTypeCast(buffer.begin(), buffer.end(), shape->begin()); + return true; + } +} + bool NDArray::Load(dmlc::Stream *strm) { // load shape TShape shape; - if (!shape.Load(strm)) return false; + if (!LegacyTShapeLoad(strm, &shape)) return false; if (shape.ndim() == 0) { *this = NDArray(); return true; } @@ -710,7 +732,7 @@ void NDArray::SyncCopyFromCPU(const void *data, size_t size) const { TShape dshape = this->shape(); CHECK_EQ(dshape.Size(), size) << "Memory size do not match"; - TBlob src((void*)data, dshape, cpu::kDevMask, this->dtype_); // NOLINT(*) + TBlob src((void*)data, dshape, cpu::kDevMask, this->dtype_, 0); // NOLINT(*) if (this->ctx().dev_mask() == cpu::kDevMask) { this->WaitToWrite(); @@ -739,7 +761,7 @@ void NDArray::SyncCopyToCPU(void *data, size_t size) const { TShape dshape = this->shape(); CHECK_EQ(dshape.Size(), size) << "Memory size do not match"; - TBlob dst(data, dshape, cpu::kDevMask, this->dtype_); // NOLINT(*) + TBlob dst(data, dshape, cpu::kDevMask, this->dtype_, 0); // NOLINT(*) if (this->ctx().dev_mask() == cpu::kDevMask) { this->WaitToRead(); diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h index 96eadcdcca4c..508b1f8be84d 100644 --- a/src/operator/cudnn_convolution-inl.h +++ b/src/operator/cudnn_convolution-inl.h @@ -33,6 +33,7 @@ class CuDNNConvolutionOp : public Operator { const Context& ctx) { using namespace mshadow; this->param_ = param; + InitBufferForParam(); auto cudnn_forward_compute_type = convertToCuDNNDataType(forward_compute_type); auto cudnn_backward_compute_type = convertToCuDNNDataType(backward_compute_type); // convert MB to words @@ -426,27 +427,28 @@ class CuDNNConvolutionOp : public Operator { // 3d conv #if CUDNN_MAJOR >= 5 CHECK_EQ(param_.layout.value(), kNCDHW) << "CuDNN only support 3D conv with NCDHW layout"; + std::vector wshape_buffer(wshape.ndim()); CUDNN_CALL(cudnnSetFilterNdDescriptor(filter_desc_, dtype_, CUDNN_TENSOR_NCHW, static_cast(wshape.ndim()), - reinterpret_cast(&wshape[0]))); + CastTShapeToIntPtr(wshape, &wshape_buffer))); #else LOG(FATAL) << "Only support CUDNN V5 for 3D convolution"; #endif CUDNN_CALL(cudnnSetConvolutionNdDescriptor(forward_conv_desc_, 3, - reinterpret_cast(¶m_.pad[0]), - reinterpret_cast(¶m_.stride[0]), - reinterpret_cast(¶m_.dilate[0]), + param_pad_.data(), + param_stride_.data(), + param_dilate_.data(), CUDNN_CROSS_CORRELATION, cudnn_forward_compute_type)); CUDNN_CALL(cudnnSetConvolutionNdDescriptor(backward_conv_desc_, 3, - reinterpret_cast(¶m_.pad[0]), - reinterpret_cast(¶m_.stride[0]), - reinterpret_cast(¶m_.dilate[0]), + param_pad_.data(), + param_stride_.data(), + param_dilate_.data(), CUDNN_CROSS_CORRELATION, cudnn_backward_compute_type)); @@ -472,17 +474,26 @@ class CuDNNConvolutionOp : public Operator { data_offset_ = dstride[1] * dshape[1]; out_offset_ = ostride[1] * oshape[1]; - CUDNN_CALL(cudnnSetTensorNdDescriptor(in_desc_, - dtype_, - static_cast(dshape.ndim()), - reinterpret_cast(&dshape[0]), - reinterpret_cast(&dstride[0]))); + std::vector dshape_buffer(dshape.ndim()); + nnvm::ShapeTypeCast(dshape.begin(), dshape.end(), dshape_buffer.data()); + std::vector dstride_buffer(dstride.ndim()); + nnvm::ShapeTypeCast(dstride.begin(), dstride.end(), dstride_buffer.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(in_desc_, + dtype_, + static_cast(dshape.ndim()), + dshape_buffer.data(), + dstride_buffer.data())); + + std::vector oshape_buffer(oshape.ndim()); + nnvm::ShapeTypeCast(oshape.begin(), oshape.end(), oshape_buffer.data()); + std::vector ostride_buffer(ostride.ndim()); + nnvm::ShapeTypeCast(ostride.begin(), ostride.end(), ostride_buffer.data()); CUDNN_CALL(cudnnSetTensorNdDescriptor(out_desc_, - dtype_, - static_cast(oshape.ndim()), - reinterpret_cast(&oshape[0]), - reinterpret_cast(&ostride[0]))); + dtype_, + static_cast(oshape.ndim()), + oshape_buffer.data(), + ostride_buffer.data())); if (!param_.no_bias) { TShape bias = in_shape[conv::kBias]; @@ -661,6 +672,22 @@ class CuDNNConvolutionOp : public Operator { init_temp_size_ = true; } + int *CastTShapeToIntPtr(const TShape& s, std::vector *buffer) { + buffer->resize(s.ndim()); + nnvm::ShapeTypeCast(s.begin(), s.end(), buffer->data()); + return buffer->data(); + } + + void InitBufferForParam() { + CastTShapeToIntPtr(param_.stride, ¶m_stride_); + CastTShapeToIntPtr(param_.dilate, ¶m_dilate_); + CastTShapeToIntPtr(param_.pad, ¶m_pad_); + } + + std::vector param_stride_; + std::vector param_dilate_; + std::vector param_pad_; + bool init_cudnn_; bool init_temp_size_; size_t forward_workspace_; diff --git a/src/operator/cudnn_deconvolution-inl.h b/src/operator/cudnn_deconvolution-inl.h index 99426531beb0..5bba1e5278fa 100644 --- a/src/operator/cudnn_deconvolution-inl.h +++ b/src/operator/cudnn_deconvolution-inl.h @@ -30,6 +30,7 @@ class CuDNNDeconvolutionOp : public Operator { const Context& ctx) { using namespace mshadow; this->param_ = param; + InitBufferForParam(); auto cudnn_forward_compute_type = convertToCuDNNDataType(forward_compute_type); auto cudnn_backward_compute_type = convertToCuDNNDataType(backward_compute_type); // convert MB to words @@ -449,27 +450,28 @@ class CuDNNDeconvolutionOp : public Operator { #if CUDNN_MAJOR >= 5 CHECK_EQ(param_.layout.value(), kNCDHW) << "CuDNN only support 3D conv with NCDHW layout"; + std::vector wshape_buffer(wshape.ndim()); CUDNN_CALL(cudnnSetFilterNdDescriptor(filter_desc_, dtype_, CUDNN_TENSOR_NCHW, static_cast(wshape.ndim()), - reinterpret_cast(&wshape[0]))); + CastTShapeToIntPtr(wshape, &wshape_buffer))); #else LOG(FATAL) << "Only support CUDNN V5 for 3D convolution"; #endif CUDNN_CALL(cudnnSetConvolutionNdDescriptor(forward_conv_desc_, 3, reinterpret_cast(&o_pad[0]), - reinterpret_cast(¶m_.stride[0]), - reinterpret_cast(¶m_.dilate[0]), + param_stride_.data(), + param_dilate_.data(), CUDNN_CROSS_CORRELATION, cudnn_forward_compute_type)); CUDNN_CALL(cudnnSetConvolutionNdDescriptor(backward_conv_desc_, 3, reinterpret_cast(&o_pad[0]), - reinterpret_cast(¶m_.stride[0]), - reinterpret_cast(¶m_.dilate[0]), + param_stride_.data(), + param_dilate_.data(), CUDNN_CROSS_CORRELATION, cudnn_backward_compute_type)); @@ -495,17 +497,21 @@ class CuDNNDeconvolutionOp : public Operator { data_offset_ = dstride[1] * dshape[1]; out_offset_ = ostride[1] * oshape[1]; + std::vector dshape_buffer(dshape.ndim()); + std::vector dstride_buffer(dstride.ndim()); CUDNN_CALL(cudnnSetTensorNdDescriptor(in_desc_, dtype_, static_cast(dshape.ndim()), - reinterpret_cast(&dshape[0]), - reinterpret_cast(&dstride[0]))); + CastTShapeToIntPtr(dshape, &dshape_buffer), + CastTShapeToIntPtr(dstride, &dstride_buffer))) + std::vector oshape_buffer(oshape.ndim()); + std::vector ostride_buffer(ostride.ndim()); CUDNN_CALL(cudnnSetTensorNdDescriptor(out_desc_, dtype_, static_cast(oshape.ndim()), - reinterpret_cast(&oshape[0]), - reinterpret_cast(&ostride[0]))); + CastTShapeToIntPtr(oshape, &oshape_buffer), + CastTShapeToIntPtr(ostride, &ostride_buffer))); if (!param_.no_bias) { TShape bias = in_shape[deconv::kBias]; @@ -687,6 +693,20 @@ class CuDNNDeconvolutionOp : public Operator { init_temp_size_ = true; } + int *CastTShapeToIntPtr(const TShape& s, std::vector *buffer) { + buffer->resize(s.ndim()); + nnvm::ShapeTypeCast(s.begin(), s.end(), buffer->data()); + return buffer->data(); + } + + void InitBufferForParam() { + CastTShapeToIntPtr(param_.stride, ¶m_stride_); + CastTShapeToIntPtr(param_.dilate, ¶m_dilate_); + } + + std::vector param_stride_; + std::vector param_dilate_; + bool init_cudnn_; bool init_temp_size_; size_t forward_workspace_; diff --git a/src/operator/custom/custom-inl.h b/src/operator/custom/custom-inl.h index b9224cd30f48..f640c3abd7a6 100644 --- a/src/operator/custom/custom-inl.h +++ b/src/operator/custom/custom-inl.h @@ -184,11 +184,17 @@ class CustomOpProp : public OperatorProperty { bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { - std::vector shapes; + std::vector shapes; std::vector ndims; + size_t size = 0; + for (const auto& s : *in_shape) size += s.ndim(); + std::vector shapes_buffer(size); + shapes_buffer.resize(size); + uint32_t *ptr = shapes_buffer.data(); for (auto iter = in_shape->begin(); iter != in_shape->end(); ++iter) { - shapes.push_back(iter->data()); + shapes.push_back(ptr); ndims.push_back(iter->ndim()); + ptr = nnvm::ShapeTypeCast(iter->begin(), iter->end(), ptr); } shapes.resize(num_inputs_+num_outputs_+num_auxs_); ndims.resize(num_inputs_+num_outputs_+num_auxs_); @@ -284,6 +290,7 @@ class CustomOpProp : public OperatorProperty { std::shared_ptr info_; std::vector > kwargs_; unsigned num_inputs_, num_outputs_, num_auxs_; + mutable std::vector shapes_buffer_; }; // class CustomOpProp } // namespace op } // namespace mxnet diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc index 06330a4a062e..29f624ead2ad 100644 --- a/src/operator/custom/custom.cc +++ b/src/operator/custom/custom.cc @@ -160,11 +160,16 @@ void CustomOp::Backward(const OpContext &ctx, Operator* CustomOpProp::CreateOperatorEx(Context ctx, std::vector *in_shape, std::vector *in_type) const { - std::vector shapes; + std::vector shapes; std::vector ndims; + size_t size = 0; + for (const auto& s : *in_shape) size += s.ndim(); + shapes_buffer_.resize(size); + uint32_t *ptr = shapes_buffer_.data(); for (auto iter = in_shape->begin(); iter != in_shape->end(); ++iter) { - shapes.push_back(iter->data()); + shapes.push_back(ptr); ndims.push_back(iter->ndim()); + ptr = nnvm::ShapeTypeCast(iter->begin(), iter->end(), ptr); } std::string str_ctx; if (ctx.dev_mask() == cpu::kDevMask) { diff --git a/src/operator/custom/native_op-inl.h b/src/operator/custom/native_op-inl.h index b5706205c82b..780b0ae41f67 100644 --- a/src/operator/custom/native_op-inl.h +++ b/src/operator/custom/native_op-inl.h @@ -108,7 +108,8 @@ class NativeOp : public Operator { NativeOpParam param_; std::vector ptrs; std::vector ndims; - std::vector shapes; + std::vector shapes; + std::vector shapes_buffer_; std::vector tags; std::map > > buffer_map; @@ -137,13 +138,18 @@ class NativeOp : public Operator { const std::string &prefix, mshadow::Stream *stream, int tag) { + size_t size = 0; + for (const auto& tblob : vec) size += tblob.shape_.ndim(); + shapes_buffer_.resize(size); + uint32_t *ptr = shapes_buffer_.data(); for (size_t i = 0; i < vec.size(); ++i) { std::stringstream name; name << prefix << i; SyncBuffer(vec[i], name.str(), stream); ptrs.push_back(buffer_map[name.str()].second.dptr_); ndims.push_back(vec[i].ndim()); - shapes.push_back(const_cast(vec[i].shape_.data())); + shapes.push_back(ptr); + ptr = nnvm::ShapeTypeCast(vec[i].shape_.begin(), vec[i].shape_.end(), ptr); tags.push_back(tag); } } @@ -198,11 +204,16 @@ class NativeOpProp : public OperatorProperty { bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { - std::vector shapes; + std::vector shapes; std::vector ndims; + size_t size = 0; + for (const auto& s : *in_shape) size += s.ndim(); + std::vector shapes_buffer(size); + uint32_t *ptr = shapes_buffer.data(); for (auto iter = in_shape->begin(); iter != in_shape->end(); ++iter) { - shapes.push_back(iter->data()); + shapes.push_back(ptr); ndims.push_back(iter->ndim()); + ptr = nnvm::ShapeTypeCast(iter->begin(), iter->end(), ptr); } shapes.resize(param_.num_inputs_+param_.num_outputs_); ndims.resize(param_.num_inputs_+param_.num_outputs_); diff --git a/src/operator/custom/ndarray_op-inl.h b/src/operator/custom/ndarray_op-inl.h index a07a7f781d2d..05b1a3a902e8 100644 --- a/src/operator/custom/ndarray_op-inl.h +++ b/src/operator/custom/ndarray_op-inl.h @@ -110,11 +110,16 @@ class NDArrayOpProp : public OperatorProperty { bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { - std::vector shapes; + std::vector shapes; std::vector ndims; + size_t size = 0; + for (const auto& s : *in_shape) size += s.ndim(); + std::vector shapes_buffer(size); + uint32_t *ptr = shapes_buffer.data(); for (auto iter = in_shape->begin(); iter != in_shape->end(); ++iter) { - shapes.push_back(iter->data()); + shapes.push_back(ptr); ndims.push_back(iter->ndim()); + ptr = nnvm::ShapeTypeCast(iter->begin(), iter->end(), ptr); } shapes.resize(param_.num_inputs_+param_.num_outputs_); ndims.resize(param_.num_inputs_+param_.num_outputs_); diff --git a/src/operator/deconvolution-inl.h b/src/operator/deconvolution-inl.h index 771f0e217073..4edeb6979222 100644 --- a/src/operator/deconvolution-inl.h +++ b/src/operator/deconvolution-inl.h @@ -151,7 +151,8 @@ class DeconvolutionOp : public Operator { Tensor out = out_data[deconv::kOut].get(s); index_t o_pad[2], o_adj[2]; - TShape dshape = {data.size(2), data.size(3)}; + TShape dshape = {static_cast(data.size(2)), + static_cast(data.size(3))}; param_.InferPad(dshape, o_pad, o_adj); Shape<3> wmat_shape = @@ -268,7 +269,8 @@ class DeconvolutionOp : public Operator { << "Must init CuBLAS handle in stream"; #endif index_t o_pad[2], o_adj[2]; - TShape dshape = {data.size(2), data.size(3)}; + TShape dshape = {static_cast(data.size(2)), + static_cast(data.size(3))}; param_.InferPad(dshape, o_pad, o_adj); const index_t nbatch = data.size(0); diff --git a/src/operator/tensor/control_flow_op.h b/src/operator/tensor/control_flow_op.h index 0ab24899042d..c7fcda0f0c01 100644 --- a/src/operator/tensor/control_flow_op.h +++ b/src/operator/tensor/control_flow_op.h @@ -108,7 +108,7 @@ inline bool WhereOpShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*in_attrs, 0, tshape); return true; } else if ((*in_attrs)[0].ndim() == 1) { - return (*in_attrs)[0].Size() == tshape[0]; + return (*in_attrs)[0].Size() == static_cast(tshape[0]); } return false; } diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index d7a591944e47..cdc8819da18e 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -283,8 +283,8 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, } } else { CHECK_EQ(shp.ndim(), param.axes.ndim()); - for (index_t i = 0; i < shp.ndim(); ++i) { - CHECK(param.axes[i] < shp.ndim()); + for (size_t i = 0; i < shp.ndim(); ++i) { + CHECK(param.axes[i] < static_cast(shp.ndim())); ret[i] = shp[param.axes[i]]; } } @@ -1387,11 +1387,13 @@ void RepeatOpForward(const nnvm::NodeAttrs& attrs, std::pair rshapes = ReshapeInputOutputForRepeatOp(ishape, axisOpt, repeats); // reshaped input tblob - TBlob iblob(inputs[0].dptr_, rshapes.first, inputs[0].dev_mask_, inputs[0].type_flag_); + TBlob iblob(inputs[0].dptr_, rshapes.first, inputs[0].dev_mask(), + inputs[0].type_flag_, inputs[0].dev_id()); std::vector newInputs = {iblob}; // reshaped output tblob - TBlob oblob(outputs[0].dptr_, rshapes.second, outputs[0].dev_mask_, outputs[0].type_flag_); + TBlob oblob(outputs[0].dptr_, rshapes.second, outputs[0].dev_mask(), + outputs[0].type_flag_, outputs[0].dev_id()); std::vector newOutputs = {oblob}; BroadcastCompute(attrs, ctx, newInputs, req, newOutputs); @@ -1429,11 +1431,13 @@ void RepeatOpBackward(const nnvm::NodeAttrs& attrs, ReshapeInputOutputForRepeatOp(oshape, axisOpt, repeats); // reshaped output grad tblob - TBlob oblob(outputs[0].dptr_, rshapes.first, outputs[0].dev_mask_, outputs[0].type_flag_); + TBlob oblob(outputs[0].dptr_, rshapes.first, outputs[0].dev_mask(), + outputs[0].type_flag_, outputs[0].dev_id()); std::vector newOutputs = {oblob}; // reshaped input grad tblob - TBlob iblob(inputs[0].dptr_, rshapes.second, inputs[0].dev_mask_, inputs[0].type_flag_); + TBlob iblob(inputs[0].dptr_, rshapes.second, inputs[0].dev_mask(), + inputs[0].type_flag_, inputs[0].dev_id()); std::vector newInputs = {iblob}; ReduceAxesComputeImpl( @@ -1563,10 +1567,12 @@ void TileOpForward(const nnvm::NodeAttrs& attrs, std::pair rshapes = ReshapeInputOutputForTileOp(ishape, reps); // reshaped input tblob - TBlob iblob(inputs[0].dptr_, rshapes.first, inputs[0].dev_mask_, inputs[0].type_flag_); + TBlob iblob(inputs[0].dptr_, rshapes.first, inputs[0].dev_mask(), + inputs[0].type_flag_, inputs[0].dev_id()); std::vector newInputs = {iblob}; // reshaped output tblob - TBlob oblob(outputs[0].dptr_, rshapes.second, outputs[0].dev_mask_, outputs[0].type_flag_); + TBlob oblob(outputs[0].dptr_, rshapes.second, outputs[0].dev_mask(), + outputs[0].type_flag_, outputs[0].dev_id()); std::vector newOutputs = {oblob}; BroadcastCompute(attrs, ctx, newInputs, req, newOutputs); @@ -1603,10 +1609,12 @@ void TileOpBackward(const nnvm::NodeAttrs& attrs, std::pair rshapes = ReshapeInputOutputForTileOp(oshape, reps); // reshaped output grad tblob - TBlob oblob(outputs[0].dptr_, rshapes.first, outputs[0].dev_mask_, outputs[0].type_flag_); + TBlob oblob(outputs[0].dptr_, rshapes.first, outputs[0].dev_mask(), + outputs[0].type_flag_, outputs[0].dev_id()); std::vector newOutputs = {oblob}; // reshaped input grad tblob - TBlob iblob(inputs[0].dptr_, rshapes.second, inputs[0].dev_mask_, inputs[0].type_flag_); + TBlob iblob(inputs[0].dptr_, rshapes.second, inputs[0].dev_mask(), + inputs[0].type_flag_, inputs[0].dev_id()); std::vector newInputs = {iblob}; ReduceAxesComputeImpl( diff --git a/tests/cpp/include/test_util.h b/tests/cpp/include/test_util.h index 6b87312e174a..b0e4c866f9de 100644 --- a/tests/cpp/include/test_util.h +++ b/tests/cpp/include/test_util.h @@ -160,14 +160,14 @@ inline StreamType& print_blob(StreamType *_os, const TBlob &blob, if (dim == 1) { // probably a tensor (mshadow::Tensor is deprecated) - TBlob changed(blob.dptr(), TShape(3), blob.dev_mask_); + TBlob changed(blob.dptr(), TShape(3), blob.dev_mask(), blob.dev_id()); changed.shape_[0] = 1; changed.shape_[1] = 1; changed.shape_[2] = blob.shape_[0]; return print_blob(&os, changed, false, false); } else if (dim == 2) { // probably a tensor (mshadow::Tensor is deprecated) - TBlob changed(blob.dptr(), TShape(4), blob.dev_mask_); + TBlob changed(blob.dptr(), TShape(4), blob.dev_mask(), blob.dev_id()); changed.shape_[0] = 1; changed.shape_[1] = 1; changed.shape_[2] = blob.shape_[0]; diff --git a/tests/python/unittest/legacy_ndarray.v0 b/tests/python/unittest/legacy_ndarray.v0 new file mode 100644 index 0000000000000000000000000000000000000000..f4306d8372021bec350f9faea44ef0b07c96d939 GIT binary patch literal 3224 zcmeI!u}d3a7{K8Rh&VWe;?Th%gF^=ghYk*cF?4X~kinrthYSu54h{|u4o8s^q?8~< zf|L?Ukx)trA|fIs6cG^-5h)^4LXi@rNDwKZNa+*MA^8{jecW@$eK+2Z_wGR=YFpiR zcdhN1j>oOtQKz(nm=aZLL}O8uB*hSEvgDW{CQpGPCCXITq{<#goKmAsgCY}3DDj>TtWxG9pZH9LHP+c+lP$LSLY1%Vu*)9%9B{}H$DHtuQ_eW&J2ifA!B6U3 z@{21PTyw*3n*8CGznO#=5YxK0uh2wjqP+;EggFuBL`M{%i7+QRk_dCcw>A7vv{w=4 L#Q)mFeFy&lZ0mW> literal 0 HcmV?d00001 diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 7f0a1d2b6301..fcc7d70f20fe 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -230,6 +230,15 @@ def test_ndarray_saveload(): assert np.sum(x.asnumpy() != y.asnumpy()) == 0 os.remove(fname) +def test_ndarray_legacy_load(): + data = [] + for i in range(6): + data.append(mx.nd.arange(128)) + path = os.path.dirname(os.path.realpath(__file__)) + legacy_data = mx.nd.load(os.path.join(path, 'legacy_ndarray.v0')) + assert len(data) == len(legacy_data) + for i in range(len(data)): + assert same(data[i].asnumpy(), legacy_data[i].asnumpy()) def test_ndarray_slice(): shape = (10,)