Skip to content

Commit 10a11d8

Browse files
authored
Improve the performence of put_along_axis (#60618)
* fix bug of put_along_axis * improve performence of put_along_axis
1 parent b578350 commit 10a11d8

File tree

1 file changed

+34
-54
lines changed

1 file changed

+34
-54
lines changed

paddle/phi/kernels/funcs/gather_scatter_functor.cu

Lines changed: 34 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ class ReduceMin {
9292
};
9393
static ReduceMin reduce_min;
9494

95+
__global__ void CudaMemsetAsync(int* dest, int value, size_t size) {
96+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
97+
if (tid * sizeof(int) >= size) return;
98+
dest[tid] = value;
99+
}
100+
95101
template <typename tensor_t,
96102
typename index_t,
97103
typename func_t,
@@ -112,13 +118,6 @@ __global__ void ScatterAssignGPUKernel(tensor_t* self_data,
112118
int* thread_ids) {
113119
int tid = threadIdx.x + blockIdx.x * blockDim.x;
114120
if (tid >= numel) return;
115-
116-
if (tid == 0) {
117-
for (int i = 0; i < numel_data; i++) {
118-
thread_ids[i] = 0;
119-
}
120-
}
121-
__syncthreads();
122121
int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop
123122
// squeezed from the N layers loop.
124123
/* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */
@@ -267,16 +266,6 @@ __global__ void ScatterMeanGPUKernel(tensor_t* self_data,
267266
int tid = threadIdx.x + blockIdx.x * blockDim.x;
268267
if (tid >= numel) return;
269268

270-
if (tid == 0) {
271-
for (int i = 0; i < numel_data; i++) {
272-
shared_mem[i] = 0; // thread_id
273-
if (include_self)
274-
shared_mem[numel_data + i] = 1; // reduce size
275-
else
276-
shared_mem[numel_data + i] = 0;
277-
}
278-
}
279-
__syncthreads();
280269
int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop
281270
// squeezed from the N layers loop.
282271
/* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */
@@ -384,6 +373,7 @@ struct gpu_gather_scatter_functor {
384373
int* shared_mem;
385374
cudaMallocAsync(
386375
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
376+
cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream);
387377
ScatterAssignGPUKernel<tensor_t, index_t, func_t, is_scatter_like>
388378
<<<grid, block, 0, stream>>>(self_data,
389379
dim,
@@ -405,6 +395,14 @@ struct gpu_gather_scatter_functor {
405395
int* shared_mem;
406396
cudaMallocAsync(
407397
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
398+
cudaMemsetAsync(shared_mem, 0, sizeof(int) * self_size, stream);
399+
if (include_self) {
400+
int64_t grid_memset = (self_size * 2 + block - 1) / block;
401+
CudaMemsetAsync<<<grid_memset, block, 0, stream>>>(
402+
shared_mem, 1, shared_mem_size);
403+
} else {
404+
cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream);
405+
}
408406
ScatterMeanGPUKernel<tensor_t, index_t, func_t, is_scatter_like>
409407
<<<grid, block, 0, stream>>>(self_data,
410408
dim,
@@ -429,6 +427,9 @@ struct gpu_gather_scatter_functor {
429427
shared_mem_size = sizeof(int) * self_size;
430428
cudaMallocAsync(
431429
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
430+
int64_t grid_memset = (self_size + block - 1) / block;
431+
CudaMemsetAsync<<<grid_memset, block, 0, stream>>>(
432+
shared_mem, index_size + 1, shared_mem_size);
432433
}
433434
GatherScatterGPUKernel<tensor_t, index_t, func_t, is_scatter_like>
434435
<<<grid, block, shared_mem_size, stream>>>(self_data,
@@ -640,12 +641,6 @@ __global__ void ScatterMulInputGradGPUKernel(tensor_t* grad_data,
640641
int* thread_ids) {
641642
int tid = threadIdx.x + blockIdx.x * blockDim.x;
642643
if (tid >= numel) return;
643-
if (tid == 0) {
644-
for (int i = 0; i < numel_grad; i++) {
645-
thread_ids[i] = 0;
646-
}
647-
}
648-
__syncthreads();
649644
int64_t i, j, k;
650645
i = tid / (select_dim_size * outer_dim_size);
651646
int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -682,13 +677,6 @@ __global__ void ScatterMinMaxInputGradGPUKernel(tensor_t* grad_data,
682677
int* shared_mem) {
683678
int tid = threadIdx.x + blockIdx.x * blockDim.x;
684679
if (tid >= numel) return;
685-
686-
if (tid == 0) {
687-
for (int i = 0; i < numel_grad; i++) {
688-
shared_mem[i] = 1; // number of elements
689-
}
690-
}
691-
__syncthreads();
692680
int64_t i, j, k;
693681
i = tid / (select_dim_size * outer_dim_size);
694682
int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -762,6 +750,7 @@ void gpu_scatter_mul_min_max_input_grad_kernel(phi::DenseTensor self,
762750
int* shared_mem;
763751
cudaMallocAsync(
764752
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
753+
cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream);
765754
ScatterMulInputGradGPUKernel<tensor_t, index_t>
766755
<<<grid, block, 0, stream>>>(grad_data,
767756
dim,
@@ -781,6 +770,9 @@ void gpu_scatter_mul_min_max_input_grad_kernel(phi::DenseTensor self,
781770
int* shared_mem;
782771
cudaMallocAsync(
783772
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
773+
int64_t grid_memset = (grad_size + block - 1) / block;
774+
CudaMemsetAsync<<<grid_memset, block, 0, stream>>>(
775+
shared_mem, 1, shared_mem_size);
784776
ScatterMinMaxInputGradGPUKernel<tensor_t, index_t>
785777
<<<grid, block, 0, stream>>>(grad_data,
786778
dim,
@@ -816,13 +808,6 @@ __global__ void ScatterMeanInputGradGPUKernel(tensor_t* grad_data,
816808
int* shared_mem) {
817809
int tid = threadIdx.x + blockIdx.x * blockDim.x;
818810
if (tid >= numel) return;
819-
if (tid == 0) {
820-
for (int i = 0; i < numel_grad; i++) {
821-
shared_mem[i] = 0; // thread_ids
822-
shared_mem[numel_grad + i] = 1; // number of elements
823-
}
824-
}
825-
__syncthreads();
826811
int64_t i, j, k;
827812
i = tid / (select_dim_size * outer_dim_size);
828813
int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -879,6 +864,10 @@ void gpu_scatter_mean_input_grad_kernel(phi::DenseTensor self,
879864
int* shared_mem;
880865
cudaMallocAsync(
881866
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
867+
cudaMemsetAsync(shared_mem, 0, sizeof(int) * grad_size, stream);
868+
int64_t grid_memset = (grad_size + block - 1) / block;
869+
CudaMemsetAsync<<<grid_memset, block, 0, stream>>>(
870+
shared_mem + grad_size, 1, sizeof(int) * grad_size);
882871
ScatterMeanInputGradGPUKernel<tensor_t, index_t>
883872
<<<grid, block, 0, stream>>>(grad_data,
884873
dim,
@@ -910,12 +899,6 @@ __global__ void ScatterValueGradGPUKernel(tensor_t* grad_data,
910899
int tid = threadIdx.x + blockIdx.x * blockDim.x;
911900
if (tid >= numel) return;
912901

913-
if (tid == 0) {
914-
for (int i = 0; i < numel_data; i++) {
915-
thread_ids[i] = 0;
916-
}
917-
}
918-
__syncthreads();
919902
int64_t i, j, k;
920903
i = tid / (select_dim_size * outer_dim_size);
921904
int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -975,6 +958,7 @@ void gpu_scatter_value_grad_kernel(phi::DenseTensor self,
975958
int* shared_mem;
976959
cudaMallocAsync(
977960
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
961+
cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream);
978962
ScatterValueGradGPUKernel<tensor_t, index_t>
979963
<<<grid, block, 0, stream>>>(grad_data,
980964
dim,
@@ -1005,20 +989,10 @@ __global__ void ScatterMeanValueGradGPUKernel(tensor_t* grad_data,
1005989
int64_t outer_dim_size_grad,
1006990
int64_t numel,
1007991
int64_t numel_self,
1008-
bool include_self,
1009992
int* shared_mem) {
1010993
int tid = threadIdx.x + blockIdx.x * blockDim.x;
1011994
if (tid >= numel) return;
1012995

1013-
if (tid == 0) {
1014-
for (int i = 0; i < numel_self; i++) {
1015-
if (include_self)
1016-
shared_mem[i] = 1; // number of elements
1017-
else
1018-
shared_mem[i] = 0;
1019-
}
1020-
}
1021-
__syncthreads();
1022996
int64_t i, j, k;
1023997
i = tid / (select_dim_size * outer_dim_size);
1024998
int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -1114,6 +1088,13 @@ void gpu_scatter_add_mean_value_grad_kernel(
11141088
int* shared_mem;
11151089
cudaMallocAsync(
11161090
reinterpret_cast<void**>(&shared_mem), shared_mem_size, stream);
1091+
if (include_self) {
1092+
int64_t grid_memset = (self_size + block - 1) / block;
1093+
CudaMemsetAsync<<<grid_memset, block, 0, stream>>>(
1094+
shared_mem, 1, shared_mem_size);
1095+
} else {
1096+
cudaMemsetAsync(shared_mem, 0, shared_mem_size, stream);
1097+
}
11171098
ScatterMeanValueGradGPUKernel<tensor_t, index_t>
11181099
<<<grid, block, 0, stream>>>(grad_data,
11191100
dim,
@@ -1127,7 +1108,6 @@ void gpu_scatter_add_mean_value_grad_kernel(
11271108
outer_dim_size_grad,
11281109
index_size,
11291110
self_size,
1130-
include_self,
11311111
shared_mem);
11321112
cudaFreeAsync(shared_mem, stream);
11331113
} else if (reduce == "add") {

0 commit comments

Comments
 (0)