Skip to content

Commit 3b25afb

Browse files
authored
[Phi] Support construct Scalar by using Non-CPU Tensor (PaddlePaddle#41765) (PaddlePaddle#41963)
* support construct scalar using non-cpu tensor * fix bugs when run unittest * fix compile bugs * fix bugs when run ci * fix compile bugs * fix bugs when move copy * perfect unit test * perfect unittest * update according to comment * add target dependency * deal with conflict * fix bugs when run unit test * fix unit test bugs
1 parent 9a75b4b commit 3b25afb

21 files changed

+449
-95
lines changed

paddle/fluid/platform/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,13 @@ add_subdirectory(profiler)
192192

193193
cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS})
194194
if(WITH_GPU)
195-
nv_library(profiler SRCS profiler.cc profiler.cu DEPS os_info device_tracer gpu_info enforce dynload_cuda new_profiler)
195+
nv_library(profiler SRCS profiler.cc profiler.cu DEPS os_info device_tracer gpu_info enforce dynload_cuda new_profiler stats)
196196
nv_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info gpu_info place)
197197
elseif(WITH_ROCM)
198-
hip_library(profiler SRCS profiler.cc profiler.cu DEPS os_info device_tracer gpu_info enforce new_profiler)
198+
hip_library(profiler SRCS profiler.cc profiler.cu DEPS os_info device_tracer gpu_info enforce new_profiler stats)
199199
hip_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info gpu_info place)
200200
else()
201-
cc_library(profiler SRCS profiler.cc DEPS os_info device_tracer enforce new_profiler)
201+
cc_library(profiler SRCS profiler.cc DEPS os_info device_tracer enforce new_profiler stats)
202202
cc_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info place)
203203
endif()
204204

paddle/phi/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ add_subdirectory(tools)
2323
add_subdirectory(tests)
2424

2525
# make an unity target for compile deps
26-
set(PHI_DEPS convert_utils dense_tensor phi_context kernel_factory kernel_context arg_map_context infermeta lod_utils op_compat_infos sparse_csr_tensor sparse_coo_tensor string_tensor)
26+
set(PHI_DEPS convert_utils dense_tensor phi_context kernel_factory kernel_context arg_map_context infermeta lod_utils op_compat_infos sparse_csr_tensor sparse_coo_tensor string_tensor api_scalar)
2727
get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS)
2828
set(PHI_DEPS ${PHI_DEPS} ${phi_kernels})
2929

paddle/phi/api/lib/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_conte
164164
cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor)
165165
cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform)
166166
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform)
167-
cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)
167+
cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform tensor_copy)
168168

169169
cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl)
170170
cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl global_utils)
@@ -173,3 +173,5 @@ cc_library(sparse_bw_api SRCS ${sparse_bw_api_source_file} DEPS phi_tensor_raw p
173173
cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform phi_function_api sparse_api)
174174
cc_library(strings_api SRCS ${strings_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils)
175175
cc_library(phi_tensor SRCS tensor_method.cc DEPS phi_tensor_raw phi_function_api api_gen_utils kernel_dispatch infermeta sparse_api strings_api)
176+
cc_library(tensor_copy SRCS tensor_copy.cc DEPS phi_tensor_raw copy_kernel kernel_dispatch api_gen_utils)
177+
cc_library(api_scalar SRCS scalar.cc DEPS tensor_copy)

paddle/phi/api/lib/api_custom_impl.cc

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/phi/api/lib/api_gen_utils.h"
1818
#include "paddle/phi/api/lib/data_transform.h"
1919
#include "paddle/phi/api/lib/kernel_dispatch.h"
20+
#include "paddle/phi/api/lib/tensor_copy.h"
2021
#include "paddle/phi/api/lib/utils/storage.h"
2122
#include "paddle/phi/common/type_traits.h"
2223
#include "paddle/phi/core/compat/convert_utils.h"
@@ -243,35 +244,8 @@ std::vector<std::vector<Tensor>> conv2d_grad_impl(
243244
}
244245

245246
Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
246-
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
247-
kernel_key_set.backend_set =
248-
kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place));
249-
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
250-
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
251-
"copy", kernel_key);
252-
253-
VLOG(6) << "copy API kernel key: " << kernel_key;
254-
VLOG(6) << "copy API kernel: " << kernel;
255-
256-
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
257-
258-
auto dense_x = TensorToDenseTensor(x);
259-
260247
Tensor out;
261-
auto kernel_out = SetKernelOutput(kernel_key.backend(), &out);
262-
phi::MetaTensor meta_out(kernel_out);
263-
phi::UnchangedInferMeta(*dense_x, &meta_out);
264-
265-
using kernel_signature = void (*)(const platform::DeviceContext&,
266-
const phi::DenseTensor&,
267-
phi::Place,
268-
bool,
269-
phi::DenseTensor*);
270-
271-
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
272-
273-
(*kernel_fn)(*dev_ctx, *dense_x, place, blocking, kernel_out);
274-
248+
copy(x, place, blocking, &out);
275249
return out;
276250
}
277251

paddle/phi/api/lib/scalar.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/common/scalar.h"
16+
17+
#include "paddle/phi/api/lib/tensor_copy.h"
18+
#include "paddle/phi/common/place.h"
19+
#include "paddle/phi/core/enforce.h"
20+
21+
namespace paddle {
22+
namespace experimental {
23+
24+
template <>
25+
ScalarBase<Tensor>::ScalarBase(const Tensor& tensor_in)
26+
: dtype_(tensor_in.dtype()) { // NOLINT
27+
PADDLE_ENFORCE_EQ(tensor_in.numel(),
28+
1,
29+
phi::errors::InvalidArgument(
30+
"The Scalar only supports Tensor with 1 element, but "
31+
"now Tensor has `%d` elements",
32+
tensor_in.numel()));
33+
auto tensor_in_place = tensor_in.place().GetType();
34+
if (tensor_in_place == phi::AllocationType::GPU) {
35+
Tensor dst_tensor;
36+
copy(tensor_in, phi::CPUPlace(), true, &dst_tensor);
37+
GetDataFromTensor(dst_tensor);
38+
} else if (tensor_in_place == phi::AllocationType::CPU) {
39+
GetDataFromTensor(tensor_in);
40+
} else {
41+
PADDLE_THROW(phi::errors::Unimplemented(
42+
"Now, it is not supported to construct Scalar using tensor that its "
43+
"Place is (%s)",
44+
tensor_in.place()));
45+
}
46+
}
47+
48+
} // namespace experimental
49+
} // namespace paddle

paddle/phi/api/lib/tensor_copy.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/api/lib/tensor_copy.h"
16+
#include "paddle/phi/api/lib/api_gen_utils.h"
17+
#include "paddle/phi/api/lib/kernel_dispatch.h"
18+
#include "paddle/phi/api/lib/utils/storage.h"
19+
#include "paddle/phi/core/compat/convert_utils.h"
20+
#include "paddle/phi/core/kernel_registry.h"
21+
#include "paddle/phi/core/meta_tensor.h"
22+
#include "paddle/phi/infermeta/unary.h"
23+
24+
namespace paddle {
25+
namespace experimental {
26+
27+
void copy(const Tensor& src, Place place, bool blocking, Tensor* dst) {
28+
auto kernel_key_set = ParseKernelKeyByInputArgs(src);
29+
kernel_key_set.backend_set =
30+
kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place));
31+
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
32+
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
33+
"copy", kernel_key);
34+
35+
VLOG(6) << "copy API kernel key: " << kernel_key;
36+
VLOG(6) << "copy API kernel: " << kernel;
37+
38+
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
39+
40+
auto dense_x = TensorToDenseTensor(src);
41+
42+
auto kernel_out = SetKernelOutput(kernel_key.backend(), dst);
43+
phi::MetaTensor meta_out(kernel_out);
44+
phi::UnchangedInferMeta(*dense_x, &meta_out);
45+
46+
using kernel_signature = void (*)(const platform::DeviceContext&,
47+
const phi::DenseTensor&,
48+
phi::Place,
49+
bool,
50+
phi::DenseTensor*);
51+
52+
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
53+
(*kernel_fn)(*dev_ctx, *dense_x, place, blocking, kernel_out);
54+
}
55+
56+
} // namespace experimental
57+
} // namespace paddle

paddle/phi/api/lib/tensor_copy.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/api/include/tensor.h"
18+
19+
namespace paddle {
20+
namespace experimental {
21+
22+
void copy(const Tensor& src, Place place, bool blocking, Tensor* dst);
23+
24+
} // namespace experimental
25+
} // namespace paddle
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
cc_library(phi_api_utils SRCS storage.cc tensor_utils.cc DEPS
2-
tensor_base convert_utils dense_tensor lod_tensor selected_rows_utils place var_type_traits scalar string_tensor)
2+
tensor_base convert_utils dense_tensor lod_tensor selected_rows_utils place var_type_traits string_tensor scalar)

paddle/phi/common/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
cc_library(phi_place SRCS place.cc)
2-
cc_library(scalar SRCS scalar.cc DEPS phi_enforce)
2+
cc_library(scalar SRCS scalar.cc DEPS phi_enforce tensor)

paddle/phi/common/scalar.cc

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,32 @@ limitations under the License. */
1414

1515
#include "paddle/phi/common/scalar.h"
1616

17+
#include "paddle/phi/common/place.h"
1718
#include "paddle/phi/core/enforce.h"
1819

20+
#include "paddle/fluid/framework/tensor_util.h"
21+
#include "paddle/fluid/platform/place.h"
1922
namespace paddle {
2023
namespace experimental {
2124

22-
// NOTE(xiongkun): why we put definition here?
23-
// test_custom_op can't include enforce.h, because enforce.h includes gflags.
24-
// so we decouple the include dependence of enforce.h by link.
25-
void ThrowTensorConvertError(int num) {
26-
PADDLE_ENFORCE_EQ(num,
25+
// The Tensor must have one dim
26+
template <>
27+
ScalarBase<phi::DenseTensor>::ScalarBase(const phi::DenseTensor& tensor_in)
28+
: dtype_(tensor_in.dtype()) { // NOLINT
29+
PADDLE_ENFORCE_EQ(tensor_in.numel(),
2730
1,
2831
phi::errors::InvalidArgument(
2932
"The Scalar only supports Tensor with 1 element, but "
3033
"now Tensor has `%d` elements",
31-
num));
34+
tensor_in.numel()));
35+
auto cpu_place = phi::CPUPlace();
36+
if (!paddle::platform::is_same_place(tensor_in.place(), cpu_place)) {
37+
phi::DenseTensor tensor;
38+
framework::TensorCopySync(tensor_in, cpu_place, &tensor);
39+
GetDataFromTensor(tensor);
40+
} else {
41+
GetDataFromTensor(tensor_in);
42+
}
3243
}
3344

3445
} // namespace experimental

0 commit comments

Comments
 (0)