Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial PR for VSINPU execution provider #20903

Merged
merged 3 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add gather/squeeze/unsqueeze/tile etc ops
  • Loading branch information
chenfeiyue-cfy committed Jun 18, 2024
commit 38592651f952aa51946230fb724835c00241a548
44 changes: 22 additions & 22 deletions include/onnxruntime/core/providers/vsinpu/vsinpu_provider_factory.h
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
/****************************************************************************
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
*
* Copyright (c) 2023 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
*
* Copyright (c) 2023 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include "onnxruntime_c_api.h"

#ifdef __cplusplus
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/framework/node_unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ void NodeUnit::InitForSingleNode() {
const auto& output_defs = target_node_.OutputDefs();
const auto& node_attrs = target_node_.GetAttributes();
auto qlinear_type = GetQLinearOpType(target_node_);
if (qlinear_type == QLinearOpType::Unknown || IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support
if (qlinear_type == QLinearOpType::Unknown) {
// Not a Qlinear op, add all inputs / outputs
auto add_all_io = [](std::vector<NodeUnitIODef>& defs,
const ConstPointerContainer<std::vector<NodeArg*>>& node_defs) {
Expand Down Expand Up @@ -351,6 +351,12 @@ void NodeUnit::InitForSingleNode() {
NodeUnitIODef::QuantParam{*input_defs[1],
input_defs.size() == 3 ? input_defs[2] : nullptr,
axis}});
} else if (IsVariadicQLinearOp(qlinear_type)) {
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
size_t input_num = (input_defs.size() - 2) / 3;
for (size_t i = 0; i < input_num; i++) {
inputs_.push_back(NodeUnitIODef{*input_defs[3 * i + 2], NodeUnitIODef::QuantParam{*input_defs[3 * i + 3], input_defs[3 * i + 4]}});
}
outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[0], input_defs[1]}});
} else {
ORT_THROW("The QLinear op [", static_cast<uint8_t>(qlinear_type), "] is not supported");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initial
return false;
if (!has_initialized_quant_param(*input.quant_param->zero_point, initializers))
return false;
if (input.quant_param->zero_point->Type() != input.node_arg.Type()){
LOGS_DEFAULT(ERROR)<<"Invalid input type because the data type mismatch with its' quant param type.";
if (input.quant_param->zero_point->Type() != input.node_arg.Type()) {
LOGS_DEFAULT(ERROR) << "Invalid input type because the data type mismatch with its' quant param type.";
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ namespace onnxruntime {
namespace vsi {
namespace npu {
class ClipOpBuilder final : public BaseOpBuilder {
bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer,
bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer,
const Node* node) const override {
if (node->SinceVersion() > 6) {
if (node->SinceVersion() > 6) {
if (node->InputDefs().size() > 1 && !Contains(graph_viewer.GetAllInitializedTensors(), node->InputDefs()[1]->Name())) {
LOGS_DEFAULT(WARNING) << "Min/Max value must be const input or attribute.";
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class DequantizeLinearOpBuilder : public BaseOpBuilder {
};
bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers,
const NodeUnit& node_unit) const override {

auto input_type = node_unit.Inputs()[0].node_arg.Type();
if (*input_type == "tensor(int64)" || !util::IsTypeSupported(&node_unit.Inputs()[0].node_arg)) {
LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported input type : "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BatchNormOpBuilder : public BaseOpBuilder {
mean_tensor = 3,
var_tensor = 4
};
int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override{ return 9; }
int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 9; }

bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer,
const Node* node) const override {
Expand Down
24 changes: 12 additions & 12 deletions onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,40 +38,40 @@ class ReduceMeanOpBuilder : public BaseOpBuilder {
return true;
}
bool HandleBuildOp(vsi::npu::GraphEP* graph_ep,
std::vector<std::shared_ptr<tim::vx::Tensor>>& inputs,
std::vector<std::shared_ptr<tim::vx::Tensor>>& outputs,
const NodeUnit& node_unit) override {
std::vector<std::shared_ptr<tim::vx::Tensor>>& inputs,
std::vector<std::shared_ptr<tim::vx::Tensor>>& outputs,
const NodeUnit& node_unit) override {
LOGS_DEFAULT(INFO) << "Creating ReduceMean Op.";

NodeAttrHelper helper(node_unit.GetNode());
std::vector<int64_t> def_axes;
auto input_shape_size = inputs[0]->GetShape().size();

if (node_unit.SinceVersion() < 18 && helper.HasAttr("axes")) {
def_axes = helper.Get("axes", def_axes);
def_axes = helper.Get("axes", def_axes);
} else if (inputs.size() > 1) {
def_axes.resize(inputs[1]->GetSpec().GetElementNum());
inputs[1]->CopyDataFromTensor(def_axes.data());
def_axes.resize(inputs[1]->GetSpec().GetElementNum());
inputs[1]->CopyDataFromTensor(def_axes.data());
} else {
for (int64_t i = 0; i < input_shape_size; ++i) {
def_axes.push_back(i);
}
for (int64_t i = 0; i < input_shape_size; ++i) {
def_axes.push_back(i);
}
}

std::vector<int32_t> axes(def_axes.begin(), def_axes.end());
axes = util::ReverseAxis(axes, input_shape_size);

if (helper.HasAttr("noop_with_empty_axes") && inputs.size() == 1 && helper.Get("noop_with_empty_axes", 0) == 1) {
outputs[0] = inputs[0];
return true;
outputs[0] = inputs[0];
return true;
}

bool keepdims = helper.Get("keepdims", 1) == 1;
auto op = graph_ep->GetGraph()->CreateOperation<tim::vx::ops::ReduceMean>(axes, keepdims);
(*op).BindInput(inputs[0]).BindOutputs(outputs);
graph_ep->GetOps().push_back(std::move(op));
return true;
}
}
};
} // namespace npu

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class ResizeOpBuilder : public BaseOpBuilder {
return false;
}
auto& cooridinate = helper.Get("coordinate_transoformation_mode", "half_pixel");
if (cooridinate != "align_corners" && cooridinate != "half_pixel" && cooridinate != "half_pixel_symmetric") {
LOGS_DEFAULT(WARNING) << "Only support half_pixel_symmetric and align_corners attributes now.";
if (cooridinate != "align_corners" && cooridinate != "half_pixel") {
LOGS_DEFAULT(WARNING) << "Only support half_pixel and align_corners attributes now.";
return false;
}
if (helper.Get("keep_aspect_ratio_policy", "stretch") != "stretch") {
Expand Down Expand Up @@ -100,7 +100,7 @@ class ResizeOpBuilder : public BaseOpBuilder {

auto resize_type = onnx_mode == "nearest" ? tim::vx::ResizeType::NEAREST_NEIGHBOR : tim::vx::ResizeType::BILINEAR;
bool align_corners = coordinate_transformation == "align_corners";
bool half_pixel_center = coordinate_transformation == "half_pixel_symmetric";
bool half_pixel_center = coordinate_transformation == "half_pixel";
std::shared_ptr<tim::vx::Operation> op = nullptr;
if (is_1dresize) {
int target_size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ static const std::map<std::string, createIOpBuildItemFunc> reg = {
REGISTER_OP_BUILDER("Tile", TileOpBuilder),
REGISTER_OP_BUILDER("Squeeze", SqueezeOpBuilder),
REGISTER_OP_BUILDER("Unsqueeze", UnsqueezeOpBuilder),
REGISTER_OP_BUILDER("Resize",ResizeOpBuilder),
REGISTER_OP_BUILDER("Resize", ResizeOpBuilder),
REGISTER_OP_BUILDER("Cast", CastOpBuilder),

#undef REGISTER_OP_BUILDER
Expand Down
26 changes: 26 additions & 0 deletions onnxruntime/core/providers/vsinpu/patches/AccuracyCorrection.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc
index 47c18c478d..93b44501cd 100644
--- a/onnxruntime/test/providers/checkers.cc
+++ b/onnxruntime/test/providers/checkers.cc
@@ -195,7 +195,7 @@ struct TensorCheck<uint8_t> {
// For any other EPs, we still expect an exact match for the results
// TODO: Verify if DML can possibly have a ROUNDING_MODE parameter and conform to the other EPs #41968513
if ((provider_type == kNnapiExecutionProvider || provider_type == kDmlExecutionProvider ||
- provider_type == kXnnpackExecutionProvider) &&
+ provider_type == kXnnpackExecutionProvider || provider_type == kVSINPUExecutionProvider) &&
(has_abs_err || has_rel_err)) {
double threshold = has_abs_err ? *(params.absolute_error)
: 0.0;
diff --git a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc
index 2bc0df5e36..7beb78c2ff 100644
--- a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc
+++ b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc
@@ -498,7 +498,7 @@ class QLinearConvOpTester {
// NOTE, for now the tolerance will only apply if the NNAPI is actually used,
// if for any reason the execution falls back to CPU, we still expect an exact match
// See, 'void Check<uint8_t>(...' in onnxruntime/test/providers/provider_test_utils.cc
-#if defined(USE_NNAPI) || defined(USE_DML)
+#if defined(USE_NNAPI) || defined(USE_DML) || defined(USE_VSINPU)
// TODO: Verify if DML can possibly have a ROUNDING_MODE parameter and conform to the other EPs #41968513
abs_error = 1.0f;
#endif
22 changes: 0 additions & 22 deletions onnxruntime/core/providers/vsinpu/patches/int8_checker_hack.patch

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import sys
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved

import numpy as np
from numpy.linalg import norm


def read_values(filename):
with open(filename, 'r') as file:
with open(filename) as file:
values = np.array([float(line.strip()) for line in file])
return values


def cosine_similarity(vec1, vec2):
return np.dot(vec1, vec2) / (norm(vec1) * norm(vec2))


if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: python cosine_similarity.py <file1> <file2>")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import sys
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved


def read_values(filename):
with open(filename, 'r') as file:
with open(filename) as file:
values = [(float(line.strip()), i + 1) for i, line in enumerate(file)]
return values


def top_n(values, N):
return sorted(values, key=lambda x: x[0], reverse=True)[:N]


def compare_files(cpu_file, npu_file, N):
cpu_values = read_values(cpu_file)
npu_values = read_values(npu_file)
Expand All @@ -18,6 +21,7 @@ def compare_files(cpu_file, npu_file, N):
print(f"Top-{N} values in {cpu_file}: {cpu_topn}")
print(f"Top-{N} values in {npu_file}: {npu_topn}")


if __name__ == "__main__":
if len(sys.argv) != 4:
print("Usage: python compare_topn.py <N> <cpu_file> <npu_file>")
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include "vsinpu_ep_graph.h"
#include "builders/op_builder_factory.h"
#include "vsinpu_util.h"
#include <algorithm>
#include "core/providers/vsinpu/vsinpu_ep_graph.h"
#include "core/providers/vsinpu/builders/op_builder_factory.h"
#include "core/providers/vsinpu/vsinpu_util.h"
#include "core/framework/node_unit.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
Expand Down Expand Up @@ -66,7 +67,6 @@ bool GraphEP::Prepare() {
add_quantized_input(*node_unit, 2);
}
}

} // All quantized inputs is recorded
return true;
}
Expand Down
14 changes: 8 additions & 6 deletions onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
#pragma once
#include <map>
#include <vector>

#include <string>
#include <memory>
#include <unordered_map>
#include "builders/op_builder.h"
#include "tim/vx/context.h"
#include "tim/vx/graph.h"
Expand All @@ -49,8 +51,8 @@ struct NodeIOInfo {

class GraphEP {
public:
GraphEP(const GraphViewer& graph_viewer);
~GraphEP(){};
explicit GraphEP(const GraphViewer& graph_viewer);
~GraphEP() {}

bool Prepare();

Expand All @@ -65,18 +67,18 @@ class GraphEP {

bool& GetCompiled() { return compiled_; }
std::shared_ptr<tim::vx::Graph>& GetGraph() { return graph_; }
std::vector<std::shared_ptr<tim::vx::Operation>>& GetOps() { return ops_;}
std::vector<std::shared_ptr<tim::vx::Operation>>& GetOps() { return ops_; }
std::map<std::string, std::shared_ptr<tim::vx::Tensor>>& GetTensors() {
return tensors_;
}

std::vector<std::shared_ptr<GraphIOInfo>>& GetGraphInputs() {
return graph_inputs_;
};
}

std::vector<std::shared_ptr<GraphIOInfo>>& GetGraphOutputs() {
return graph_outputs_;
};
}

void UpdateTensorMap(const std::string& name, const std::shared_ptr<tim::vx::Tensor>& dst_tensor);

Expand Down
Loading