Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dnnl training #6045

Merged
merged 13 commits into from
Jan 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Isolate training code and code cleanup
* Do not build if dnnl_gpu_runtime if enable_training is set training code
  does not support dnnl_gpu_runtime yet.
* Isolated Training code inside ifdefs so that they wont affect
  project if built without training enabled
* Inadvertant changes in whitespace were removed to make code review simpler
* Undid some code reordering that was not needed
* comments added to closing #endif statments to simplify reading complex ifdefs
* Modified the GetPrimitiveDesc functions to return shared_ptr instead of raw
  pointer. This matches what was done in Pool code and is safer memory code.

Signed-off-by: George Nash <george.nash@intel.com>
  • Loading branch information
jeyblu authored and georgen117 committed Jan 5, 2021
commit d180030ff609eec7e23615b1fa5a3c9ae2db3686
1 change: 0 additions & 1 deletion BUILD.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
* [ROCM](#ROCM)
* [Intel DNNL/MKL-ML](#dnnl-training)

***
# Inferencing
## Start: Baseline CPU

Expand Down
10 changes: 7 additions & 3 deletions cmake/external/dnnl.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ else()
endif()
endif()

if (onnxruntime_USE_DNNL AND onnxruntime_DNNL_GPU_RUNTIME STREQUAL "ocl" AND onnxruntime_DNNL_OPENCL_ROOT STREQUAL "")
message(FATAL_ERROR "onnxruntime_DNNL_OPENCL_ROOT required for onnxruntime_DNNL_GPU_RUNTIME")
elseif(onnxruntime_USE_DNNL AND onnxruntime_DNNL_GPU_RUNTIME STREQUAL "ocl")
if(onnxruntime_USE_DNNL AND onnxruntime_ENABLE_TRAINING AND onnxruntime_DNNL_GPU_RUNTIME STREQUAL "ocl")
message(FATAL_ERROR "--enable_training not supported with dnnl GPU runtime. Remove '--enable_training' or remove '--dnnl_gpu_runtime'.")
elseif(onnxruntime_USE_DNNL AND onnxruntime_DNNL_GPU_RUNTIME STREQUAL "ocl" AND onnxruntime_DNNL_OPENCL_ROOT STREQUAL "")
message(FATAL_ERROR "--dnnl_opencl_root required")
elseif(onnxruntime_USE_DNNL AND onnxruntime_DNNL_GPU_RUNTIME STREQUAL "" AND NOT (onnxruntime_DNNL_OPENCL_ROOT STREQUAL ""))
message(FATAL_ERROR "--dnnl_gpu_runtime required")
elseif(onnxruntime_USE_DNNL AND onnxruntime_DNNL_GPU_RUNTIME STREQUAL "ocl" AND NOT (onnxruntime_DNNL_OPENCL_ROOT STREQUAL ""))
file(TO_CMAKE_PATH ${onnxruntime_DNNL_OPENCL_ROOT} onnxruntime_DNNL_OPENCL_ROOT)
set(DNNL_OCL_INCLUDE_DIR ${onnxruntime_DNNL_OPENCL_ROOT}/include)
set(DNNL_GPU_CMAKE_ARGS "-DDNNL_GPU_RUNTIME=OCL " "-DOPENCLROOT=${onnxruntime_DNNL_OPENCL_ROOT}")
Expand Down
30 changes: 19 additions & 11 deletions onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,34 +126,37 @@ void DNNLExecutionProvider::CreateOrUpdateDnnlNode(const Node* node,
const auto& node_inputs = node->InputDefs();
sub_var.outputs.push_back(node->OutputDefs()[0]->Name());


if (!fused) {
ort_dnnl::DnnlNode dnnl_node;
dnnl_node.name = node->OpType();
// When running training mode the backward pass will need to access the forwardpass
// operations. Store the index of the node and the the list of input nodes.
// The input nodes can be used to find the forward pass node.
// The onnx node index is being used instead of the subgraph index because forwardpass
// and backward pass nodes are likly to span beyond the subgraph.
// When running training mode the backward pass will need to access the forwardpass
// operations. Store the index of the node and the the list of input nodes.
// The input nodes can be used to find the forward pass node.
// The onnx node index is being used instead of the subgraph index because forwardpass
// and backward pass nodes are likly to span beyond the subgraph.
georgen117 marked this conversation as resolved.
Show resolved Hide resolved
#ifdef ENABLE_TRAINING
dnnl_node.onnx_index = node->Index();
for (auto iter = node->InputNodesBegin(); iter != node->InputNodesEnd(); ++iter) {
ort_dnnl::InputNode input_node;
input_node.index = (*iter).Index();
input_node.op_type = (*iter).OpType();
dnnl_node.input_nodes.push_back(input_node);
}
#endif //ENABLE_TRAINING

dnnl_node.num_inputs = static_cast<int>(node->InputDefs().size());
dnnl_node.input_start_index = static_cast<int>(sub_var.inputs.size()) - 1;
dnnl_node.node_index = static_cast<int>(subgraph_ptr->dnnl_nodes.size()) + 1;
const auto& node_outputs = node->OutputDefs();
dnnl_node.output_name = node_outputs[0]->Name();
#ifdef ENABLE_TRAINING
dnnl_node.num_outputs = static_cast<int>(node->OutputDefs().size());
if (dnnl_node.num_outputs > 1) {
for (auto n : node_outputs) {
dnnl_node.output_names.push_back(n->Name());
}
}
#endif //ENABLE_TRAINING

if (node->OpType() == "Conv") {
dnnl_node.weight_name = node->InputDefs()[1]->Name();
Expand Down Expand Up @@ -265,17 +268,17 @@ std::vector<std::unique_ptr<ComputeCapability>> DNNLExecutionProvider::GetCapabi
fused = true;
}
}
#endif
//TODO: Support this in training phase so that a valid entry would be added to the forward kernel map for this fusion. Without this, it would error out
//in the respective backward pass
#endif // !ENABLE_TRAINING
// TODO: Support this in training phase so that a valid entry would be added to the forward kernel map for this fusion. Without this, it would error out
// in the respective backward pass
#ifndef ENABLE_TRAINING
if (sub_var.subgraph_node_indexes.size() > 1 && node->OpType() == "Relu") {
if (subgraph_ptr->dnnl_nodes.back().name == "Conv-BatchNormalization" || subgraph_ptr->dnnl_nodes.back().name == "BatchNormalization" || subgraph_ptr->dnnl_nodes.back().name == "Conv") {
subgraph_ptr->dnnl_nodes.back().name += "-Relu";
fused = true;
}
}
#endif
#endif // !ENABLE_TRAINING

// Create Dnnl node:
// Update inputs, outputs and parent nodes
Expand Down Expand Up @@ -418,14 +421,19 @@ void DNNLExecutionProvider::CreateMetaDef(const GraphViewer& graph_viewer,
auto itr = std::find(sub_var.outputs_as_input_other_node.begin(),
sub_var.outputs_as_input_other_node.end(), mklnode.output_name);
if (itr == sub_var.outputs_as_input_other_node.end()) {
#ifndef ENABLE_TRAINING
mrry marked this conversation as resolved.
Show resolved Hide resolved
meta_def->outputs().push_back(mklnode.output_name);
mklnode.output_index = static_cast<int>(meta_def->outputs().size()) - 1;
#else
if (mklnode.num_outputs == 1) {
meta_def->outputs().push_back(mklnode.output_name);
} else {
for (auto output : mklnode.output_names) {
meta_def->outputs().push_back(output);
}
}
mklnode.output_index = static_cast<int>(meta_def->outputs().size()) - 1;
mklnode.output_index = static_cast<int>(meta_def->outputs().size()) - 1;
#endif // ENABLE_TRAINING
}
}

Expand Down
25 changes: 17 additions & 8 deletions onnxruntime/core/providers/dnnl/dnnl_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "core/providers/dnnl/subgraph/subgraph.h"
#include "core/platform/ort_mutex.h"


namespace dnnl {
struct memory;
};
Expand Down Expand Up @@ -101,7 +100,7 @@ class DNNLExecutionProvider : public IExecutionProvider {
}

std::stack<std::shared_ptr<ort_dnnl::DnnlKernel>> fwd_conv_stack;
#endif // ENABLE_TRAINING
#endif // ENABLE_TRAINING
private:
// dnnl weights(filer data) memory blocks from first iteration
// saved by weights name
Expand All @@ -114,13 +113,14 @@ class DNNLExecutionProvider : public IExecutionProvider {
// Conv+BathNorm fusion bias memory buffer.
std::vector<IAllocatorUniquePtr<void>> biass_buffers_;
OrtMutex mutex_;

#ifdef ENABLE_TRAINING
// map used to hold and lookup forward DnnlKernels. This should only be needed in when
// running in training mode.The backward Kernels need access the forward kernals; typically
// to obtain the forward primitive description but it may be need for other items like
// accessing workspace memory.
std::map<onnxruntime::NodeIndex, std::shared_ptr<ort_dnnl::DnnlKernel>> fwd_kernal_map_;
#endif
#endif // ENABLE_TRAINING
// SUBGRAPH
private:
static int GetOnnxOpSet(const GraphViewer& graph_viewer) {
Expand Down Expand Up @@ -161,17 +161,21 @@ class DNNLExecutionProvider : public IExecutionProvider {
}
if (node->OpType().find("Pool") != std::string::npos) {
auto node_inputs = node->InputDefs();
#ifdef ENABLE_TRAINING
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() < 3) {
#else
if (node_inputs[0]->Shape() != nullptr && node_inputs[0]->Shape()->dim_size() <= 3) {
#endif // ENABLE_TRAINING
supported = false;
}
#ifdef ENABLE_TRAINING

#ifdef ENABLE_TRAINING
if (node->OutputDefs().size() > 2)
supported = false;
#else
#else
if (node->OutputDefs().size() > 1)
supported = false;
#endif

#endif // ENABLE_TRAINING
}
return supported;
}
Expand All @@ -197,9 +201,14 @@ class DNNLExecutionProvider : public IExecutionProvider {
}

private:
// supported Dnnl Operators
// supported Dnnl Operators
#ifdef ENABLE_TRAINING
std::set<std::string> dnnl_ops_ = {"Conv", "ConvGrad", "BatchNormalization", "Relu", "ReluGrad", "Sum",
"AveragePool", "GlobalMaxPool", "GlobalAveragePool", "MaxPool", "MaxPoolGrad", "LRN"};
#else
std::set<std::string> dnnl_ops_ = {"Conv", "BatchNormalization", "Relu", "Sum",
"AveragePool", "GlobalMaxPool", "GlobalAveragePool", "MaxPool", "LRN"};
#endif // ENABLE_TRAINING

mutable std::unordered_map<std::string, std::shared_ptr<ort_dnnl::Subgraph>> mkl_subgraphs_;
};
Expand Down
11 changes: 8 additions & 3 deletions onnxruntime/core/providers/dnnl/subgraph/dnnl_activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,22 @@ class DnnlRelu : public DnnlKernel {

return Status::OK();
}

dnnl::eltwise_forward::primitive_desc* GetPrimitiveDesc() {
return relu_fwd_pd_.get();
#ifdef ENABLE_TRAINING
std::shared_ptr <dnnl::eltwise_forward::primitive_desc> GetPrimitiveDesc() {
return relu_fwd_pd_;
}
#endif // ENABLE_TRAINING

private:
std::shared_ptr<dnnl::memory> src_mem_;
std::shared_ptr<dnnl::memory> src_mem_gpu_;

std::unique_ptr<dnnl::eltwise_forward::desc> fwd_desc_;
#ifndef ENABLE_TRAINING
std::unique_ptr<dnnl::eltwise_forward::primitive_desc> relu_fwd_pd_;
#else
std::shared_ptr<dnnl::eltwise_forward::primitive_desc> relu_fwd_pd_;
#endif // ENABLE_TRAINING
std::unique_ptr<dnnl::primitive> relu_fwd_;

std::unique_ptr<dnnl::memory::desc> src_md_;
Expand Down
25 changes: 17 additions & 8 deletions onnxruntime/core/providers/dnnl/subgraph/dnnl_conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ class DnnlConv : public DnnlKernel {
dnnl::memory::desc({bias_dims_mkl}, DnnnType<T>(), dnnl::memory::format_tag::any));

dnnl::memory::dims conv_zero_padding = {0, 0};

#ifdef ENABLE_TRAINING
mrry marked this conversation as resolved.
Show resolved Hide resolved
if (!bias_dims_mkl.empty()) {
fwd_desc_ = onnxruntime::make_unique<dnnl::convolution_forward::desc>(
Expand Down Expand Up @@ -287,7 +288,7 @@ class DnnlConv : public DnnlKernel {
*filter_md_, *primitive_dst_md_, strides_mkl,
dilations_mkl, padding_left_mkl, padding_right_mkl));
}
#endif //ENABLE_TRAINING
#endif //ENABLE_TRAINING

if (fuse_relu_) {
dnnl::primitive_attr attr;
Expand Down Expand Up @@ -463,6 +464,7 @@ class DnnlConv : public DnnlKernel {
TensorShape W(xshape, xdim);

const int group_mkl = static_cast<int>(group_);
const T* filter_data = const_cast<T*>(ort.GetTensorData<T>(input_tensor));

dnnl::memory::dims filter_dims_mkl;
if (group_mkl == 1) {
Expand All @@ -473,7 +475,6 @@ class DnnlConv : public DnnlKernel {
filter_dims_mkl.insert(filter_dims_mkl.end(), W.GetDims().begin() + 1, W.GetDims().end());
}

const T* filter_data = const_cast<T*>(ort.GetTensorData<T>(input_tensor));
{
// lock to make sure reordering is done only once
std::lock_guard<OrtMutex> lock(provider_->GetMutex());
Expand All @@ -499,12 +500,12 @@ class DnnlConv : public DnnlKernel {
.execute(dnnl_engine_gpu_, src, *filter_dst_mem);
}

// Do not use cached weights if running training since weight is changed each iteration
// Do not use cached weights if running training since weight is changed each iteration
#ifndef ENABLE_TRAINING
provider_->SetWeightsMemoryBuffer(mklnode_ptr_->weight_name, filter_dst_mem);
#else
filter_dst_mem_ = filter_dst_mem;
#endif // !ENABLE_TRAINING
#endif // !ENABLE_TRAINING
}
}
}
Expand Down Expand Up @@ -597,10 +598,12 @@ class DnnlConv : public DnnlKernel {
}
return Status::OK();
}

dnnl::convolution_forward::primitive_desc* GetPrimitiveDesc() {
return conv_fwd_pd_.get();
#ifdef ENABLE_TRAINING
std::shared_ptr <dnnl::convolution_forward::primitive_desc> GetPrimitiveDesc() {
return conv_fwd_pd_;
}
#endif

private:
void ReadAttributes(const NodeAttributes& attributes,
const std::string attributes_prefix = "") override {
Expand Down Expand Up @@ -665,7 +668,8 @@ class DnnlConv : public DnnlKernel {
dnnl::memory::format_tag filter_format_;
#ifdef ENABLE_TRAINING
std::shared_ptr<dnnl::memory> filter_dst_mem_;
#endif
#endif // ENABLE_TRAINING

std::shared_ptr<dnnl::memory> src_mem_from_;
std::unique_ptr<dnnl::memory> src_mem_to_;

Expand All @@ -687,7 +691,12 @@ class DnnlConv : public DnnlKernel {
std::unique_ptr<dnnl::memory::desc> filter_md_;
std::unique_ptr<dnnl::memory::desc> bias_md_;

#ifndef ENABLE_TRAINING
std::unique_ptr<dnnl::convolution_forward::primitive_desc> conv_fwd_pd_;
#else
std::shared_ptr<dnnl::convolution_forward::primitive_desc> conv_fwd_pd_;
#endif // ENABLE_TRAINING

std::unique_ptr<dnnl::primitive> conv_fwd_;

dnnl::engine dnnl_engine_cpu_;
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/core/providers/dnnl/subgraph/dnnl_func_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ class SubgraphPrimitive : public PrimitiveBase {
for (size_t i = 0; i < context_.net.size(); ++i) {
context_.net.at(i).execute(*context_.stream, context_.net_args.at(i));
}

return Status::OK();
}

Expand All @@ -78,7 +77,7 @@ class SubgraphPrimitive : public PrimitiveBase {
kernel = std::make_shared<DnnlConv<T>>(dnnl_node, params.provider, *params.attributes, os.str());
#ifdef ENABLE_TRAINING
params.provider->fwd_conv_stack.emplace(kernel);
#endif
#endif // ENABLE_TRAINING
for (auto index : dnnl_node.parent_nodes) {
kernel->parents_.push_back(context_.kernels[index]);
}
Expand All @@ -103,7 +102,7 @@ class SubgraphPrimitive : public PrimitiveBase {
// figure out way to read the training_mode parameter from
// onnxruntime\core\framwork\run_options.h
params.provider->SetForwardKernel(dnnl_node.onnx_index, kernel);
#endif
#endif // ENABLE_TRAINING
for (auto index : dnnl_node.parent_nodes) {
kernel->parents_.push_back(context_.kernels[index]);
}
Expand Down Expand Up @@ -153,7 +152,7 @@ class SubgraphPrimitive : public PrimitiveBase {
kernel = std::make_shared<DnnlPool<T>>(dnnl_node, params.provider, *params.attributes, os.str());
#ifdef ENABLE_TRAINING
params.provider->SetForwardKernel(dnnl_node.onnx_index, kernel);
#endif
#endif // ENABLE_TRAINING
for (auto index : dnnl_node.parent_nodes) {
kernel->parents_.push_back(context_.kernels[index]);
}
Expand Down Expand Up @@ -296,7 +295,6 @@ class SubgraphPrimitivePool : public PrimitivePool<T> {
std::string dims_str;
for (auto i = 0; i < params.subgraph->dnnl_nodes[0].num_inputs; i++) {
const OrtValue* input_tensor = ort.KernelContext_GetInput(context, i);

auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
auto tensor_shape = ort.GetTensorShape(tensor_info);
ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
Expand Down Expand Up @@ -342,7 +340,7 @@ Status DnnlFuncKernel<T>::Compute(const OrtCustomOpApi* api, OrtKernelContext* c
std::unique_ptr<SubgraphPrimitive<T>> primitive = onnxruntime::make_unique<SubgraphPrimitive<T>>(api, context, params_);
#else
SubgraphPrimitive<T>* primitive = SubgraphPrimitivePool<T>::Get(api, context, params_);
#endif
#endif // ENABLE_TRAINING

primitive->UpdateProvider(params_);
status = primitive->Compute(api, context);
Expand Down
Loading