Skip to content

Commit 1646461

Browse files
committed
Introduce NumericTensor class
This commit defines the new NumericTensor<T> class as a subclass of Tensor class. NumericTensor<T> extends Tensor class by adding a member function to access element values in a tensor.
1 parent ada04ca commit 1646461

File tree

3 files changed

+136
-1
lines changed

3 files changed

+136
-1
lines changed

cpp/src/arrow/tensor-test.cc

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,56 @@ TEST(TestTensor, ZeroDimensionalTensor) {
104104
ASSERT_EQ(t.strides().size(), 1);
105105
}
106106

107+
TEST(TestNumericTensor, ElementAccess) {
108+
std::vector<int64_t> shape = {3, 4};
109+
110+
std::vector<int64_t> values_i64 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
111+
std::shared_ptr<Buffer> buffer_i64(Buffer::Wrap(values_i64));
112+
NumericTensor<Int64Type> t_i64(buffer_i64, shape);
113+
114+
ASSERT_EQ(1, t_i64.Value({0, 0}));
115+
ASSERT_EQ(5, t_i64.Value({1, 0}));
116+
ASSERT_EQ(6, t_i64.Value({1, 1}));
117+
ASSERT_EQ(11, t_i64.Value({2, 2}));
118+
119+
std::vector<float> values_f32 = {1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f,
120+
7.1f, 8.1f, 9.1f, 10.1f, 11.1f, 12.1f};
121+
std::shared_ptr<Buffer> buffer_f32(Buffer::Wrap(values_f32));
122+
NumericTensor<FloatType> t_f32(buffer_f32, shape);
123+
124+
ASSERT_EQ(1.1f, t_f32.Value({0, 0}));
125+
ASSERT_EQ(5.1f, t_f32.Value({1, 0}));
126+
ASSERT_EQ(6.1f, t_f32.Value({1, 1}));
127+
ASSERT_EQ(11.1f, t_f32.Value({2, 2}));
128+
}
129+
130+
TEST(TestNumericTensor, ElementAccessWithStrides) {
131+
std::vector<int64_t> shape = {3, 4};
132+
133+
const int64_t i64_size = sizeof(int64_t);
134+
std::vector<int64_t> values_i64 = {1, 2, 3, 4, 0, 0, 5, 6, 7,
135+
8, 0, 0, 9, 10, 11, 12, 0, 0};
136+
std::vector<int64_t> strides_i64 = {i64_size * 6, i64_size};
137+
std::shared_ptr<Buffer> buffer_i64(Buffer::Wrap(values_i64));
138+
NumericTensor<Int64Type> t_i64(buffer_i64, shape, strides_i64);
139+
140+
ASSERT_EQ(1, t_i64.Value({0, 0}));
141+
ASSERT_EQ(5, t_i64.Value({1, 0}));
142+
ASSERT_EQ(6, t_i64.Value({1, 1}));
143+
ASSERT_EQ(11, t_i64.Value({2, 2}));
144+
145+
const int64_t f32_size = sizeof(float);
146+
std::vector<float> values_f32 = {1.1f, 2.1f, 3.1f, 4.1f, 0.0f, 0.0f,
147+
5.1f, 6.1f, 7.1f, 8.1f, 0.0f, 0.0f,
148+
9.1f, 10.1f, 11.1f, 12.1f, 0.0f, 0.0f};
149+
std::vector<int64_t> strides_f32 = {f32_size * 6, f32_size};
150+
std::shared_ptr<Buffer> buffer_f32(Buffer::Wrap(values_f32));
151+
NumericTensor<FloatType> t_f32(buffer_f32, shape, strides_f32);
152+
153+
ASSERT_EQ(1.1f, t_f32.Value({0, 0}));
154+
ASSERT_EQ(5.1f, t_f32.Value({1, 0}));
155+
ASSERT_EQ(6.1f, t_f32.Value({1, 1}));
156+
ASSERT_EQ(11.1f, t_f32.Value({2, 2}));
157+
}
158+
107159
} // namespace arrow

cpp/src/arrow/tensor.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include "arrow/compare.h"
2828
#include "arrow/type.h"
29+
#include "arrow/type_traits.h"
2930
#include "arrow/util/checked_cast.h"
3031
#include "arrow/util/logging.h"
3132

@@ -121,4 +122,58 @@ Type::type Tensor::type_id() const { return type_->id(); }
121122

122123
bool Tensor::Equals(const Tensor& other) const { return TensorEquals(*this, other); }
123124

125+
// ----------------------------------------------------------------------
126+
// NumericTensor
127+
128+
template <typename TYPE>
129+
NumericTensor<TYPE>::NumericTensor(const std::shared_ptr<Buffer>& data,
130+
const std::vector<int64_t>& shape)
131+
: NumericTensor(data, shape, {}, {}) {}
132+
133+
template <typename TYPE>
134+
NumericTensor<TYPE>::NumericTensor(const std::shared_ptr<Buffer>& data,
135+
const std::vector<int64_t>& shape,
136+
const std::vector<int64_t>& strides)
137+
: NumericTensor(data, shape, strides, {}) {}
138+
139+
template <typename TYPE>
140+
NumericTensor<TYPE>::NumericTensor(const std::shared_ptr<Buffer>& data,
141+
const std::vector<int64_t>& shape,
142+
const std::vector<int64_t>& strides,
143+
const std::vector<std::string>& dim_names)
144+
: Tensor(TypeTraits<TYPE>::type_singleton(), data, shape, strides, dim_names) {}
145+
146+
template <typename TYPE>
147+
int64_t NumericTensor<TYPE>::CalculateValueOffset(
148+
const std::vector<int64_t>& index) const {
149+
int64_t offset = 0;
150+
if (strides_.size() > 0) {
151+
for (size_t i = 0; i < index.size(); ++i) {
152+
offset += index[i] * strides_[i];
153+
}
154+
} else {
155+
for (size_t i = 0; i < index.size(); ++i) {
156+
offset = index[i] + offset * shape_[i];
157+
}
158+
offset *= static_cast<int64_t>(sizeof(value_type));
159+
}
160+
161+
return offset;
162+
}
163+
164+
// ----------------------------------------------------------------------
165+
// Instantiate templates
166+
167+
template class ARROW_TEMPLATE_EXPORT NumericTensor<UInt8Type>;
168+
template class ARROW_TEMPLATE_EXPORT NumericTensor<UInt16Type>;
169+
template class ARROW_TEMPLATE_EXPORT NumericTensor<UInt32Type>;
170+
template class ARROW_TEMPLATE_EXPORT NumericTensor<UInt64Type>;
171+
template class ARROW_TEMPLATE_EXPORT NumericTensor<Int8Type>;
172+
template class ARROW_TEMPLATE_EXPORT NumericTensor<Int16Type>;
173+
template class ARROW_TEMPLATE_EXPORT NumericTensor<Int32Type>;
174+
template class ARROW_TEMPLATE_EXPORT NumericTensor<Int64Type>;
175+
template class ARROW_TEMPLATE_EXPORT NumericTensor<HalfFloatType>;
176+
template class ARROW_TEMPLATE_EXPORT NumericTensor<FloatType>;
177+
template class ARROW_TEMPLATE_EXPORT NumericTensor<DoubleType>;
178+
124179
} // namespace arrow

cpp/src/arrow/tensor.h

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class ARROW_EXPORT Tensor {
6262
Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
6363
const std::vector<int64_t>& shape, const std::vector<int64_t>& strides);
6464

65-
/// Constructor with strides and dimension names
65+
/// Constructor with non-negative strides and dimension names
6666
Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
6767
const std::vector<int64_t>& shape, const std::vector<int64_t>& strides,
6868
const std::vector<std::string>& dim_names);
@@ -114,6 +114,34 @@ class ARROW_EXPORT Tensor {
114114
ARROW_DISALLOW_COPY_AND_ASSIGN(Tensor);
115115
};
116116

117+
template <typename TYPE>
118+
class ARROW_EXPORT NumericTensor : public Tensor {
119+
public:
120+
using TypeClass = TYPE;
121+
using value_type = typename TypeClass::c_type;
122+
123+
/// Constructor with no dimension names or strides, data assumed to be row-major
124+
NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape);
125+
126+
/// Constructor with non-negative strides
127+
NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape,
128+
const std::vector<int64_t>& strides);
129+
130+
/// Constructor with non-negative strides and dimension names
131+
NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape,
132+
const std::vector<int64_t>& strides,
133+
const std::vector<std::string>& dim_names);
134+
135+
const value_type& Value(const std::vector<int64_t>& index) const {
136+
int64_t offset = CalculateValueOffset(index);
137+
const value_type* ptr = reinterpret_cast<const value_type*>(raw_data() + offset);
138+
return *ptr;
139+
}
140+
141+
protected:
142+
int64_t CalculateValueOffset(const std::vector<int64_t>& index) const;
143+
};
144+
117145
} // namespace arrow
118146

119147
#endif // ARROW_TENSOR_H

0 commit comments

Comments
 (0)