Skip to content

Commit fd97d7d

Browse files
authored
[IR] Value system && Operation (PaddlePaddle#51992)
* add Value OpResult OpOperand class * add Value OpResult OpOperand class * fix bug * fix bug * add utils * refine code * add ptr offset and reset method * add value impl * fix bug * refine comment of ValueImpl * refine code of OpResult * refine code of Value * add some comment * fix cpu compile bug * refine code * add op * add method for op & test value * refine unittest * refine code by comment * refine code * refine code * refine code * refine code
1 parent 8cbeefe commit fd97d7d

15 files changed

+988
-32
lines changed

paddle/ir/builtin_attribute.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/ir/attribute.h"
1818
#include "paddle/ir/builtin_attribute_storage.h"
19+
#include "paddle/ir/utils.h"
1920

2021
namespace ir {
2122
///
@@ -82,15 +83,11 @@ class DictionaryAttribute : public ir::Attribute {
8283
} // namespace ir
8384

8485
namespace std {
85-
static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
86-
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
87-
}
88-
8986
template <>
9087
struct hash<ir::NamedAttribute> {
9188
std::size_t operator()(const ir::NamedAttribute &obj) const {
92-
return hash_combine(std::hash<ir::Attribute>()(obj.name_),
93-
std::hash<ir::Attribute>()(obj.value_));
89+
return ir::hash_combine(std::hash<ir::Attribute>()(obj.name_),
90+
std::hash<ir::Attribute>()(obj.value_));
9491
}
9592
};
9693
} // namespace std

paddle/ir/builtin_attribute_storage.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/ir/builtin_attribute_storage.h"
1616
#include "paddle/ir/builtin_attribute.h"
17+
#include "paddle/ir/utils.h"
1718

1819
namespace ir {
1920

@@ -32,7 +33,7 @@ DictionaryAttributeStorage::DictionaryAttributeStorage(const ParamKey &key) {
3233
std::size_t DictionaryAttributeStorage::HashValue(const ParamKey &key) {
3334
std::size_t hash_value = key.size();
3435
for (auto iter = key.begin(); iter != key.end(); ++iter) {
35-
hash_value = hash_combine(
36+
hash_value = ir::hash_combine(
3637
hash_value,
3738
std::hash<NamedAttribute>()(NamedAttribute(iter->first, iter->second)));
3839
}

paddle/ir/builtin_attribute_storage.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,6 @@ struct DictionaryAttributeStorage : public AttributeStorage {
8383
uint32_t size() const { return size_; }
8484

8585
private:
86-
static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
87-
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
88-
}
89-
9086
NamedAttribute *data_;
9187
uint32_t size_;
9288
};

paddle/ir/builtin_type_storage.h

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <type_traits>
1818

1919
#include "paddle/ir/type.h"
20+
#include "paddle/ir/utils.h"
2021

2122
namespace std {
2223
///
@@ -109,20 +110,22 @@ struct DenseTensorTypeStorage : public ir::TypeStorage {
109110
std::size_t hash_value = 0;
110111
// hash dtype
111112
hash_value =
112-
hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key)));
113+
ir::hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key)));
113114
// hash dims
114-
hash_value = hash_combine(hash_value, std::hash<Dim>()(std::get<1>(key)));
115-
// hash layout
116115
hash_value =
117-
hash_combine(hash_value,
118-
std::hash<std::underlying_type<DataLayout>::type>()(
119-
static_cast<std::underlying_type<DataLayout>::type>(
120-
std::get<2>(key))));
116+
ir::hash_combine(hash_value, std::hash<Dim>()(std::get<1>(key)));
117+
// hash layout
118+
hash_value = ir::hash_combine(
119+
hash_value,
120+
std::hash<std::underlying_type<DataLayout>::type>()(
121+
static_cast<std::underlying_type<DataLayout>::type>(
122+
std::get<2>(key))));
121123
// hash lod
122-
hash_value = hash_combine(hash_value, std::hash<LoD>()(std::get<3>(key)));
124+
hash_value =
125+
ir::hash_combine(hash_value, std::hash<LoD>()(std::get<3>(key)));
123126
// hash offset
124127
hash_value =
125-
hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key)));
128+
ir::hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key)));
126129
return hash_value;
127130
}
128131

@@ -146,11 +149,6 @@ struct DenseTensorTypeStorage : public ir::TypeStorage {
146149
DataLayout layout_;
147150
LoD lod_;
148151
size_t offset_;
149-
150-
private:
151-
static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) {
152-
return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
153-
}
154152
};
155153

156154
} // namespace ir

paddle/ir/op_base.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/ir/operation.h"
18+
19+
namespace ir {
20+
class OpBase {
21+
public:
22+
Operation *operation() { return operation_; }
23+
24+
explicit operator bool() { return operation() != nullptr; }
25+
26+
operator Operation *() const { return operation_; }
27+
28+
Operation *operator->() const { return operation_; }
29+
30+
protected:
31+
explicit OpBase(Operation *operation) : operation_(operation) {}
32+
33+
private:
34+
Operation *operation_;
35+
};
36+
37+
} // namespace ir

paddle/ir/operation.cc

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/ir/operation.h"
16+
#include "paddle/ir/utils.h"
17+
18+
namespace ir {
19+
// Allocate the required memory based on the size and number of inputs, outputs,
20+
// and operators, and construct it in the order of: OpOutlineResult,
21+
// OpInlineResult, Operation, Operand.
22+
Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
23+
const std::vector<ir::Type> &output_types,
24+
ir::DictionaryAttribute attribute) {
25+
// 1. Calculate the required memory size for OpResults + Operation +
26+
// OpOperands.
27+
uint32_t num_results = output_types.size();
28+
uint32_t num_operands = inputs.size();
29+
uint32_t max_inline_result_num =
30+
detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
31+
size_t result_mem_size =
32+
num_results > max_inline_result_num
33+
? sizeof(detail::OpOutlineResultImpl) *
34+
(num_results - max_inline_result_num) +
35+
sizeof(detail::OpInlineResultImpl) * max_inline_result_num
36+
: sizeof(detail::OpInlineResultImpl) * num_results;
37+
size_t operand_mem_size = sizeof(detail::OpOperandImpl) * num_operands;
38+
size_t op_mem_size = sizeof(Operation);
39+
size_t base_size = result_mem_size + op_mem_size + operand_mem_size;
40+
// 2. Malloc memory.
41+
char *base_ptr = reinterpret_cast<char *>(aligned_malloc(base_size, 8));
42+
// 3.1. Construct OpResults.
43+
for (size_t idx = num_results; idx > 0; idx--) {
44+
if (idx > max_inline_result_num) {
45+
new (base_ptr)
46+
detail::OpOutlineResultImpl(output_types[idx - 1], idx - 1);
47+
base_ptr += sizeof(detail::OpOutlineResultImpl);
48+
} else {
49+
new (base_ptr) detail::OpInlineResultImpl(output_types[idx - 1], idx - 1);
50+
base_ptr += sizeof(detail::OpInlineResultImpl);
51+
}
52+
}
53+
// 3.2. Construct Operation.
54+
Operation *op =
55+
new (base_ptr) Operation(num_results, num_operands, attribute);
56+
base_ptr += sizeof(Operation);
57+
// 3.3. Construct OpOperands.
58+
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
59+
throw("The address of OpOperandImpl must be divisible by 8.");
60+
}
61+
for (size_t idx = 0; idx < num_operands; idx++) {
62+
new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
63+
base_ptr += sizeof(detail::OpOperandImpl);
64+
}
65+
VLOG(4) << "Construct an Operation: " << op->print();
66+
return op;
67+
}
68+
69+
// Call destructors for OpResults, Operation, and OpOperands in sequence, and
70+
// finally free memory.
71+
void Operation::destroy() {
72+
// 1. Get aligned_ptr by result_num.
73+
uint32_t max_inline_result_num =
74+
detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
75+
size_t result_mem_size =
76+
num_results_ > max_inline_result_num
77+
? sizeof(detail::OpOutlineResultImpl) *
78+
(num_results_ - max_inline_result_num) +
79+
sizeof(detail::OpInlineResultImpl) * max_inline_result_num
80+
: sizeof(detail::OpInlineResultImpl) * num_results_;
81+
char *aligned_ptr = reinterpret_cast<char *>(this) - result_mem_size;
82+
// 2.1. Deconstruct OpResult.
83+
char *base_ptr = aligned_ptr;
84+
for (size_t idx = num_results_; idx > 0; idx--) {
85+
if (!reinterpret_cast<detail::OpResultImpl *>(base_ptr)->use_empty()) {
86+
throw("Cannot destroy a value that still has uses!");
87+
}
88+
if (idx > max_inline_result_num) {
89+
reinterpret_cast<detail::OpOutlineResultImpl *>(base_ptr)
90+
->~OpOutlineResultImpl();
91+
base_ptr += sizeof(detail::OpOutlineResultImpl);
92+
} else {
93+
reinterpret_cast<detail::OpInlineResultImpl *>(base_ptr)
94+
->~OpInlineResultImpl();
95+
base_ptr += sizeof(detail::OpInlineResultImpl);
96+
}
97+
}
98+
// 2.2. Deconstruct Operation.
99+
if (reinterpret_cast<uintptr_t>(base_ptr) !=
100+
reinterpret_cast<uintptr_t>(this)) {
101+
throw("Operation address error");
102+
}
103+
reinterpret_cast<Operation *>(base_ptr)->~Operation();
104+
base_ptr += sizeof(Operation);
105+
// 2.3. Deconstruct OpOpOerand.
106+
for (size_t idx = 0; idx < num_operands_; idx++) {
107+
reinterpret_cast<detail::OpOperandImpl *>(base_ptr)->~OpOperandImpl();
108+
base_ptr += sizeof(detail::OpOperandImpl);
109+
}
110+
// 3. Free memory.
111+
VLOG(4) << "Destroy an Operation: {ptr = "
112+
<< reinterpret_cast<void *>(aligned_ptr)
113+
<< ", size = " << result_mem_size << "}";
114+
aligned_free(reinterpret_cast<void *>(aligned_ptr));
115+
}
116+
117+
Operation::Operation(uint32_t num_results,
118+
uint32_t num_operands,
119+
ir::DictionaryAttribute attribute) {
120+
if (!attribute) {
121+
throw("unexpected null attribute dictionary");
122+
}
123+
num_results_ = num_results;
124+
num_operands_ = num_operands;
125+
attribute_ = attribute;
126+
}
127+
128+
ir::OpResult Operation::GetResultByIndex(uint32_t index) {
129+
if (index >= num_results_) {
130+
throw("index exceeds OP output range.");
131+
}
132+
uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex();
133+
char *ptr = nullptr;
134+
if (index > max_inline_idx) {
135+
ptr = reinterpret_cast<char *>(this) -
136+
(max_inline_idx + 1) * sizeof(detail::OpInlineResultImpl) -
137+
(index - max_inline_idx) * sizeof(detail::OpOutlineResultImpl);
138+
} else {
139+
ptr = reinterpret_cast<char *>(this) -
140+
(index + 1) * sizeof(detail::OpInlineResultImpl);
141+
}
142+
if (index > max_inline_idx) {
143+
detail::OpOutlineResultImpl *result_impl_ptr =
144+
reinterpret_cast<detail::OpOutlineResultImpl *>(ptr);
145+
return ir::OpResult(result_impl_ptr);
146+
} else {
147+
detail::OpInlineResultImpl *result_impl_ptr =
148+
reinterpret_cast<detail::OpInlineResultImpl *>(ptr);
149+
return ir::OpResult(result_impl_ptr);
150+
}
151+
}
152+
153+
std::string Operation::print() {
154+
std::stringstream result;
155+
result << "{ " << num_results_ << " outputs, " << num_operands_
156+
<< " inputs } : ";
157+
result << "[ ";
158+
for (size_t idx = num_results_; idx > 0; idx--) {
159+
result << GetResultByIndex(idx - 1).impl_ << ", ";
160+
}
161+
result << "] = ";
162+
result << this << "( ";
163+
for (size_t idx = 0; idx < num_operands_; idx++) {
164+
result << reinterpret_cast<void *>(reinterpret_cast<char *>(this) +
165+
sizeof(Operation) +
166+
idx * sizeof(detail::OpOperandImpl))
167+
<< ", ";
168+
}
169+
result << ")";
170+
return result.str();
171+
}
172+
173+
} // namespace ir

paddle/ir/operation.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/ir/builtin_attribute.h"
18+
#include "paddle/ir/type.h"
19+
#include "paddle/ir/value_impl.h"
20+
21+
namespace ir {
22+
23+
class alignas(8) Operation final {
24+
public:
25+
///
26+
/// \brief Malloc memory and construct objects in the following order:
27+
/// OpResultImpls|Operation|OpOperandImpls.
28+
///
29+
static Operation *create(const std::vector<ir::OpResult> &inputs,
30+
const std::vector<ir::Type> &output_types,
31+
ir::DictionaryAttribute attribute);
32+
33+
void destroy();
34+
35+
ir::OpResult GetResultByIndex(uint32_t index);
36+
37+
std::string print();
38+
39+
ir::DictionaryAttribute attribute() { return attribute_; }
40+
41+
uint32_t num_results() { return num_results_; }
42+
43+
uint32_t num_operands() { return num_operands_; }
44+
45+
private:
46+
Operation(uint32_t num_results,
47+
uint32_t num_operands,
48+
ir::DictionaryAttribute attribute);
49+
50+
ir::DictionaryAttribute attribute_;
51+
52+
uint32_t num_results_ = 0;
53+
54+
uint32_t num_operands_ = 0;
55+
};
56+
57+
} // namespace ir

paddle/ir/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
cc_test_old(type_test SRCS type_test.cc DEPS new_ir gtest)
22
cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS new_ir gtest)
3+
cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS new_ir gtest)

0 commit comments

Comments
 (0)