|
16 | 16 | #include <vector> |
17 | 17 |
|
18 | 18 | #include "paddle/fluid/operators/math/math_function.h" |
| 19 | +#ifdef PADDLE_WITH_MKLDNN |
| 20 | +#include "paddle/fluid/platform/mkldnn_helper.h" |
| 21 | +#endif |
19 | 22 |
|
20 | 23 | namespace paddle { |
21 | 24 | namespace framework { |
@@ -88,5 +91,85 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, |
88 | 91 | out->set_layout(expected_kernel_type.data_layout_); |
89 | 92 | } |
90 | 93 |
|
| 94 | +#ifdef PADDLE_WITH_MKLDNN |
| 95 | +using mkldnn::memory; |
| 96 | +using mkldnn::primitive; |
| 97 | +using mkldnn::reorder; |
| 98 | + |
| 99 | +void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { |
| 100 | + switch (type) { |
| 101 | + case mkldnn::memory::data_type::f32: |
| 102 | + return platform::to_void_cast(tensor.data<float>()); |
| 103 | + case mkldnn::memory::data_type::s8: |
| 104 | + return platform::to_void_cast(tensor.data<char>()); |
| 105 | + case mkldnn::memory::data_type::u8: |
| 106 | + return platform::to_void_cast(tensor.data<unsigned char>()); |
| 107 | + case mkldnn::memory::data_type::s16: |
| 108 | + return platform::to_void_cast(tensor.data<int16_t>()); |
| 109 | + case mkldnn::memory::data_type::s32: |
| 110 | + return platform::to_void_cast(tensor.data<int32_t>()); |
| 111 | + default: |
| 112 | + PADDLE_THROW("wrong mkldnn type provided"); |
| 113 | + } |
| 114 | +} |
| 115 | +#endif |
| 116 | + |
| 117 | +void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, |
| 118 | + const OpKernelType& expected_kernel_type, |
| 119 | + const Tensor& in, Tensor* out) { |
| 120 | + auto in_layout = kernel_type_for_var.data_layout_; |
| 121 | + auto out_layout = expected_kernel_type.data_layout_; |
| 122 | + |
| 123 | + PADDLE_ENFORCE( |
| 124 | + in_layout == DataLayout::kMKLDNN && out_layout != DataLayout::kMKLDNN, |
| 125 | + "TransDataLayoutFromMKLDNN only supports transform from MKLDNN to " |
| 126 | + "non-MKLDNN"); |
| 127 | + |
| 128 | +#ifdef PADDLE_WITH_MKLDNN |
| 129 | + PADDLE_ENFORCE(in.format() != memory::format::format_undef && |
| 130 | + in.format() != memory::format::any, |
| 131 | + "Input tensor should have specified memory format"); |
| 132 | + |
| 133 | + // Set default as NCHW in case not specified |
| 134 | + out_layout = |
| 135 | + out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout; |
| 136 | + |
| 137 | + auto& pool = platform::DeviceContextPool::Instance(); |
| 138 | + auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>( |
| 139 | + pool.Get(expected_kernel_type.place_)); |
| 140 | + auto& cpu_engine = dev_ctx->GetEngine(); |
| 141 | + |
| 142 | + std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims()); |
| 143 | + std::vector<int> out_tz = in_tz; |
| 144 | + |
| 145 | + memory::data_type in_type = ToMKLDNNDataType(in.type()); |
| 146 | + PADDLE_ENFORCE(in_type != memory::data_type::data_undef, |
| 147 | + "Input tensor type is not supported: ", in.type().name()); |
| 148 | + memory::data_type out_type = in_type; |
| 149 | + |
| 150 | + memory::format in_format = |
| 151 | + in_tz.size() == 2 ? memory::format::nc : in.format(); |
| 152 | + memory::format out_format = |
| 153 | + out_tz.size() == 2 ? memory::format::nc : ToMKLDNNFormat(out_layout); |
| 154 | + |
| 155 | + void* in_data = GetDataFromTensor(in, in_type); |
| 156 | + |
| 157 | + // output tensor has the same dims as input. Reorder don't change dims |
| 158 | + out->Resize(in.dims()); |
| 159 | + |
| 160 | + auto out_data = out->mutable_data(expected_kernel_type.place_, in.type()); |
| 161 | + |
| 162 | + auto in_memory = memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data); |
| 163 | + auto out_memory = |
| 164 | + memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data); |
| 165 | + |
| 166 | + platform::Reorder(in_memory, out_memory); |
| 167 | + |
| 168 | + out->set_layout(out_layout); |
| 169 | + // reset format since the out tensor will be feed to non-MKLDNN OPkernel |
| 170 | + out->set_format(memory::format::format_undef); |
| 171 | +#endif |
| 172 | +} |
| 173 | + |
91 | 174 | } // namespace framework |
92 | 175 | } // namespace paddle |
0 commit comments