Skip to content

Commit

Permalink
Automatic Layout Management (apache#20718)
Browse files Browse the repository at this point in the history
* Automatic Layout Management

Originally authored by Dawid Tracz <dtracz@nvidia.com>

* Fix clang-format

* Fix clang-format in mshadow

* Print layout name instead of a number

* Generalize NHWC target layout to other dimensions

* Change layout optimization API

* Add layout optimization tests

* Add backward check to tests

* Generalize tests to 1..3 spatial dims

* Add NWC layout to ConvolutionParams

* Enable layout optimization tests only with cuDNN

Co-authored-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
  • Loading branch information
mk-61 and Vladimir Cherepanov authored Dec 2, 2021
1 parent f60c1d2 commit 40359ce
Show file tree
Hide file tree
Showing 23 changed files with 737 additions and 6 deletions.
60 changes: 60 additions & 0 deletions 3rdparty/mshadow/mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,8 @@ const int index_type_flag = DataType<lapack_index_t>::kFlag;

/*! layout flag */
enum LayoutFlag {
kUNKNOWN = -1,

kNCHW = 0,
kNHWC,
kCHWN,
Expand All @@ -509,6 +511,64 @@ enum LayoutFlag {
kCDHWN
};

inline LayoutFlag layoutFlag(std::string layoutstr) {
switch (layoutstr.length()) {
case 4:
if (layoutstr == "NHWC")
return kNHWC;
if (layoutstr == "NCHW")
return kNCHW;
if (layoutstr == "CHWN")
return kCHWN;
return kUNKNOWN;
case 3:
if (layoutstr == "NWC")
return kNWC;
if (layoutstr == "NCW")
return kNCW;
if (layoutstr == "CWN")
return kCWN;
return kUNKNOWN;
case 5:
if (layoutstr == "NDHWC")
return kNDHWC;
if (layoutstr == "NCDHW")
return kNCDHW;
if (layoutstr == "CDHWN")
return kCDHWN;
return kUNKNOWN;
default:
return kUNKNOWN;
}
}

inline std::string toString(LayoutFlag layout) {
switch (layout) {
case kUNKNOWN:
return "";
case kNCHW:
return "NCHW";
case kNHWC:
return "NHWC";
case kCHWN:
return "CHWN";
case kNCW:
return "NCW";
case kNWC:
return "NWC";
case kCWN:
return "CWN";
case kNCDHW:
return "NCDHW";
case kNDHWC:
return "NDHWC";
case kCDHWN:
return "CDHWN";
default:
return "";
}
}

template<int layout>
struct LayoutType;

Expand Down
91 changes: 91 additions & 0 deletions 3rdparty/mshadow/mshadow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,97 @@ inline Shape<5> ConvertLayout(const Shape<5>& src, int src_layout, int dst_layou
return dst2;
}

/*!
* \brief returns axes of transpose operation
* that needs to be performed between src layout and dst
* \param src_layout input layout
* \param dst_layout output layout
* \return vector of required type describing axes of a transpose operation
*/
template <typename dim_t>
inline std::vector<dim_t> getTranspAxes(const LayoutFlag src_layout, const LayoutFlag dst_layout) {
auto apply = [](const std::vector<dim_t>& v, const std::vector<dim_t>& op) {
CHECK_EQ(v.size(), op.size()) << "Layout ndims does not match";
std::vector<dim_t> ret(v.size());
for (size_t i = 0; i < v.size(); i++) {
ret[i] = v[op[i]];
}
return ret;
};
std::vector<dim_t> axes;
// transpose from `case` to ND?H?WC
switch (src_layout) {
case kUNKNOWN:
LOG(FATAL) << "Unknown source layout";
break;
case kNHWC:
axes = std::vector<dim_t>({0, 1, 2, 3});
break;
case kNCHW:
axes = std::vector<dim_t>({0, 2, 3, 1});
break;
case kCHWN:
axes = std::vector<dim_t>({3, 1, 2, 0});
break;
case kNWC:
axes = std::vector<dim_t>({0, 1, 2});
break;
case kNCW:
axes = std::vector<dim_t>({0, 2, 1});
break;
case kCWN:
axes = std::vector<dim_t>({2, 1, 0});
break;
case kNDHWC:
axes = std::vector<dim_t>({0, 1, 2, 3, 4});
break;
case kNCDHW:
axes = std::vector<dim_t>({0, 2, 3, 4, 1});
break;
case kCDHWN:
axes = std::vector<dim_t>({4, 1, 2, 3, 0});
break;
default:
LOG(FATAL) << "Invalid source layout " << src_layout;
}
// transpose from ND?H?WC to `case`
switch (dst_layout) {
case kUNKNOWN:
LOG(FATAL) << "Unknown destination layout";
break;
case kNHWC:
axes = apply(axes, {0, 1, 2, 3});
break;
case kNCHW:
axes = apply(axes, {0, 3, 1, 2});
break;
case kCHWN:
axes = apply(axes, {3, 1, 2, 0});
break;
case kNWC:
axes = apply(axes, {0, 1, 2});
break;
case kNCW:
axes = apply(axes, {0, 2, 1});
break;
case kCWN:
axes = apply(axes, {2, 1, 0});
break;
case kNDHWC:
axes = apply(axes, {0, 1, 2, 3, 4});
break;
case kNCDHW:
axes = apply(axes, {0, 4, 1, 2, 3});
break;
case kCDHWN:
axes = apply(axes, {4, 1, 2, 3, 0});
break;
default:
LOG(FATAL) << "Invalid destination layout " << src_layout;
}
return axes;
}

/*!
* \brief computaion stream structure, used for asynchronous computations
*/
Expand Down
10 changes: 10 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3161,6 +3161,16 @@ MXNET_DLL int MXCUDAProfilerStart();
*/
MXNET_DLL int MXCUDAProfilerStop();

/*!
* \brief Turns on or off Layout Optimization
*/
MXNET_DLL int MXSetOptimizeLayout(bool val);

/*!
* \brief Get current Layout Optimization status
*/
MXNET_DLL int MXGetOptimizeLayout(bool* val);

#ifdef __cplusplus
}
#endif // __cplusplus
Expand Down
8 changes: 6 additions & 2 deletions python/mxnet/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def warn_if_model_exists():
return

def init(target_dtype='float16', target_precision_ops=None,
conditional_fp32_ops=None, fp32_ops=None):
conditional_fp32_ops=None, fp32_ops=None, layout_optimization=False):
"""Initialize AMP (automatic mixed precision).
This needs to be done before model creation.
Expand All @@ -333,7 +333,11 @@ def init(target_dtype='float16', target_precision_ops=None,
assert target_dtype in ['float16', np.float16, 'bfloat16', bfloat16], \
"AMP currently supports only float16 or bfloat16 as a target_dtype"
_amp_initialized = True
logging.info("Using AMP")
log_msg = "Using AMP"
if layout_optimization:
log_msg += "\n - layout optimization: enabled"
check_call(_LIB.MXSetOptimizeLayout(ctypes.c_bool(True)))
logging.info(log_msg)
if target_dtype == "bfloat16":
target_dtype = bfloat16
else:
Expand Down
13 changes: 13 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include "../operator/tvmop/op_module.h"
#include "../operator/subgraph/partitioner/custom_subgraph_property.h"
#include "../operator/subgraph/subgraph_property.h"
#include "../common/alm.h"
#include "../common/utils.h"
#include "../profiler/profiler.h"
#include "../serialization/cnpy.h"
Expand Down Expand Up @@ -4004,3 +4005,15 @@ int MXCUDAProfilerStop() {
#endif
API_END();
}

int MXSetOptimizeLayout(bool val) {
API_BEGIN();
mxnet::alm::ALMParams::get().optimize = val;
API_END();
}

int MXGetOptimizeLayout(bool* val) {
API_BEGIN();
*val = mxnet::alm::ALMParams::get().optimize;
API_END();
}
Loading

0 comments on commit 40359ce

Please sign in to comment.