@@ -92,6 +92,12 @@ class ReduceMin {
92
92
};
93
93
static ReduceMin reduce_min;
94
94
95
+ __global__ void CudaMemsetAsync (int * dest, int value, size_t size) {
96
+ int tid = threadIdx .x + blockIdx .x * blockDim .x ;
97
+ if (tid * sizeof (int ) >= size) return ;
98
+ dest[tid] = value;
99
+ }
100
+
95
101
template <typename tensor_t ,
96
102
typename index_t ,
97
103
typename func_t ,
@@ -112,13 +118,6 @@ __global__ void ScatterAssignGPUKernel(tensor_t* self_data,
112
118
int * thread_ids) {
113
119
int tid = threadIdx .x + blockIdx .x * blockDim .x ;
114
120
if (tid >= numel) return ;
115
-
116
- if (tid == 0 ) {
117
- for (int i = 0 ; i < numel_data; i++) {
118
- thread_ids[i] = 0 ;
119
- }
120
- }
121
- __syncthreads ();
122
121
int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop
123
122
// squeezed from the N layers loop.
124
123
/* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */
@@ -267,16 +266,6 @@ __global__ void ScatterMeanGPUKernel(tensor_t* self_data,
267
266
int tid = threadIdx .x + blockIdx .x * blockDim .x ;
268
267
if (tid >= numel) return ;
269
268
270
- if (tid == 0 ) {
271
- for (int i = 0 ; i < numel_data; i++) {
272
- shared_mem[i] = 0 ; // thread_id
273
- if (include_self)
274
- shared_mem[numel_data + i] = 1 ; // reduce size
275
- else
276
- shared_mem[numel_data + i] = 0 ;
277
- }
278
- }
279
- __syncthreads ();
280
269
int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop
281
270
// squeezed from the N layers loop.
282
271
/* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */
@@ -384,6 +373,7 @@ struct gpu_gather_scatter_functor {
384
373
int * shared_mem;
385
374
cudaMallocAsync (
386
375
reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
376
+ cudaMemsetAsync (shared_mem, 0 , shared_mem_size, stream);
387
377
ScatterAssignGPUKernel<tensor_t , index_t , func_t , is_scatter_like>
388
378
<<<grid, block, 0 , stream>>> (self_data,
389
379
dim,
@@ -405,6 +395,14 @@ struct gpu_gather_scatter_functor {
405
395
int * shared_mem;
406
396
cudaMallocAsync (
407
397
reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
398
+ cudaMemsetAsync (shared_mem, 0 , sizeof (int ) * self_size, stream);
399
+ if (include_self) {
400
+ int64_t grid_memset = (self_size * 2 + block - 1 ) / block;
401
+ CudaMemsetAsync<<<grid_memset, block, 0 , stream>>> (
402
+ shared_mem, 1 , shared_mem_size);
403
+ } else {
404
+ cudaMemsetAsync (shared_mem, 0 , shared_mem_size, stream);
405
+ }
408
406
ScatterMeanGPUKernel<tensor_t , index_t , func_t , is_scatter_like>
409
407
<<<grid, block, 0 , stream>>> (self_data,
410
408
dim,
@@ -429,6 +427,9 @@ struct gpu_gather_scatter_functor {
429
427
shared_mem_size = sizeof (int ) * self_size;
430
428
cudaMallocAsync (
431
429
reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
430
+ int64_t grid_memset = (self_size + block - 1 ) / block;
431
+ CudaMemsetAsync<<<grid_memset, block, 0 , stream>>> (
432
+ shared_mem, index_size + 1 , shared_mem_size);
432
433
}
433
434
GatherScatterGPUKernel<tensor_t , index_t , func_t , is_scatter_like>
434
435
<<<grid, block, shared_mem_size, stream>>> (self_data,
@@ -640,12 +641,6 @@ __global__ void ScatterMulInputGradGPUKernel(tensor_t* grad_data,
640
641
int * thread_ids) {
641
642
int tid = threadIdx .x + blockIdx .x * blockDim .x ;
642
643
if (tid >= numel) return ;
643
- if (tid == 0 ) {
644
- for (int i = 0 ; i < numel_grad; i++) {
645
- thread_ids[i] = 0 ;
646
- }
647
- }
648
- __syncthreads ();
649
644
int64_t i, j, k;
650
645
i = tid / (select_dim_size * outer_dim_size);
651
646
int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -682,13 +677,6 @@ __global__ void ScatterMinMaxInputGradGPUKernel(tensor_t* grad_data,
682
677
int * shared_mem) {
683
678
int tid = threadIdx .x + blockIdx .x * blockDim .x ;
684
679
if (tid >= numel) return ;
685
-
686
- if (tid == 0 ) {
687
- for (int i = 0 ; i < numel_grad; i++) {
688
- shared_mem[i] = 1 ; // number of elements
689
- }
690
- }
691
- __syncthreads ();
692
680
int64_t i, j, k;
693
681
i = tid / (select_dim_size * outer_dim_size);
694
682
int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -762,6 +750,7 @@ void gpu_scatter_mul_min_max_input_grad_kernel(phi::DenseTensor self,
762
750
int * shared_mem;
763
751
cudaMallocAsync (
764
752
reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
753
+ cudaMemsetAsync (shared_mem, 0 , shared_mem_size, stream);
765
754
ScatterMulInputGradGPUKernel<tensor_t , index_t >
766
755
<<<grid, block, 0 , stream>>> (grad_data,
767
756
dim,
@@ -781,6 +770,9 @@ void gpu_scatter_mul_min_max_input_grad_kernel(phi::DenseTensor self,
781
770
int * shared_mem;
782
771
cudaMallocAsync (
783
772
reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
773
+ int64_t grid_memset = (grad_size + block - 1 ) / block;
774
+ CudaMemsetAsync<<<grid_memset, block, 0 , stream>>> (
775
+ shared_mem, 1 , shared_mem_size);
784
776
ScatterMinMaxInputGradGPUKernel<tensor_t , index_t >
785
777
<<<grid, block, 0 , stream>>> (grad_data,
786
778
dim,
@@ -816,13 +808,6 @@ __global__ void ScatterMeanInputGradGPUKernel(tensor_t* grad_data,
816
808
int * shared_mem) {
817
809
int tid = threadIdx .x + blockIdx .x * blockDim .x ;
818
810
if (tid >= numel) return ;
819
- if (tid == 0 ) {
820
- for (int i = 0 ; i < numel_grad; i++) {
821
- shared_mem[i] = 0 ; // thread_ids
822
- shared_mem[numel_grad + i] = 1 ; // number of elements
823
- }
824
- }
825
- __syncthreads ();
826
811
int64_t i, j, k;
827
812
i = tid / (select_dim_size * outer_dim_size);
828
813
int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -879,6 +864,10 @@ void gpu_scatter_mean_input_grad_kernel(phi::DenseTensor self,
879
864
int * shared_mem;
880
865
cudaMallocAsync (
881
866
reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
867
+ cudaMemsetAsync (shared_mem, 0 , sizeof (int ) * grad_size, stream);
868
+ int64_t grid_memset = (grad_size + block - 1 ) / block;
869
+ CudaMemsetAsync<<<grid_memset, block, 0 , stream>>> (
870
+ shared_mem + grad_size, 1 , sizeof (int ) * grad_size);
882
871
ScatterMeanInputGradGPUKernel<tensor_t , index_t >
883
872
<<<grid, block, 0 , stream>>> (grad_data,
884
873
dim,
@@ -910,12 +899,6 @@ __global__ void ScatterValueGradGPUKernel(tensor_t* grad_data,
910
899
int tid = threadIdx .x + blockIdx .x * blockDim .x ;
911
900
if (tid >= numel) return ;
912
901
913
- if (tid == 0 ) {
914
- for (int i = 0 ; i < numel_data; i++) {
915
- thread_ids[i] = 0 ;
916
- }
917
- }
918
- __syncthreads ();
919
902
int64_t i, j, k;
920
903
i = tid / (select_dim_size * outer_dim_size);
921
904
int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -975,6 +958,7 @@ void gpu_scatter_value_grad_kernel(phi::DenseTensor self,
975
958
int * shared_mem;
976
959
cudaMallocAsync (
977
960
reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
961
+ cudaMemsetAsync (shared_mem, 0 , shared_mem_size, stream);
978
962
ScatterValueGradGPUKernel<tensor_t , index_t >
979
963
<<<grid, block, 0 , stream>>> (grad_data,
980
964
dim,
@@ -1005,20 +989,10 @@ __global__ void ScatterMeanValueGradGPUKernel(tensor_t* grad_data,
1005
989
int64_t outer_dim_size_grad,
1006
990
int64_t numel,
1007
991
int64_t numel_self,
1008
- bool include_self,
1009
992
int * shared_mem) {
1010
993
int tid = threadIdx .x + blockIdx .x * blockDim .x ;
1011
994
if (tid >= numel) return ;
1012
995
1013
- if (tid == 0 ) {
1014
- for (int i = 0 ; i < numel_self; i++) {
1015
- if (include_self)
1016
- shared_mem[i] = 1 ; // number of elements
1017
- else
1018
- shared_mem[i] = 0 ;
1019
- }
1020
- }
1021
- __syncthreads ();
1022
996
int64_t i, j, k;
1023
997
i = tid / (select_dim_size * outer_dim_size);
1024
998
int64_t remind = tid % (select_dim_size * outer_dim_size);
@@ -1114,6 +1088,13 @@ void gpu_scatter_add_mean_value_grad_kernel(
1114
1088
int * shared_mem;
1115
1089
cudaMallocAsync (
1116
1090
reinterpret_cast <void **>(&shared_mem), shared_mem_size, stream);
1091
+ if (include_self) {
1092
+ int64_t grid_memset = (self_size + block - 1 ) / block;
1093
+ CudaMemsetAsync<<<grid_memset, block, 0 , stream>>> (
1094
+ shared_mem, 1 , shared_mem_size);
1095
+ } else {
1096
+ cudaMemsetAsync (shared_mem, 0 , shared_mem_size, stream);
1097
+ }
1117
1098
ScatterMeanValueGradGPUKernel<tensor_t , index_t >
1118
1099
<<<grid, block, 0 , stream>>> (grad_data,
1119
1100
dim,
@@ -1127,7 +1108,6 @@ void gpu_scatter_add_mean_value_grad_kernel(
1127
1108
outer_dim_size_grad,
1128
1109
index_size,
1129
1110
self_size,
1130
- include_self,
1131
1111
shared_mem);
1132
1112
cudaFreeAsync (shared_mem, stream);
1133
1113
} else if (reduce == " add" ) {
0 commit comments