Skip to content

Commit 2e38044

Browse files
authored
Merge pull request #4205 from reyoung/feature/tensor_copy
Add StridedCopy method
2 parents 4400284 + 07915c9 commit 2e38044

File tree

4 files changed

+299
-0
lines changed

4 files changed

+299
-0
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,4 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
9696
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
9797
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
9898
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
99+
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/ddim.h"
17+
#include "paddle/memory/memcpy.h"
18+
#include "paddle/platform/device_context.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
namespace detail {
23+
24+
template <typename T, int Rank>
25+
struct StridedMemcpyFunctor;
26+
27+
template <typename T>
28+
struct StridedMemcpyFunctor<T, 1> {
29+
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
30+
framework::Dim<1> src_stride, framework::Dim<1> dst_dim,
31+
framework::Dim<1> dst_stride, T* dst) const {
32+
auto place = dev_ctx.GetPlace();
33+
if (platform::is_cpu_place(place)) {
34+
auto& cpu_place = boost::get<platform::CPUPlace>(place);
35+
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim.head);
36+
} else {
37+
#ifndef PADDLE_ONLY_CPU
38+
auto& gpu_place = boost::get<platform::GPUPlace>(place);
39+
auto& cuda_ctx =
40+
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
41+
memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T) * dst_dim.head,
42+
cuda_ctx.stream());
43+
#else
44+
PADDLE_THROW("Paddle is not compiled with GPU");
45+
#endif
46+
}
47+
}
48+
};
49+
50+
template <typename T, int Rank>
51+
struct StridedMemcpyFunctor {
52+
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
53+
framework::Dim<Rank> src_stride, framework::Dim<Rank> dst_dim,
54+
framework::Dim<Rank> dst_stride, T* dst) const {
55+
for (int64_t i = 0; i < dst_dim.head; ++i) {
56+
StridedMemcpyFunctor<T, Rank - 1> func;
57+
func(dev_ctx, src, src_stride.tail, dst_dim.tail, dst_stride.tail, dst);
58+
src += src_stride.head;
59+
dst += dst_stride.head;
60+
}
61+
}
62+
};
63+
64+
template <typename T>
65+
struct StridedCopyDimVisitor : public boost::static_visitor<void> {
66+
StridedCopyDimVisitor(const platform::DeviceContext& dev_ctx, const T* src,
67+
const framework::DDim& src_stride,
68+
const framework::DDim& dst_stride, T* dst)
69+
: dev_ctx_(dev_ctx),
70+
src_(src),
71+
src_stride_(src_stride),
72+
dst_stride_(dst_stride),
73+
dst_(dst) {}
74+
75+
template <typename Dim>
76+
void operator()(Dim dst_dim) const {
77+
Dim src_stride = boost::get<Dim>(src_stride_);
78+
Dim dst_stride = boost::get<Dim>(dst_stride_);
79+
constexpr int dim = Dim::dimensions;
80+
StridedMemcpyFunctor<T, dim> functor;
81+
functor(dev_ctx_, src_, src_stride, dst_dim, dst_stride, dst_);
82+
}
83+
84+
const platform::DeviceContext& dev_ctx_;
85+
const T* src_;
86+
const framework::DDim& src_stride_;
87+
const framework::DDim& dst_stride_;
88+
T* dst_;
89+
};
90+
91+
} // namespace detail
92+
} // namespace operators
93+
} // namespace paddle

paddle/operators/strided_memcpy.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/operators/detail/strided_memcpy.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
// Strided memory copy from src to dst.
22+
//
23+
// The src and dst should be both on dev_ctx.GetPlace(), otherwise, there will
24+
// be a segment fault.
25+
//
26+
// The stride of an array (also referred to as increment, pitch or step size) is
27+
// the number of locations in memory between beginnings of successive array
28+
// elements
29+
//
30+
// For example, for tensor like [1, 3, 300, 300]. If there is no padding, the
31+
// stride is [270000, 90000, 300, 1].
32+
//
33+
// NOTE: When use GPU, the memcpy is async. To sync memcpy, please invoke
34+
// `dev_ctx.Wait()`.
35+
template <typename T>
36+
inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
37+
const framework::DDim& src_stride,
38+
const framework::DDim& dst_dim,
39+
const framework::DDim& dst_stride, T* dst) {
40+
using namespace detail;
41+
StridedCopyDimVisitor<T> func(dev_ctx, src, src_stride, dst_stride, dst);
42+
boost::apply_visitor(func, dst_dim);
43+
}
44+
} // namespace operators
45+
} // namespace paddle
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/strided_memcpy.h"
16+
#include "gtest/gtest.h"
17+
#include "paddle/memory/memory.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
TEST(StridedMemcpy, CPUCrop) {
23+
// clang-format off
24+
int src[] = {
25+
0, 1, 2, 0, 0,
26+
0, 3, 4, 0, 0,
27+
0, 0, 0, 0, 0,
28+
};
29+
// clang-format on
30+
31+
framework::DDim src_stride({5, 1});
32+
33+
int dst[4];
34+
framework::DDim dst_dim({2, 2});
35+
framework::DDim dst_stride({2, 1});
36+
37+
platform::CPUDeviceContext ctx;
38+
StridedMemcpy<int>(ctx, src + 1, src_stride, dst_dim, dst_stride, dst);
39+
40+
ASSERT_EQ(1, dst[0]);
41+
ASSERT_EQ(2, dst[1]);
42+
ASSERT_EQ(3, dst[2]);
43+
ASSERT_EQ(4, dst[3]);
44+
}
45+
46+
TEST(StridedMemcpy, CPUConcat) {
47+
// clang-format off
48+
int src[] = {
49+
1, 2,
50+
3, 4
51+
};
52+
// clang-format on
53+
54+
int dst[8];
55+
56+
framework::DDim src_stride({2, 1});
57+
framework::DDim dst_dim({2, 2});
58+
framework::DDim dst_stride({4, 1});
59+
platform::CPUDeviceContext ctx;
60+
61+
StridedMemcpy<int>(ctx, src, src_stride, dst_dim, dst_stride, dst);
62+
StridedMemcpy<int>(ctx, src, src_stride, dst_dim, dst_stride, dst + 2);
63+
64+
// clang-format off
65+
int expect_dst[] = {
66+
1, 2, 1, 2,
67+
3, 4, 3, 4
68+
};
69+
// clang-format on
70+
for (size_t i = 0; i < sizeof(expect_dst) / sizeof(int); ++i) {
71+
ASSERT_EQ(expect_dst[i], dst[i]);
72+
}
73+
}
74+
75+
#ifndef PADDLE_ONLY_CPU
76+
TEST(StridedMemcpy, GPUCrop) {
77+
// clang-format off
78+
int src[] = {
79+
0, 1, 2, 0, 0,
80+
0, 3, 4, 0, 0,
81+
0, 0, 0, 0, 0,
82+
};
83+
// clang-format on
84+
85+
platform::GPUPlace gpu0(0);
86+
platform::CPUPlace cpu;
87+
88+
int* gpu_src = reinterpret_cast<int*>(memory::Alloc(gpu0, sizeof(src)));
89+
memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src));
90+
91+
framework::DDim src_stride({5, 1});
92+
93+
int dst[4];
94+
int* gpu_dst = reinterpret_cast<int*>(memory::Alloc(gpu0, sizeof(dst)));
95+
96+
framework::DDim dst_dim({2, 2});
97+
framework::DDim dst_stride({2, 1});
98+
99+
platform::CUDADeviceContext ctx(gpu0);
100+
StridedMemcpy<int>(ctx, gpu_src + 1, src_stride, dst_dim, dst_stride,
101+
gpu_dst);
102+
103+
memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst), ctx.stream());
104+
ctx.Wait();
105+
106+
ASSERT_EQ(1, dst[0]);
107+
ASSERT_EQ(2, dst[1]);
108+
ASSERT_EQ(3, dst[2]);
109+
ASSERT_EQ(4, dst[3]);
110+
111+
memory::Free(gpu0, gpu_dst);
112+
memory::Free(gpu0, gpu_src);
113+
}
114+
115+
TEST(StridedMemcpy, GPUConcat) {
116+
// clang-format off
117+
int src[] = {
118+
1, 2,
119+
3, 4
120+
};
121+
// clang-format on
122+
123+
platform::GPUPlace gpu0(0);
124+
platform::CPUPlace cpu;
125+
126+
int* gpu_src = reinterpret_cast<int*>(memory::Alloc(gpu0, sizeof(src)));
127+
memory::Copy(gpu0, gpu_src, cpu, src, sizeof(src));
128+
129+
int dst[8];
130+
int* gpu_dst = reinterpret_cast<int*>(memory::Alloc(gpu0, sizeof(dst)));
131+
132+
framework::DDim src_stride({2, 1});
133+
framework::DDim dst_dim({2, 2});
134+
framework::DDim dst_stride({4, 1});
135+
platform::CUDADeviceContext ctx(gpu0);
136+
137+
StridedMemcpy<int>(ctx, gpu_src, src_stride, dst_dim, dst_stride, gpu_dst);
138+
StridedMemcpy<int>(ctx, gpu_src, src_stride, dst_dim, dst_stride,
139+
gpu_dst + 2);
140+
141+
memory::Copy(cpu, dst, gpu0, gpu_dst, sizeof(dst), ctx.stream());
142+
ctx.Wait();
143+
144+
// clang-format off
145+
int expect_dst[] = {
146+
1, 2, 1, 2,
147+
3, 4, 3, 4
148+
};
149+
// clang-format on
150+
for (size_t i = 0; i < sizeof(expect_dst) / sizeof(int); ++i) {
151+
ASSERT_EQ(expect_dst[i], dst[i]);
152+
}
153+
154+
memory::Free(gpu0, gpu_dst);
155+
memory::Free(gpu0, gpu_src);
156+
}
157+
158+
#endif
159+
} // namespace operators
160+
} // namespace paddle

0 commit comments

Comments
 (0)