Skip to content

Commit

Permalink
[PyOV] Missed API for Node (openvinotoolkit#24427)
Browse files Browse the repository at this point in the history
### Details:
 - Add missed methods for Node:
"get_input_element_type",
"get_input_partial_shape",
"get_input_shape",
"set_output_type",
"set_output_size",
"validate_and_infer_types"

### Tickets:
 - CVS-141050

---------

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
  • Loading branch information
akuporos and mlukasze authored May 15, 2024
1 parent 026ac9e commit 460604c
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 35 deletions.
98 changes: 85 additions & 13 deletions src/bindings/python/src/pyopenvino/graph/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "pyopenvino/graph/any.hpp"
#include "pyopenvino/graph/node.hpp"
#include "pyopenvino/graph/rt_map.hpp"
#include "pyopenvino/utils/utils.hpp"

class PyNode : public ov::Node {
public:
Expand Down Expand Up @@ -178,7 +179,7 @@ void regclass_graph_Node(py::module m) {
:param index: Index of Input.
:type index: int
:return: Tensor of the input i
:return: Tensor of the input index
:rtype: openvino._pyopenvino.DescriptorTensor
)");
node.def("get_element_type",
Expand Down Expand Up @@ -216,6 +217,63 @@ void regclass_graph_Node(py::module m) {
:return: Number of inputs.
:rtype: int
)");
node.def("get_input_element_type",
&ov::Node::get_input_element_type,
py::arg("index"),
R"(
Returns the element type for input index
:param index: Index of the input.
:type index: int
:return: Type of the input index
:rtype: openvino.Type
)");
node.def("get_input_partial_shape",
&ov::Node::get_input_partial_shape,
py::arg("index"),
R"(
Returns the partial shape for input index
:param index: Index of the input.
:type index: int
:return: PartialShape of the input index
:rtype: openvino.PartialShape
)");
node.def("get_input_shape",
&ov::Node::get_input_shape,
py::arg("index"),
R"(
Returns the shape for input index
:param index: Index of the input.
:type index: int
:return: Shape of the input index
:rtype: openvino.Shape
)");
node.def("set_output_type",
&ov::Node::set_output_type,
py::arg("index"),
py::arg("element_type"),
py::arg("shape"),
R"(
Sets output's element type and shape.
:param index: Index of the output.
:type index: int
:param element_type: Element type of the output.
:type element_type: openvino.Type
:param shape: Shape of the output.
:type shape: openvino.PartialShape
)");
node.def("set_output_size",
&ov::Node::set_output_size,
py::arg("size"),
R"(
Sets the number of outputs
:param size: number of outputs.
:type size: int
)");
node.def("get_output_size",
&ov::Node::get_output_size,
R"(
Expand All @@ -228,45 +286,45 @@ void regclass_graph_Node(py::module m) {
&ov::Node::get_output_element_type,
py::arg("index"),
R"(
Returns the element type for output i
Returns the element type for output index
:param index: Index of the output.
:type index: int
:return: Type of the output i
:return: Type of the output index
:rtype: openvino.runtime.Type
)");
node.def("get_output_shape",
&ov::Node::get_output_shape,
py::arg("index"),
R"(
Returns the shape for output i
Returns the shape for output index
:param index: Index of the output.
:return: Shape of the output i
:type index: int
:return: Shape of the output index
:rtype: openvino.runtime.Shape
)");
node.def("get_output_partial_shape",
&ov::Node::get_output_partial_shape,
py::arg("index"),
R"(
Returns the partial shape for output i
Returns the partial shape for output index
:param index: Index of the output.
:type index: int
:return: PartialShape of the output i
:return: PartialShape of the output index
:rtype: openvino.runtime.PartialShape
)");
node.def("get_output_tensor",
&ov::Node::get_output_tensor,
py::arg("index"),
py::return_value_policy::reference_internal,
R"(
Returns the tensor for output i
Returns the tensor for output index
:param index: Index of the output.
:type index: int
:return: Tensor of the output i
:return: Tensor of the output index
:rtype: openvino._pyopenvino.DescriptorTensor
)");
node.def("get_type_name",
Expand Down Expand Up @@ -382,10 +440,24 @@ void regclass_graph_Node(py::module m) {
util::DictAttributeDeserializer dict_deserializer(attr_dict, variables);
self->visit_attributes(dict_deserializer);
});
node.def("set_arguments", [](const std::shared_ptr<ov::Node>& self, const ov::OutputVector& arguments) {
return self->set_arguments(arguments);
});
node.def("validate", [](const std::shared_ptr<ov::Node>& self) {
Common::utils::deprecation_warning("validate",
"2024.4",
"Please use 'constructor_validate_and_infer_types' method instead.");
return self->constructor_validate_and_infer_types();
});
node.def("constructor_validate_and_infer_types", [](const std::shared_ptr<ov::Node>& self) {
return self->constructor_validate_and_infer_types();
});
node.def(
"validate_and_infer_types",
[](const std::shared_ptr<ov::Node>& self) {
return self->validate_and_infer_types();
},
R"(
Verifies that attributes and inputs are consistent and computes output shapes and element types.
Must be implemented by concrete child classes so that it can be run any number of times.
Throws if the node is invalid.
)");
}
27 changes: 7 additions & 20 deletions src/bindings/python/tests/test_graph/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,15 +348,12 @@ def test_result():

def test_node_friendly_name():
dummy_node = ops.parameter(shape=[1], name="dummy_name")

assert (dummy_node.friendly_name == "dummy_name")

dummy_node.set_friendly_name("changed_name")

assert (dummy_node.get_friendly_name() == "changed_name")

dummy_node.friendly_name = "new_name"

assert (dummy_node.get_friendly_name() == "new_name")


Expand Down Expand Up @@ -393,30 +390,20 @@ def test_node_output():
assert [output0.get_index(), output1.get_index(), output2.get_index()] == [0, 1, 2]


def test_node_input_size():
node = ops.add([1], [2])
assert node.get_input_size() == 2


def test_node_input_values():
shapes = [Shape([3]), Shape([3])]
data1 = np.array([1, 2, 3])
data2 = np.array([3, 2, 1])
data1 = np.array([1, 2, 3], dtype=np.int64)
data2 = np.array([3, 2, 1], dtype=np.int64)

node = ops.add(data1, data2)

assert node.get_input_size() == 2
assert node.get_input_element_type(0) == Type.i64
assert node.get_input_partial_shape(0) == PartialShape([3])
assert node.get_input_shape(1) == Shape([3])

assert np.equal(
[input_node.get_shape() for input_node in node.input_values()],
shapes,
).all()

assert np.equal(
[node.input_value(i).get_shape() for i in range(node.get_input_size())],
shapes,
).all()

assert np.equal([input_node.get_shape() for input_node in node.input_values()], shapes,).all()
assert np.equal([node.input_value(i).get_shape() for i in range(node.get_input_size())], shapes,).all()
assert np.allclose(
[input_node.get_node().get_vector() for input_node in node.input_values()],
[data1, data2],
Expand Down
6 changes: 5 additions & 1 deletion src/bindings/python/tests/test_graph/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import numpy as np
import openvino.runtime.opset8 as ov
from openvino import Model, Shape
Expand Down Expand Up @@ -96,7 +97,10 @@ def test_loop_inputs_are_nodes():
loop.set_invariant_input(y_i, param_y.output(0))
loop.set_merged_input(m_body, param_m.output(0), zo.output(0))
loop.set_special_body_ports([-1, 0])
loop.validate()
with pytest.warns(DeprecationWarning, match="validate is deprecated and will be removed in version 2024.4."):
loop.validate()

loop.constructor_validate_and_infer_types()

out0 = loop.get_iter_value(body_condition.output(0), -1)
out1 = loop.get_iter_value(zo.output(0), -1)
Expand Down
5 changes: 4 additions & 1 deletion src/bindings/python/tests/test_graph/test_node_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def test_node_factory_empty_topk_with_args_and_attrs():
node.set_attribute("axis", 1)
node.set_attribute("mode", "max")
node.set_attribute("sort", "value")
node.validate()
with pytest.warns(DeprecationWarning, match="validate is deprecated and will be removed in version 2024.4."):
node.validate()

node.constructor_validate_and_infer_types()

assert node.get_type_name() == "TopK"
assert node.get_output_size() == 2
Expand Down

0 comments on commit 460604c

Please sign in to comment.