1+ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ http://www.apache.org/licenses/LICENSE-2.0
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License. */
14+
15+ #include " paddle/operators/strided_memcpy.h"
16+ #include " gtest/gtest.h"
17+ #include " paddle/memory/memory.h"
18+
19+ namespace paddle {
20+ namespace operators {
21+
22+ TEST (StridedMemcpy, CPUCrop) {
23+ // clang-format off
24+ int src[] = {
25+ 0 , 1 , 2 , 0 , 0 ,
26+ 0 , 3 , 4 , 0 , 0 ,
27+ 0 , 0 , 0 , 0 , 0 ,
28+ };
29+ // clang-format on
30+
31+ framework::DDim src_stride ({5 , 1 });
32+
33+ int dst[4 ];
34+ framework::DDim dst_dim ({2 , 2 });
35+ framework::DDim dst_stride ({2 , 1 });
36+
37+ platform::CPUDeviceContext ctx;
38+ StridedMemcpy<int >(ctx, src + 1 , src_stride, dst_dim, dst_stride, dst);
39+
40+ ASSERT_EQ (1 , dst[0 ]);
41+ ASSERT_EQ (2 , dst[1 ]);
42+ ASSERT_EQ (3 , dst[2 ]);
43+ ASSERT_EQ (4 , dst[3 ]);
44+ }
45+
46+ TEST (StridedMemcpy, CPUConcat) {
47+ // clang-format off
48+ int src[] = {
49+ 1 , 2 ,
50+ 3 , 4
51+ };
52+ // clang-format on
53+
54+ int dst[8 ];
55+
56+ framework::DDim src_stride ({2 , 1 });
57+ framework::DDim dst_dim ({2 , 2 });
58+ framework::DDim dst_stride ({4 , 1 });
59+ platform::CPUDeviceContext ctx;
60+
61+ StridedMemcpy<int >(ctx, src, src_stride, dst_dim, dst_stride, dst);
62+ StridedMemcpy<int >(ctx, src, src_stride, dst_dim, dst_stride, dst + 2 );
63+
64+ // clang-format off
65+ int expect_dst[] = {
66+ 1 , 2 , 1 , 2 ,
67+ 3 , 4 , 3 , 4
68+ };
69+ // clang-format on
70+ for (size_t i = 0 ; i < sizeof (expect_dst) / sizeof (int ); ++i) {
71+ ASSERT_EQ (expect_dst[i], dst[i]);
72+ }
73+ }
74+
75+ #ifndef PADDLE_ONLY_CPU
76+ TEST (StridedMemcpy, GPUCrop) {
77+ // clang-format off
78+ int src[] = {
79+ 0 , 1 , 2 , 0 , 0 ,
80+ 0 , 3 , 4 , 0 , 0 ,
81+ 0 , 0 , 0 , 0 , 0 ,
82+ };
83+ // clang-format on
84+
85+ platform::GPUPlace gpu0 (0 );
86+ platform::CPUPlace cpu;
87+
88+ int * gpu_src = reinterpret_cast <int *>(memory::Alloc (gpu0, sizeof (src)));
89+ memory::Copy (gpu0, gpu_src, cpu, src, sizeof (src));
90+
91+ framework::DDim src_stride ({5 , 1 });
92+
93+ int dst[4 ];
94+ int * gpu_dst = reinterpret_cast <int *>(memory::Alloc (gpu0, sizeof (dst)));
95+
96+ framework::DDim dst_dim ({2 , 2 });
97+ framework::DDim dst_stride ({2 , 1 });
98+
99+ platform::CUDADeviceContext ctx (gpu0);
100+ StridedMemcpy<int >(ctx, gpu_src + 1 , src_stride, dst_dim, dst_stride,
101+ gpu_dst);
102+
103+ memory::Copy (cpu, dst, gpu0, gpu_dst, sizeof (dst), ctx.stream ());
104+ ctx.Wait ();
105+
106+ ASSERT_EQ (1 , dst[0 ]);
107+ ASSERT_EQ (2 , dst[1 ]);
108+ ASSERT_EQ (3 , dst[2 ]);
109+ ASSERT_EQ (4 , dst[3 ]);
110+
111+ memory::Free (gpu0, gpu_dst);
112+ memory::Free (gpu0, gpu_src);
113+ }
114+
115+ TEST (StridedMemcpy, GPUConcat) {
116+ // clang-format off
117+ int src[] = {
118+ 1 , 2 ,
119+ 3 , 4
120+ };
121+ // clang-format on
122+
123+ platform::GPUPlace gpu0 (0 );
124+ platform::CPUPlace cpu;
125+
126+ int * gpu_src = reinterpret_cast <int *>(memory::Alloc (gpu0, sizeof (src)));
127+ memory::Copy (gpu0, gpu_src, cpu, src, sizeof (src));
128+
129+ int dst[8 ];
130+ int * gpu_dst = reinterpret_cast <int *>(memory::Alloc (gpu0, sizeof (dst)));
131+
132+ framework::DDim src_stride ({2 , 1 });
133+ framework::DDim dst_dim ({2 , 2 });
134+ framework::DDim dst_stride ({4 , 1 });
135+ platform::CUDADeviceContext ctx (gpu0);
136+
137+ StridedMemcpy<int >(ctx, gpu_src, src_stride, dst_dim, dst_stride, gpu_dst);
138+ StridedMemcpy<int >(ctx, gpu_src, src_stride, dst_dim, dst_stride,
139+ gpu_dst + 2 );
140+
141+ memory::Copy (cpu, dst, gpu0, gpu_dst, sizeof (dst), ctx.stream ());
142+ ctx.Wait ();
143+
144+ // clang-format off
145+ int expect_dst[] = {
146+ 1 , 2 , 1 , 2 ,
147+ 3 , 4 , 3 , 4
148+ };
149+ // clang-format on
150+ for (size_t i = 0 ; i < sizeof (expect_dst) / sizeof (int ); ++i) {
151+ ASSERT_EQ (expect_dst[i], dst[i]);
152+ }
153+
154+ memory::Free (gpu0, gpu_dst);
155+ memory::Free (gpu0, gpu_src);
156+ }
157+
158+ #endif
159+ } // namespace operators
160+ } // namespace paddle
0 commit comments