@@ -37,52 +37,50 @@ __global__ void __launch_bounds__(THREADS) layer_norm_NC(
37
37
v_mean1 = ew_add (v_mean1, x);
38
38
v_mean2 = ew_add (v_mean2, ew_sqr (x));
39
39
}
40
- float2 mean ;
41
- mean .x = ew_sum (v_mean1) * rcpK;
42
- mean .y = ew_sum (v_mean2) * rcpK;
40
+ float2 stats ;
41
+ stats .x = ew_sum (v_mean1) * rcpK;
42
+ stats .y = ew_sum (v_mean2) * rcpK;
43
43
44
44
// reduce within warp
45
45
for (int i = 16 ; i > 0 ; i >>= 1 )
46
- {
47
- mean.x += shfl_xor (mean.x , i);
48
- mean.y += shfl_xor (mean.y , i);
49
- }
46
+ stats = ew_warp_sum (stats, i);
47
+
50
48
// if using more than 1 warp, further reduced with shared memory
51
49
if (THREADS > 32 )
52
50
{
53
51
__shared__ float2 Share[32 ];
54
52
55
53
// first thread of each warp store to shared
56
54
if ((tid & 31 ) == 0 )
57
- Share[tid/32 ] = mean ;
55
+ Share[tid/32 ] = stats ;
58
56
59
57
__syncthreads ();
60
58
61
59
if (tid < 32 )
62
60
{
63
61
// first warp loads all prior reductions
64
- mean = Share[tid];
62
+ stats = Share[tid];
65
63
66
64
// reduce within this first warp
67
65
for (int i = THREADS/64 ; i > 0 ; i >>= 1 )
68
- {
69
- mean.x += shfl_xor (mean.x , i);
70
- mean.y += shfl_xor (mean.y , i);
71
- }
72
- // outputs final reduction to shared
73
- Share[tid] = mean;
66
+ stats = ew_warp_sum (stats, i);
67
+
68
+ // final reduction to shared
69
+ Share[tid] = stats;
74
70
}
75
71
__syncthreads ();
76
72
77
73
// broadcast result to all threads
78
- mean = Share[0 ];
74
+ stats = Share[0 ];
79
75
}
80
76
// var = avg(x**2) - avg(x)**2
81
77
// rstd = 1/sqrt(var)
82
- float rstd = rsqrtf ((mean.y - ew_sqr (mean.x )) + epsilon);
78
+ float mean = stats.x ;
79
+ float rstd = rsqrtf (precise_sub (stats.y , ew_sqr (mean)) + epsilon);
80
+
83
81
if (tid == 0 )
84
82
{
85
- Mean[n] = mean. x ;
83
+ Mean[n] = mean;
86
84
Rstd[n] = rstd;
87
85
}
88
86
@@ -94,7 +92,7 @@ __global__ void __launch_bounds__(THREADS) layer_norm_NC(
94
92
V g = load (G, k);
95
93
V b = load (B, k);
96
94
97
- V xhat = ew_mul (ew_sub (x, mean. x ), rstd);
95
+ V xhat = ew_mul (ew_sub (x, mean), rstd);
98
96
V y = ew_add (ew_mul (xhat, g), b);
99
97
100
98
if (relu)
@@ -513,17 +511,17 @@ __global__ void layer_norm_segmented_nc(
513
511
#pragma unroll 1
514
512
for (int i = thread2/64 ; i > 0 ; i >>= 1 )
515
513
stats = ew_warp_sum (stats, i);
514
+
516
515
// final reduction to shared
517
516
Share[tid] = stats;
518
517
}
519
518
__syncthreads ();
520
519
stats = Share[0 ];
521
520
}
522
-
523
521
// var = avg(x**2) - avg(x)**2
524
522
// rstd = 1/sqrt(var)
525
523
float mean = stats.x ;
526
- float rstd = rsqrtf ((stats.y - mean*mean ) + epsilon);
524
+ float rstd = rsqrtf (precise_sub (stats.y , ew_sqr ( mean) ) + epsilon);
527
525
if (tid == 0 )
528
526
{
529
527
__stg (add_ptr_u (Mean, m), mean);
@@ -808,17 +806,20 @@ bool LayerNormSegmentedBackward_NC(CUstream stream, int SMs,
808
806
const float * b,
809
807
const float * mean,
810
808
const float * rstd,
811
- float epsilon, uint N, uint S, uint K, float rcpK, int relu)
809
+ float epsilon, uint N, uint S, uint K, float rcpK, int relu, int atomics )
812
810
{
813
811
uint gridK = CEIL_DIV (K, 32 );
814
812
uint gridN = 1 ;
815
- uint blocksK = gridK * S;
816
- while (gridN < (N>>3 ) && gridN * blocksK < 32 *SMs) gridN += 1 ;
817
- if (gridN * blocksK > 32 *SMs && gridN > 1 ) gridN -= 1 ;
818
- if (gridN > 1 )
813
+ if (atomics)
819
814
{
820
- cuMemsetD32Async ((CUdeviceptr)dg, 0 , S*K, stream);
821
- cuMemsetD32Async ((CUdeviceptr)db, 0 , S*K, stream);
815
+ uint blocksK = gridK * S;
816
+ while (gridN < (N>>3 ) && gridN * blocksK < 32 *SMs) gridN += 1 ;
817
+ if (gridN * blocksK > 32 *SMs && gridN > 1 ) gridN -= 1 ;
818
+ if (gridN > 1 )
819
+ {
820
+ cuMemsetD32Async ((CUdeviceptr)dg, 0 , S*K, stream);
821
+ cuMemsetD32Async ((CUdeviceptr)db, 0 , S*K, stream);
822
+ }
822
823
}
823
824
layer_norm_segmented_dg_db_nc<T><<<dim3 (gridN,gridK,S),32 ,0 ,stream>>> (dg, db, dy, x, g, b, mean, rstd, N, S*K, S*K*gridN, K, relu);
824
825
@@ -869,9 +870,9 @@ bool LayerNormSegmentedBackward_NC(CUstream stream, int SMs,
869
870
}
870
871
return true ; // TODO
871
872
}
872
- template bool LayerNormSegmentedBackward_NC<float ,float4 >(CUstream stream, int SMs, float * dx, float * dg, float * db, const float * dy, const float * x, const float * g, const float * b, const float * mean, const float * rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu);
873
- template bool LayerNormSegmentedBackward_NC<ehalf,ehalf4>(CUstream stream, int SMs, ehalf* dx, float * dg, float * db, const ehalf* dy, const ehalf* x, const float * g, const float * b, const float * mean, const float * rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu);
874
- template bool LayerNormSegmentedBackward_NC<bhalf,bhalf4>(CUstream stream, int SMs, bhalf* dx, float * dg, float * db, const bhalf* dy, const bhalf* x, const float * g, const float * b, const float * mean, const float * rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu);
873
+ template bool LayerNormSegmentedBackward_NC<float ,float4 >(CUstream stream, int SMs, float * dx, float * dg, float * db, const float * dy, const float * x, const float * g, const float * b, const float * mean, const float * rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu, int atomics );
874
+ template bool LayerNormSegmentedBackward_NC<ehalf,ehalf4>(CUstream stream, int SMs, ehalf* dx, float * dg, float * db, const ehalf* dy, const ehalf* x, const float * g, const float * b, const float * mean, const float * rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu, int atomics );
875
+ template bool LayerNormSegmentedBackward_NC<bhalf,bhalf4>(CUstream stream, int SMs, bhalf* dx, float * dg, float * db, const bhalf* dy, const bhalf* x, const float * g, const float * b, const float * mean, const float * rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu, int atomics );
875
876
876
877
877
878
#endif // GOOGLE_CUDA
0 commit comments