Skip to content

Commit 49ab008

Browse files
authored
Automated sync from github.com/tensorflow/tensorflow (#2280)
BUG=automated sync from upstream NO_CHECK_TFLITE_FILES=automated sync from upstream
1 parent ea4c7a6 commit 49ab008

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tensorflow/lite/kernels/internal/portable_tensor.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_PORTABLE_TENSOR_H_
1616
#define TENSORFLOW_LITE_KERNELS_INTERNAL_PORTABLE_TENSOR_H_
1717

18+
#include <cstddef>
1819
#include <vector>
1920

2021
#include "tensorflow/lite/core/c/common.h"
@@ -50,6 +51,26 @@ class VectorOfTensors {
5051
all_shape_ptr_.push_back(&all_shape_[i]);
5152
}
5253
}
54+
55+
explicit VectorOfTensors(const std::vector<TfLiteTensor*>& tensors) {
56+
int num_tensors = tensors.size();
57+
58+
all_data_.reserve(num_tensors);
59+
all_shape_.reserve(num_tensors);
60+
all_shape_ptr_.reserve(num_tensors);
61+
62+
for (auto* t : tensors) {
63+
all_data_.push_back(GetTensorData<T>(t));
64+
all_shape_.push_back(GetTensorShape(t));
65+
}
66+
67+
// Taking the pointer from inside a std::vector is only OK if the vector is
68+
// never modified, so we populate all_shape in the previous loop and then we
69+
// are free to grab iterators here.
70+
for (int i = 0; i < num_tensors; ++i) {
71+
all_shape_ptr_.push_back(&all_shape_[i]);
72+
}
73+
}
5374
// Return a pointer to the data pointers of all tensors in the list. For
5475
// example:
5576
// float* const* f = v.data();
@@ -62,6 +83,8 @@ class VectorOfTensors {
6283
// dims[1] are the dimensions of the second tensor in the list.
6384
const RuntimeShape* const* shapes() const { return all_shape_ptr_.data(); }
6485

86+
size_t size() const { return all_data_.size(); }
87+
6588
private:
6689
std::vector<T*> all_data_;
6790
std::vector<RuntimeShape> all_shape_;

0 commit comments

Comments
 (0)