Skip to content

Commit f1c9661

Browse files
authored
add a candidate dense tensor class, test=develop (#28)
1 parent ce210b4 commit f1c9661

20 files changed

+838
-15
lines changed

paddle/pten/common/data_type.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,20 @@ inline size_t SizeOf(DataType data_type) {
7575
PADDLE_THROW(platform::errors::Unimplemented(
7676
"Data type %d is not supported by tensor.",
7777
static_cast<int>(data_type)));
78-
return 0;
7978
}
79+
return 0;
8080
}
8181

8282
#define PT_FOR_EACH_DATA_TYPE(_) \
8383
_(bool, DataType::BOOL) \
8484
_(int8_t, DataType::INT8) \
8585
_(uint8_t, DataType::UINT8) \
8686
_(int16_t, DataType::INT16) \
87-
_(int, DataType::INT32) \
87+
_(uint16_t, DataType::UINT16) \
88+
_(int32_t, DataType::INT32) \
89+
_(uint32_t, DataType::UINT32) \
8890
_(int64_t, DataType::INT64) \
91+
_(uint64_t, DataType::UINT64) \
8992
_(bfloat16, DataType::BFLOAT16) \
9093
_(float16, DataType::FLOAT16) \
9194
_(float, DataType::FLOAT32) \

paddle/pten/core/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
add_subdirectory(candidate)
2+
13
IF(WITH_MKLDNN)
24
set(MKLDNN_CTX_DEPS mkldnn)
35
ELSE()
@@ -15,3 +17,5 @@ cc_library(dense_tensor SRCS dense_tensor.cc DEPS enforce data_type ddim allocat
1517

1618
cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce)
1719
cc_library(kernel_context SRCS kernel_context.cc DEPS enforce device_context)
20+
21+
cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce)

paddle/pten/core/allocator.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ namespace pten {
2323
/// deallocation and construction/destruction of objects.
2424
class RawAllocator {
2525
public:
26+
using Place = paddle::platform::Place;
27+
2628
/// \brief Default destructor.
2729
virtual ~RawAllocator() = default;
2830

@@ -43,7 +45,7 @@ class RawAllocator {
4345

4446
/// \brief Get the place value of the allocator and the allocation.
4547
/// \return The place value of the allocator and the allocation.
46-
virtual const paddle::platform::Place& place() const = 0;
48+
virtual const Place& place() const = 0;
4749
};
4850

4951
/// \brief Fancy pointer with context. The use of this data type
@@ -52,24 +54,24 @@ class RawAllocator {
5254
/// support being inherited.
5355
class Allocation final {
5456
public:
57+
using Place = paddle::platform::Place;
5558
using DeleterFnPtr = void (*)(void*);
5659

5760
Allocation() = default;
5861
Allocation(Allocation&&) = default;
5962
Allocation& operator=(Allocation&&) = default;
6063

61-
Allocation(void* data, const paddle::platform::Place& place)
62-
: data_(data), place_(place) {}
64+
Allocation(void* data, const Place& place) : data_(data), place_(place) {}
6365

6466
Allocation(void* data,
6567
void* ctx,
6668
DeleterFnPtr ctx_deleter,
67-
const paddle::platform::Place& place)
69+
const Place& place)
6870
: data_(data), ctx_(ctx, ctx_deleter), place_(place) {}
6971

7072
void* operator->() const noexcept { return data_; }
7173
operator bool() const noexcept { return data_ || ctx_.Get(); }
72-
const paddle::platform::Place& place() const noexcept { return place_; }
74+
const Place& place() const noexcept { return place_; }
7375

7476
void Clear() noexcept {
7577
data_ = nullptr;
@@ -132,7 +134,7 @@ class Allocation final {
132134
Context ctx_;
133135
// TODO(Shixiaowei02): Enum needs to be used instead to reduce
134136
// the construction overhead by more than 50%.
135-
paddle::platform::Place place_;
137+
Place place_;
136138
};
137139

138140
inline void swap(Allocation::Context& a, Allocation::Context& b) noexcept {
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
cc_library(pten_dense_tensor SRCS dense_tensor.cc DEPS tensor_base)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/pten/core/candidate/dense_tensor.h"
16+
17+
namespace pten {
18+
namespace candidate {
19+
20+
DenseTensorMeta::DenseTensorMeta(DataType type, const DDim& dims)
21+
: dims(dims), type(type) {}
22+
DenseTensorMeta::DenseTensorMeta(DataType type,
23+
const DDim& dims,
24+
DataLayout layout)
25+
: dims(dims), type(type), layout(layout) {}
26+
DenseTensorMeta::DenseTensorMeta(DataType type,
27+
const DDim& dims,
28+
DataLayout layout,
29+
const std::vector<std::vector<size_t>>& lod)
30+
: dims(dims), type(type), layout(layout), lod(lod) {}
31+
32+
bool DenseTensorMeta::valid() const noexcept {
33+
bool valid{true};
34+
valid = valid && (type != DataType::UNDEFINED);
35+
valid = valid && (layout != DataLayout::UNDEFINED);
36+
valid = valid && (is_scalar || product(dims));
37+
return valid;
38+
}
39+
40+
DenseTensor::DenseTensor(const std::shared_ptr<Allocator>& a,
41+
const DenseTensorMeta& meta)
42+
: meta_(meta),
43+
storage_(
44+
make_intrusive<TensorStorage>(a, SizeOf(data_type()) * numel())) {}
45+
46+
DenseTensor::DenseTensor(const std::shared_ptr<Allocator>& a,
47+
DenseTensorMeta&& meta)
48+
: meta_(std::move(meta)),
49+
storage_(
50+
make_intrusive<TensorStorage>(a, SizeOf(data_type()) * numel())) {}
51+
52+
DenseTensor::DenseTensor(intrusive_ptr<Storage> storage,
53+
const DenseTensorMeta& meta)
54+
: meta_(meta), storage_(std::move(storage)) {}
55+
56+
DenseTensor::DenseTensor(intrusive_ptr<Storage> storage, DenseTensorMeta&& meta)
57+
: meta_(std::move(meta)), storage_(std::move(storage)) {}
58+
59+
int64_t DenseTensor::numel() const {
60+
if (meta_.is_scalar) {
61+
return 1;
62+
}
63+
return product(meta_.dims);
64+
}
65+
66+
bool DenseTensor::SharesStorageWith(const DenseTensor& b) const {
67+
return storage_.get() == b.storage_.get() && storage_.get() != nullptr;
68+
}
69+
70+
template <typename T>
71+
T* DenseTensor::mutable_data(size_t request_bytes) {
72+
PADDLE_ENFORCE(
73+
valid(),
74+
paddle::platform::errors::PreconditionNotMet(
75+
"The meta data must be valid when call the mutable data function."));
76+
PADDLE_ENFORCE_NOT_NULL(
77+
storage_,
78+
paddle::platform::errors::PreconditionNotMet(
79+
"The storage must be valid when call the mutable data function."));
80+
PADDLE_ENFORCE(
81+
(data_type() == paddle::experimental::CppTypeToDataType<T>::Type()),
82+
paddle::platform::errors::PreconditionNotMet(
83+
"The type of data we are trying to retrieve does not match the "
84+
"type of data currently contained in the container."));
85+
size_t bytes = numel() * SizeOf(data_type());
86+
if (request_bytes) {
87+
PADDLE_ENFORCE_GE(request_bytes,
88+
bytes,
89+
paddle::platform::errors::InvalidArgument(
90+
"The reserved size %d should be enough to meet the "
91+
"volume required by metadata %d.",
92+
request_bytes,
93+
bytes));
94+
bytes = request_bytes;
95+
}
96+
if (storage_->size() < bytes) {
97+
storage_->Realloc(bytes);
98+
}
99+
return static_cast<T*>(storage_->data());
100+
}
101+
102+
template <typename T>
103+
const T* DenseTensor::data() const {
104+
PADDLE_ENFORCE_NOT_NULL(
105+
storage_,
106+
paddle::platform::errors::PreconditionNotMet(
107+
"The storage must be valid when call the mutable data function."));
108+
PADDLE_ENFORCE(
109+
(data_type() == paddle::experimental::CppTypeToDataType<T>::Type()),
110+
paddle::platform::errors::PreconditionNotMet(
111+
"The type of data we are trying to retrieve does not match the "
112+
"type of data currently contained in the container."));
113+
return static_cast<const T*>(storage_->data());
114+
}
115+
116+
void DenseTensor::check_memory_size() const {
117+
size_t bytes = numel() * SizeOf(data_type());
118+
PADDLE_ENFORCE_GE(memory_size(),
119+
bytes,
120+
paddle::platform::errors::InvalidArgument(
121+
"The memory size %d should be enough to meet the "
122+
"volume required by metadata %d.",
123+
memory_size(),
124+
bytes));
125+
}
126+
127+
#define DATA_MEMBER_FUNC_INSTANTIATION(dtype) \
128+
template dtype* DenseTensor::mutable_data(size_t request_bytes); \
129+
template const dtype* DenseTensor::data() const;
130+
131+
DATA_MEMBER_FUNC_INSTANTIATION(int8_t);
132+
DATA_MEMBER_FUNC_INSTANTIATION(uint8_t);
133+
DATA_MEMBER_FUNC_INSTANTIATION(int16_t);
134+
DATA_MEMBER_FUNC_INSTANTIATION(uint16_t);
135+
DATA_MEMBER_FUNC_INSTANTIATION(int32_t);
136+
DATA_MEMBER_FUNC_INSTANTIATION(uint32_t);
137+
DATA_MEMBER_FUNC_INSTANTIATION(int64_t);
138+
DATA_MEMBER_FUNC_INSTANTIATION(uint64_t);
139+
DATA_MEMBER_FUNC_INSTANTIATION(float);
140+
DATA_MEMBER_FUNC_INSTANTIATION(double);
141+
142+
#undef DATA_MEMBER_FUNC_INSTANTIATION
143+
144+
} // namespace candidate
145+
} // namespace pten

0 commit comments

Comments
 (0)