@@ -83,25 +83,35 @@ class OpenvinoBackend final : public ::executorch::runtime::BackendInterface {
83
83
84
84
auto infer_request = execution_handle->infer_request ;
85
85
86
- // Assume first argument is the input tensor
87
- auto input_tensor = args[0 ]->toTensor ();
88
- ov::Shape input_shape (input_tensor.sizes ().begin (), input_tensor.sizes ().end ());
86
+ size_t num_inputs = infer_request->get_compiled_model ().inputs ().size ();
87
+ size_t num_outputs = infer_request->get_compiled_model ().outputs ().size ();
89
88
90
- // Convert input tensor to OpenVINO tensor
91
- ov::element::Type ov_type = convert_to_openvino_type (input_tensor.scalar_type ());
92
- ov::Tensor ov_input_tensor (ov_type, input_shape, input_tensor.mutable_data_ptr ());
89
+ // Set inputs
90
+ for (size_t i = 0 ; i < num_inputs; i++) {
91
+ auto input_tensor = args[i]->toTensor ();
92
+ ov::Shape input_shape (input_tensor.sizes ().begin (), input_tensor.sizes ().end ());
93
93
94
- // infer_request->set_tensor("input", ov_input_tensor);
95
- infer_request->set_input_tensor (0 , ov_input_tensor);
94
+ // Convert input tensor to OpenVINO tensor
95
+ ov::element::Type ov_type = convert_to_openvino_type (input_tensor.scalar_type ());
96
+ ov::Tensor ov_input_tensor (ov_type, input_shape, input_tensor.mutable_data_ptr ());
96
97
97
- // Execute the inference
98
- infer_request->infer ();
98
+ infer_request->set_input_tensor (i, ov_input_tensor);
99
+ }
100
+
101
+ // Set outputs
102
+ for (size_t i = 0 ; i < num_outputs; i++) {
103
+ auto output_tensor = args[num_inputs+i]->toTensor ();
104
+ ov::Shape output_shape (output_tensor.sizes ().begin (), output_tensor.sizes ().end ());
99
105
100
- // Retrieve and copy output
101
- auto output_tensor = args[ 1 ]-> toTensor (); // Assume second argument is the output
102
- ov::Tensor ov_output_tensor = infer_request-> get_output_tensor ( 0 ); // get_tensor("output" );
106
+ // Convert input tensor to OpenVINO tensor
107
+ ov::element::Type ov_type = convert_to_openvino_type (output_tensor. scalar_type ());
108
+ ov::Tensor ov_output_tensor (ov_type, output_shape, output_tensor. mutable_data_ptr () );
103
109
104
- std::memcpy (output_tensor.mutable_data_ptr (), ov_output_tensor.data (), ov_output_tensor.get_byte_size ());
110
+ infer_request->set_output_tensor (i, ov_output_tensor);
111
+ }
112
+
113
+ // Execute the inference
114
+ infer_request->infer ();
105
115
106
116
return Error::Ok;
107
117
}
0 commit comments