@@ -23,9 +23,16 @@ namespace operators {
23
23
using Tensor = framework::Tensor;
24
24
25
25
template <typename T>
26
- __global__ void LinspaceKernel (T start, double step, int64_t size, T* out) {
27
- CUDA_KERNEL_LOOP (index , size) {
28
- out[index ] = static_cast <T>(start + step * index );
26
+ __global__ void LinspaceKernel (T start, T stop, double step, int64_t size,
27
+ T* out) {
28
+ int64_t index = blockIdx .x * blockDim .x + threadIdx .x ;
29
+
30
+ for (; index < size; index += blockDim .x * gridDim .x ) {
31
+ if (index < size / 2 ) {
32
+ out[index ] = static_cast <T>(start + step * index );
33
+ } else {
34
+ out[index ] = static_cast <T>(stop - step * (size - index - 1 ));
35
+ }
29
36
}
30
37
}
31
38
@@ -55,13 +62,15 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
55
62
framework::TransDataType (start_dtype, out_dtype, *pre_start, &start_t );
56
63
framework::TransDataType (stop_dtype, out_dtype, *pre_stop, &stop_t );
57
64
58
- framework::Tensor n;
59
- framework::TensorCopy (start_t , platform::CPUPlace (), &n);
60
- T start = n.data <T>()[0 ];
61
- framework::TensorCopy (stop_t , platform::CPUPlace (), &n);
62
- T stop = n.data <T>()[0 ];
63
- framework::TensorCopy (*num_t , platform::CPUPlace (), &n);
64
- int32_t num = n.data <int32_t >()[0 ];
65
+ framework::Tensor n_start;
66
+ framework::Tensor n_stop;
67
+ framework::Tensor n_num;
68
+ framework::TensorCopy (start_t , platform::CPUPlace (), &n_start);
69
+ T start = n_start.data <T>()[0 ];
70
+ framework::TensorCopy (stop_t , platform::CPUPlace (), &n_stop);
71
+ T stop = n_stop.data <T>()[0 ];
72
+ framework::TensorCopy (*num_t , platform::CPUPlace (), &n_num);
73
+ int64_t num = static_cast <int64_t >(n_num.data <int32_t >()[0 ]);
65
74
66
75
PADDLE_ENFORCE_GT (num, 0 , platform::errors::InvalidArgument (
67
76
" The num of linspace op should be larger "
@@ -72,14 +81,16 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
72
81
T* out_data = out->mutable_data <T>(context.GetPlace ());
73
82
74
83
double step = 0 ;
75
- if (num != 1 ) {
76
- step = (static_cast <double >(stop - start)) / (num - 1 );
77
- }
78
-
79
84
auto stream = context.cuda_device_context ().stream ();
80
85
int block = 512 ;
81
86
int grid = (num + block - 1 ) / block;
82
- LinspaceKernel<T><<<grid, block, 0 , stream>>> (start, step, num, out_data);
87
+ if (num != 1 ) {
88
+ step = (static_cast <double >(stop - start)) / (num - 1 );
89
+ LinspaceKernel<T><<<grid, block, 0 , stream>>> (start, stop, step, num,
90
+ out_data);
91
+ } else {
92
+ LinspaceSpecialKernel<T><<<grid, block, 0 , stream>>> (start, out_data);
93
+ }
83
94
}
84
95
};
85
96
0 commit comments