Skip to content

Commit 0a862fd

Browse files
refine the precious of linspace Op using half way (PaddlePaddle#27452)
1 parent fda54c0 commit 0a862fd

File tree

3 files changed

+34
-17
lines changed

3 files changed

+34
-17
lines changed

paddle/fluid/operators/linspace_op.cu

+26-15
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,16 @@ namespace operators {
2323
using Tensor = framework::Tensor;
2424

2525
template <typename T>
26-
__global__ void LinspaceKernel(T start, double step, int64_t size, T* out) {
27-
CUDA_KERNEL_LOOP(index, size) {
28-
out[index] = static_cast<T>(start + step * index);
26+
__global__ void LinspaceKernel(T start, T stop, double step, int64_t size,
27+
T* out) {
28+
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
29+
30+
for (; index < size; index += blockDim.x * gridDim.x) {
31+
if (index < size / 2) {
32+
out[index] = static_cast<T>(start + step * index);
33+
} else {
34+
out[index] = static_cast<T>(stop - step * (size - index - 1));
35+
}
2936
}
3037
}
3138

@@ -55,13 +62,15 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
5562
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
5663
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
5764

58-
framework::Tensor n;
59-
framework::TensorCopy(start_t, platform::CPUPlace(), &n);
60-
T start = n.data<T>()[0];
61-
framework::TensorCopy(stop_t, platform::CPUPlace(), &n);
62-
T stop = n.data<T>()[0];
63-
framework::TensorCopy(*num_t, platform::CPUPlace(), &n);
64-
int32_t num = n.data<int32_t>()[0];
65+
framework::Tensor n_start;
66+
framework::Tensor n_stop;
67+
framework::Tensor n_num;
68+
framework::TensorCopy(start_t, platform::CPUPlace(), &n_start);
69+
T start = n_start.data<T>()[0];
70+
framework::TensorCopy(stop_t, platform::CPUPlace(), &n_stop);
71+
T stop = n_stop.data<T>()[0];
72+
framework::TensorCopy(*num_t, platform::CPUPlace(), &n_num);
73+
int64_t num = static_cast<int64_t>(n_num.data<int32_t>()[0]);
6574

6675
PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
6776
"The num of linspace op should be larger "
@@ -72,14 +81,16 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
7281
T* out_data = out->mutable_data<T>(context.GetPlace());
7382

7483
double step = 0;
75-
if (num != 1) {
76-
step = (static_cast<double>(stop - start)) / (num - 1);
77-
}
78-
7984
auto stream = context.cuda_device_context().stream();
8085
int block = 512;
8186
int grid = (num + block - 1) / block;
82-
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, step, num, out_data);
87+
if (num != 1) {
88+
step = (static_cast<double>(stop - start)) / (num - 1);
89+
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, stop, step, num,
90+
out_data);
91+
} else {
92+
LinspaceSpecialKernel<T><<<grid, block, 0, stream>>>(start, out_data);
93+
}
8394
}
8495
};
8596

paddle/fluid/operators/linspace_op.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,15 @@ class CPULinspaceKernel : public framework::OpKernel<T> {
5656
T* out_data = out->mutable_data<T>(context.GetPlace());
5757

5858
if (num > 1) {
59+
// step should be of double type for all types
5960
double step = (static_cast<double>(stop - start)) / (num - 1);
61+
int half_num = num / 2;
6062
for (int i = 0; i < num; ++i) {
61-
out_data[i] = static_cast<T>(start + step * i);
63+
if (i < half_num) {
64+
out_data[i] = static_cast<T>(start + step * i);
65+
} else {
66+
out_data[i] = static_cast<T>(stop - step * (num - i - 1));
67+
}
6268
}
6369
} else {
6470
out_data[0] = static_cast<T>(start);

python/paddle/fluid/layers/tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1424,7 +1424,7 @@ def linspace(start, stop, num, dtype=None, name=None):
14241424
stop(int|float|Tensor): The input :attr:`stop` is start variable of range. It is a scalar, \
14251425
or a Tensor of shape [1] with input data type int32, int64, float32 or float64.
14261426
num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int scalar, \
1427-
or a Tensor of shape [1] with data type int32 or int64.
1427+
or a Tensor of shape [1] with data type int32.
14281428
dtype(np.dtype|str, optional): The data type of output tensor, it could be
14291429
int32, int64, float32 and float64. Default: if None, the data type is float32.
14301430
name(str, optional): Normally there is no need for user to set this property.

0 commit comments

Comments
 (0)