Skip to content

Commit ecf892f

Browse files
authored
[PHI]Add new Tensor type and migrate save_combine kernel (#47856)
* add new tensor * fix windows compile bugs * fix ci bugs * fix ci bugs * fix ci bugs * perfect according comment * fix ci compile bugs * add raw tensor * fix ci bugs * modify code by comment * delete String
1 parent 737fbdb commit ecf892f

30 files changed

+627
-117
lines changed

cmake/operators.cmake

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,35 @@ function(find_register FILENAME PATTERN OUTPUT)
2626
PARENT_SCOPE)
2727
endfunction()
2828

29+
function(find_phi_register FILENAME ADD_PATH)
30+
# set op_name to OUTPUT
31+
set(options "")
32+
set(oneValueArgs "")
33+
set(multiValueArgs "")
34+
file(READ ${FILENAME} CONTENT)
35+
36+
string(
37+
REGEX
38+
MATCH
39+
"PD_REGISTER_KERNEL\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z]*,[ \\\t\r\n]*[A-Z_]*"
40+
register
41+
"${CONTENT}")
42+
if(NOT register STREQUAL "")
43+
string(REPLACE "PD_REGISTER_KERNEL(" "" register "${register}")
44+
string(REPLACE "," ";" register "${register}")
45+
string(REGEX REPLACE "[ \\\t\r\n]+" "" register "${register}")
46+
string(REGEX REPLACE "//cuda_only" "" register "${register}")
47+
list(GET register 0 kernel_name)
48+
list(GET register 1 kernel_backend)
49+
list(GET register 2 kernel_layout)
50+
51+
file(
52+
APPEND ${ADD_PATH}
53+
"PD_DECLARE_KERNEL(${kernel_name}, ${kernel_backend}, ${kernel_layout});\n"
54+
)
55+
endif()
56+
endfunction()
57+
2958
function(op_library TARGET)
3059
# op_library is a function to create op library. The interface is same as
3160
# cc_library. But it handle split GPU/CPU code and link some common library
@@ -371,6 +400,8 @@ function(op_library TARGET)
371400
foreach(cc_src ${cc_srcs})
372401
# pybind USE_OP_ITSELF
373402
set(op_name "")
403+
# Add PHI Kernel Registry Message
404+
find_phi_register(${cc_src} ${pybind_file})
374405
find_register(${cc_src} "REGISTER_OPERATOR" op_name)
375406
if(NOT ${op_name} EQUAL "")
376407
file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n")
@@ -408,6 +439,8 @@ function(op_library TARGET)
408439
# message("cu_srcs ${cu_srcs}")
409440
foreach(cu_src ${cu_srcs})
410441
set(op_name "")
442+
# Add PHI Kernel Registry Message
443+
find_phi_register(${cu_src} ${pybind_file})
411444
find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name)
412445
if(NOT ${op_name} EQUAL "")
413446
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n")

paddle/fluid/framework/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ proto_library(trainer_desc_proto SRCS trainer_desc.proto DEPS framework_proto
115115
cc_library(
116116
string_array
117117
SRCS string_array.cc
118-
DEPS utf8proc)
118+
DEPS utf8proc phi_enforce)
119119

120120
cc_library(
121121
data_type
@@ -233,7 +233,8 @@ cc_test(
233233
cc_library(
234234
var_type_traits
235235
SRCS var_type_traits.cc
236-
DEPS framework_proto scope tensor_array sparse_coo_tensor sparse_csr_tensor)
236+
DEPS framework_proto scope tensor_array sparse_coo_tensor sparse_csr_tensor
237+
extended_tensor)
237238
if(WITH_GPU)
238239
target_link_libraries(var_type_traits dynload_cuda)
239240
endif()

paddle/fluid/framework/operator.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License. */
2424
#include "paddle/fluid/framework/details/nan_inf_utils.h"
2525
#include "paddle/fluid/framework/op_call_stack.h"
2626
#include "paddle/fluid/framework/phi_utils.h"
27+
#include "paddle/fluid/framework/raw_tensor.h"
2728
#include "paddle/fluid/framework/shape_inference.h"
2829
#include "paddle/fluid/framework/transfer_scope_cache.h"
2930
#include "paddle/fluid/framework/unused_var_check.h"
@@ -3008,6 +3009,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
30083009
need_prepare_phi_data_ = true;
30093010
tensor_in = &(var->Get<framework::LoDTensorArray>());
30103011
phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
3012+
} else if (var->IsType<framework::Vocab>()) {
3013+
tensor_in = &(var->Get<framework::Vocab>());
3014+
phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
30113015
} else {
30123016
PADDLE_THROW(platform::errors::Unimplemented(
30133017
"Unsupported input `%s` type when call pt kernel.",
@@ -3057,6 +3061,13 @@ void OperatorWithKernel::BuildPhiKernelContext(
30573061
// Note: If the input LoDTensorArray size is 0, the output
30583062
// LoDTensorArray is also 0
30593063
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
3064+
} else if (var->template IsType<paddle::framework::RawTensor>()) {
3065+
tensor_out = var->template GetMutable<paddle::framework::RawTensor>();
3066+
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
3067+
} else if (!var->IsInitialized()) {
3068+
// The following is for RAW type of var
3069+
tensor_out = var->template GetMutable<paddle::framework::RawTensor>();
3070+
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
30603071
} else {
30613072
PADDLE_THROW(platform::errors::Unimplemented(
30623073
"Unsupported output `%s` type when call pt kernel.",
@@ -3156,6 +3167,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
31563167
}
31573168
}
31583169
break;
3170+
31593171
case phi::AttributeType::SCALARS: {
31603172
PADDLE_ENFORCE_NE(
31613173
attr_iter,
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <unordered_map>
18+
19+
#include "paddle/phi/core/extended_tensor.h"
20+
#include "paddle/utils/any.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
25+
/// \brief Fluid Kernel and PHI Kernel will be unified in the future.
26+
/// So, we need a class in PHI that can represent the RAW type in Fluid.
27+
/// The RawTensor is for PHI Kernel that has RAW type arguments.
28+
class RawTensor : public phi::ExtendedTensor,
29+
public phi::TypeInfoTraits<phi::TensorBase, RawTensor> {
30+
public:
31+
RawTensor() = default;
32+
33+
RawTensor(RawTensor&& other) = default;
34+
35+
RawTensor(const RawTensor& other) = default;
36+
37+
RawTensor& operator=(RawTensor&& other) = default;
38+
39+
/// \brief Destroy the RawTensor and release exclusive resources.
40+
virtual ~RawTensor() = default;
41+
42+
public:
43+
/// \brief Returns the name of the class for type traits.
44+
/// \return The name of the class.
45+
static const char* name() { return "RawTensor"; }
46+
47+
template <typename T>
48+
T* GetMutable() {
49+
if (!data_.empty()) {
50+
try {
51+
return paddle::any_cast<T*>(data_);
52+
} catch (paddle::bad_any_cast&) {
53+
PADDLE_THROW(phi::errors::InvalidArgument(
54+
"Invalid data type error, expected %s, actual %s.",
55+
typeid(T).name(),
56+
data_type_.name()));
57+
}
58+
}
59+
T* created_data = new T();
60+
data_ = created_data;
61+
data_deleter_ = [created_data]() { delete created_data; };
62+
data_type_ = std::type_index(typeid(T));
63+
return created_data;
64+
}
65+
66+
template <typename T>
67+
bool IsType() const {
68+
return std::type_index(typeid(T)) == data_type_;
69+
}
70+
71+
private:
72+
paddle::any data_;
73+
std::function<void(void)> data_deleter_;
74+
std::type_index data_type_ = std::type_index(typeid(void));
75+
};
76+
77+
} // namespace framework
78+
} // namespace paddle

paddle/fluid/framework/string_array.h

100755100644
Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,82 @@ limitations under the License. */
2020
#include <string>
2121
#include <unordered_map>
2222
#include <vector>
23+
#include "paddle/phi/core/extended_tensor.h"
2324

2425
namespace paddle {
2526
namespace framework {
2627

28+
class Vocab : public phi::ExtendedTensor,
29+
public phi::TypeInfoTraits<phi::TensorBase, Vocab> {
30+
public:
31+
Vocab() = default;
32+
33+
Vocab(Vocab&& other) = default;
34+
35+
Vocab(const Vocab& other) = default;
36+
37+
Vocab& operator=(const Vocab& other) = default;
38+
39+
Vocab& operator=(Vocab&& other) = default;
40+
41+
Vocab& operator=(
42+
const std::unordered_map<std::wstring, std::int32_t>& other) {
43+
this->data_ = other;
44+
return *this;
45+
}
46+
47+
/// \brief Destroy the Vocab and release exclusive resources.
48+
virtual ~Vocab() = default;
49+
50+
public:
51+
/// \brief Returns the name of the class for type traits.
52+
/// \return The name of the class.
53+
static const char* name() { return "Vocab"; }
54+
55+
size_t size() const { return data_.size(); }
56+
57+
void clear() { data_.clear(); }
58+
59+
void emplace(const std::wstring& key, std::int32_t value) {
60+
data_.emplace(key, value);
61+
}
62+
63+
std::int32_t at(const std::wstring& key) { return data_.at(key); }
64+
65+
std::int32_t at(const std::wstring& key) const { return data_.at(key); }
66+
67+
std::unordered_map<std::wstring, std::int32_t>::iterator find(
68+
const std::wstring& key) {
69+
return data_.find(key);
70+
}
71+
72+
std::unordered_map<std::wstring, std::int32_t>::const_iterator find(
73+
const std::wstring& key) const {
74+
return data_.find(key);
75+
}
76+
77+
std::unordered_map<std::wstring, std::int32_t>::iterator begin() {
78+
return data_.begin();
79+
}
80+
81+
std::unordered_map<std::wstring, std::int32_t>::const_iterator begin() const {
82+
return data_.begin();
83+
}
84+
85+
std::unordered_map<std::wstring, std::int32_t>::iterator end() {
86+
return data_.end();
87+
}
88+
89+
std::unordered_map<std::wstring, std::int32_t>::const_iterator end() const {
90+
return data_.end();
91+
}
92+
93+
private:
94+
std::unordered_map<std::wstring, std::int32_t> data_;
95+
};
96+
2797
using String = std::string;
2898
using Strings = std::vector<std::string>;
29-
using Vocab = std::unordered_map<std::wstring, std::int32_t>;
3099

31100
// Convert the std::string type to the std::string type.
32101
bool ConvertStrToWstr(const std::string& src, std::wstring* res);

paddle/fluid/framework/var_type_traits.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
4242
#endif
4343

44+
#include "paddle/fluid/framework/raw_tensor.h"
4445
#include "paddle/fluid/operators/cuda_graph_with_in_out.h"
4546

4647
namespace paddle {

paddle/fluid/framework/var_type_traits.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include "paddle/fluid/framework/feed_fetch_type.h"
2525
#include "paddle/fluid/framework/lod_tensor_array.h"
26+
#include "paddle/fluid/framework/raw_tensor.h"
2627
#include "paddle/fluid/framework/string_array.h"
2728
#include "paddle/fluid/platform/place.h"
2829
#ifdef PADDLE_WITH_CUDA
@@ -219,7 +220,8 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
219220
float,
220221
Vocab,
221222
std::vector<int>,
222-
std::vector<float>>;
223+
std::vector<float>,
224+
RawTensor>;
223225
template <typename T>
224226
struct VarTypeTrait {
225227
static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type");

paddle/fluid/framework/var_type_traits_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#if defined(PADDLE_WITH_XPU_BKCL)
3939
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
4040
#endif
41+
#include "paddle/fluid/framework/raw_tensor.h"
4142

4243
namespace paddle {
4344
namespace framework {

paddle/fluid/framework/variable_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ namespace framework {
2222
TEST(Variable, GetMutable) {
2323
std::unique_ptr<Variable> v(new Variable());
2424

25-
auto* t = v->GetMutable<std::string>();
25+
auto* t = v->GetMutable<String>();
2626
*t = "1234";
2727

28-
const auto& tt = v->Get<std::string>();
28+
const auto& tt = v->Get<String>();
2929
EXPECT_EQ("1234", tt);
3030

3131
try {

paddle/fluid/imperative/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ cc_library(
55
cc_library(
66
var_helper
77
SRCS var_helper.cc
8-
DEPS tensor selected_rows)
8+
DEPS tensor selected_rows extended_tensor)
99
if(WITH_XPU)
1010
cc_library(
1111
prepared_operator

0 commit comments

Comments
 (0)