@@ -15,6 +15,7 @@ limitations under the License.
1515#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_PORTABLE_TENSOR_H_
1616#define TENSORFLOW_LITE_KERNELS_INTERNAL_PORTABLE_TENSOR_H_
1717
18+ #include < cstddef>
1819#include < vector>
1920
2021#include " tensorflow/lite/core/c/common.h"
@@ -50,6 +51,26 @@ class VectorOfTensors {
5051 all_shape_ptr_.push_back (&all_shape_[i]);
5152 }
5253 }
54+
55+ explicit VectorOfTensors (const std::vector<TfLiteTensor*>& tensors) {
56+ int num_tensors = tensors.size ();
57+
58+ all_data_.reserve (num_tensors);
59+ all_shape_.reserve (num_tensors);
60+ all_shape_ptr_.reserve (num_tensors);
61+
62+ for (auto * t : tensors) {
63+ all_data_.push_back (GetTensorData<T>(t));
64+ all_shape_.push_back (GetTensorShape (t));
65+ }
66+
67+ // Taking the pointer from inside a std::vector is only OK if the vector is
68+ // never modified, so we populate all_shape in the previous loop and then we
69+ // are free to grab iterators here.
70+ for (int i = 0 ; i < num_tensors; ++i) {
71+ all_shape_ptr_.push_back (&all_shape_[i]);
72+ }
73+ }
5374 // Return a pointer to the data pointers of all tensors in the list. For
5475 // example:
5576 // float* const* f = v.data();
@@ -62,6 +83,8 @@ class VectorOfTensors {
6283 // dims[1] are the dimensions of the second tensor in the list.
6384 const RuntimeShape* const * shapes () const { return all_shape_ptr_.data (); }
6485
86+ size_t size () const { return all_data_.size (); }
87+
6588 private:
6689 std::vector<T*> all_data_;
6790 std::vector<RuntimeShape> all_shape_;
0 commit comments