Skip to content

Commit 3ead98d

Browse files
committed
updated tests
slightly more accurate layer_norm, added deterministic mode
1 parent c2e5def commit 3ead98d

20 files changed

+1119
-468
lines changed

blocksparse/lstm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def fused_lstm_gates_grad(op, ec, eh):
6262
# compute bias grad
6363
#db = ew_db_dzb_op(dh, op.inputs[2], op=BIASADD_OP)
6464
# db = bias_grad_op(dh, op.inputs[2])
65-
db, _ = bias_grad_op(dh, op.inputs[2])
65+
db, _ = bias_grad_op(dh, op.inputs[2], axis=1)
6666

6767
return dc, dh, db
6868

blocksparse/norms.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222

23-
def layer_norm(x, g, b, axis=1, segments=1, epsilon=1e-6, relu=False, bench=0, use_tf=False):
23+
def layer_norm(x, g, b, axis=1, segments=1, epsilon=1e-6, relu=False, atomics=True, bench=0, use_tf=False):
2424

2525
dev = g.op.device.lower()
2626
if use_tf or not dev or "cpu" in dev:
@@ -51,7 +51,7 @@ def layer_norm(x, g, b, axis=1, segments=1, epsilon=1e-6, relu=False, bench=0, u
5151
if relu:
5252
y = tf.nn.relu(y)
5353
else:
54-
y, m, v, _, _ = layer_norm_op(x, g, b, S=segments, axis=axis, epsilon=epsilon, relu=relu, bench=bench)
54+
y, m, v, _, _ = layer_norm_op(x, g, b, S=segments, axis=axis, epsilon=epsilon, relu=relu, atomics=atomics, bench=bench)
5555

5656
return y
5757

@@ -61,8 +61,9 @@ def layer_norm_grad(op, dy, mean, rstd, p1, p2):
6161
epsilon = op.get_attr("epsilon")
6262
relu = op.get_attr("relu")
6363
axis = op.get_attr("axis")
64+
atomics = op.get_attr("atomics")
6465
bench = op.get_attr("bench")
65-
dx, dg, db, _, _ = layer_norm_grad_op(dy, op.inputs[0], op.inputs[1], op.inputs[2], op.outputs[1], op.outputs[2], S=S, axis=axis, epsilon=epsilon, relu=relu, bench=bench)
66+
dx, dg, db, _, _ = layer_norm_grad_op(dy, op.inputs[0], op.inputs[1], op.inputs[2], op.outputs[1], op.outputs[2], S=S, axis=axis, epsilon=epsilon, relu=relu, atomics=atomics, bench=bench)
6667
return dx, dg, db
6768

6869
def batch_norm_inference(x, g, b, m, v, epsilon=1e-6):
@@ -168,8 +169,8 @@ def layer_norm_grad_test(dy, x, g, b, axis=1, segments=1, epsilon=1e-6, relu=Fal
168169

169170
#print("x:%.2f, mean:%.2f, rstd:%.2f, xhat:%.2f, dy:%.2f\n" % (x[0,0], mean[0,0], xstdr[0,0], xhat[0,0], dy[0,0]));
170171

171-
dg[seg] = np.sum(dy[seg] * xhat, axis=1-axis)
172-
db[seg] = np.sum(dy[seg], axis=1-axis)
172+
dg[seg] = np.sum(dy[seg] * xhat, axis=1-axis, keepdims=True)
173+
db[seg] = np.sum(dy[seg], axis=1-axis, keepdims=True)
173174
dy[seg] = dy[seg] * g[seg]
174175

175176
sum1 = np.sum(xhat * dy[seg], axis=axis, keepdims=True)

src/bst_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ class BlocksparseTransformerOp : public OpKernel {
223223

224224
if (a.dtype() == DT_HALF)
225225
{
226+
OP_REQUIRES(ctx, major_ >= 7, errors::InvalidArgument("Tensorcore GPU required"));
227+
226228
const ehalf* a_ptr = (const ehalf*)a.tensor_data().data();
227229
const ehalf* b_ptr = (const ehalf*)b.tensor_data().data();
228230

@@ -293,6 +295,8 @@ class BlocksparseTransformerOp : public OpKernel {
293295

294296
if (a.dtype() == DT_HALF)
295297
{
298+
OP_REQUIRES(ctx, major_ >= 7, errors::InvalidArgument("Tensorcore GPU required"));
299+
296300
const ehalf* a_ptr = (const ehalf*)a.tensor_data().data();
297301
const ehalf* b_ptr = (const ehalf*)b.tensor_data().data();
298302
ehalf* c_ptr = ( ehalf*)c->tensor_data().data();

src/ew_op_gpu.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,20 @@ __device__ __forceinline__ void ew_set(bhalf8 &a, uint val) { a.x = a.y = a.z =
849849
__device__ __forceinline__ void ew_zero(vhalf &a) { a.x = 0; }
850850
__device__ __forceinline__ void ew_zero(mhalf &a) { a.x = 0; }
851851

852-
852+
// minimize catastrophic cancellation: https://en.wikipedia.org/wiki/Loss_of_significance
853+
// Probably unnecessary, but GPU supports it at no cost (when used sparingly)
854+
__device__ __forceinline__ float precise_sub(float a, float b)
855+
{
856+
float r;
857+
asm("{\n\t"
858+
".reg .f64 a, b, c;\n\t"
859+
"cvt.f64.f32 a, %1;\n\t"
860+
"cvt.f64.f32 b, %2;\n\t"
861+
"sub.f64 c, a, b;\n\t"
862+
"cvt.rn.f32.f64 %0, c;\n\t"
863+
"}" : "=f"(r) : "f"(a), "f"(b));
864+
return r;
865+
}
853866

854867
__device__ __forceinline__ float _ex2_approx(float x)
855868
{
@@ -1139,6 +1152,8 @@ MATH_Z_XY(ew_mul, _mul)
11391152
MATH_Z_XY(ew_div, _div)
11401153
MATH_Z_XY(ew_maximum, fmaxf)
11411154
MATH_Z_XY(ew_minimum, fminf)
1155+
MATH_Z_XY(ew_precise_sub, precise_sub)
1156+
11421157

11431158
MATH_Z_X(ew_abs, fabsf)
11441159
MATH_Z_X(ew_neg, _neg)

src/layer_norm_cn_op_gpu.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ __global__ void __launch_bounds__(256) layer_norm_moments2_CN(
136136
// rstd = 1/sqrt(var)
137137
mean1 *= rcpK;
138138
mean2 *= rcpK;
139-
float rstd = rsqrtf((mean2 - ew_sqr(mean1)) + epsilon);
139+
float rstd = rsqrtf(precise_sub(mean2, ew_sqr(mean1)) + epsilon);
140140
store(add_ptr_u(Mean, n), mean1);
141141
store(add_ptr_u(Rstd, n), rstd);
142142
}

src/layer_norm_nc_op_gpu.cu

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,52 +37,50 @@ __global__ void __launch_bounds__(THREADS) layer_norm_NC(
3737
v_mean1 = ew_add(v_mean1, x);
3838
v_mean2 = ew_add(v_mean2, ew_sqr(x));
3939
}
40-
float2 mean;
41-
mean.x = ew_sum(v_mean1) * rcpK;
42-
mean.y = ew_sum(v_mean2) * rcpK;
40+
float2 stats;
41+
stats.x = ew_sum(v_mean1) * rcpK;
42+
stats.y = ew_sum(v_mean2) * rcpK;
4343

4444
// reduce within warp
4545
for (int i = 16; i > 0; i >>= 1)
46-
{
47-
mean.x += shfl_xor(mean.x, i);
48-
mean.y += shfl_xor(mean.y, i);
49-
}
46+
stats = ew_warp_sum(stats, i);
47+
5048
// if using more than 1 warp, further reduced with shared memory
5149
if (THREADS > 32)
5250
{
5351
__shared__ float2 Share[32];
5452

5553
// first thread of each warp store to shared
5654
if ((tid & 31) == 0)
57-
Share[tid/32] = mean;
55+
Share[tid/32] = stats;
5856

5957
__syncthreads();
6058

6159
if (tid < 32)
6260
{
6361
// first warp loads all prior reductions
64-
mean = Share[tid];
62+
stats = Share[tid];
6563

6664
// reduce within this first warp
6765
for (int i = THREADS/64; i > 0; i >>= 1)
68-
{
69-
mean.x += shfl_xor(mean.x, i);
70-
mean.y += shfl_xor(mean.y, i);
71-
}
72-
// outputs final reduction to shared
73-
Share[tid] = mean;
66+
stats = ew_warp_sum(stats, i);
67+
68+
// final reduction to shared
69+
Share[tid] = stats;
7470
}
7571
__syncthreads();
7672

7773
// broadcast result to all threads
78-
mean = Share[0];
74+
stats = Share[0];
7975
}
8076
// var = avg(x**2) - avg(x)**2
8177
// rstd = 1/sqrt(var)
82-
float rstd = rsqrtf((mean.y - ew_sqr(mean.x)) + epsilon);
78+
float mean = stats.x;
79+
float rstd = rsqrtf(precise_sub(stats.y, ew_sqr(mean)) + epsilon);
80+
8381
if (tid == 0)
8482
{
85-
Mean[n] = mean.x;
83+
Mean[n] = mean;
8684
Rstd[n] = rstd;
8785
}
8886

@@ -94,7 +92,7 @@ __global__ void __launch_bounds__(THREADS) layer_norm_NC(
9492
V g = load(G, k);
9593
V b = load(B, k);
9694

97-
V xhat = ew_mul(ew_sub(x, mean.x), rstd);
95+
V xhat = ew_mul(ew_sub(x, mean), rstd);
9896
V y = ew_add(ew_mul(xhat, g), b);
9997

10098
if (relu)
@@ -513,17 +511,17 @@ __global__ void layer_norm_segmented_nc(
513511
#pragma unroll 1
514512
for (int i = thread2/64; i > 0; i >>= 1)
515513
stats = ew_warp_sum(stats, i);
514+
516515
// final reduction to shared
517516
Share[tid] = stats;
518517
}
519518
__syncthreads();
520519
stats = Share[0];
521520
}
522-
523521
// var = avg(x**2) - avg(x)**2
524522
// rstd = 1/sqrt(var)
525523
float mean = stats.x;
526-
float rstd = rsqrtf((stats.y - mean*mean) + epsilon);
524+
float rstd = rsqrtf(precise_sub(stats.y, ew_sqr(mean)) + epsilon);
527525
if (tid == 0)
528526
{
529527
__stg(add_ptr_u(Mean, m), mean);
@@ -808,17 +806,20 @@ bool LayerNormSegmentedBackward_NC(CUstream stream, int SMs,
808806
const float* b,
809807
const float* mean,
810808
const float* rstd,
811-
float epsilon, uint N, uint S, uint K, float rcpK, int relu)
809+
float epsilon, uint N, uint S, uint K, float rcpK, int relu, int atomics)
812810
{
813811
uint gridK = CEIL_DIV(K, 32);
814812
uint gridN = 1;
815-
uint blocksK = gridK * S;
816-
while (gridN < (N>>3) && gridN * blocksK < 32*SMs) gridN += 1;
817-
if (gridN * blocksK > 32*SMs && gridN > 1) gridN -= 1;
818-
if (gridN > 1)
813+
if (atomics)
819814
{
820-
cuMemsetD32Async((CUdeviceptr)dg, 0, S*K, stream);
821-
cuMemsetD32Async((CUdeviceptr)db, 0, S*K, stream);
815+
uint blocksK = gridK * S;
816+
while (gridN < (N>>3) && gridN * blocksK < 32*SMs) gridN += 1;
817+
if (gridN * blocksK > 32*SMs && gridN > 1) gridN -= 1;
818+
if (gridN > 1)
819+
{
820+
cuMemsetD32Async((CUdeviceptr)dg, 0, S*K, stream);
821+
cuMemsetD32Async((CUdeviceptr)db, 0, S*K, stream);
822+
}
822823
}
823824
layer_norm_segmented_dg_db_nc<T><<<dim3(gridN,gridK,S),32,0,stream>>>(dg, db, dy, x, g, b, mean, rstd, N, S*K, S*K*gridN, K, relu);
824825

@@ -869,9 +870,9 @@ bool LayerNormSegmentedBackward_NC(CUstream stream, int SMs,
869870
}
870871
return true; // TODO
871872
}
872-
template bool LayerNormSegmentedBackward_NC<float,float4>(CUstream stream, int SMs, float* dx, float* dg, float* db, const float* dy, const float* x, const float* g, const float* b, const float* mean, const float* rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu);
873-
template bool LayerNormSegmentedBackward_NC<ehalf,ehalf4>(CUstream stream, int SMs, ehalf* dx, float* dg, float* db, const ehalf* dy, const ehalf* x, const float* g, const float* b, const float* mean, const float* rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu);
874-
template bool LayerNormSegmentedBackward_NC<bhalf,bhalf4>(CUstream stream, int SMs, bhalf* dx, float* dg, float* db, const bhalf* dy, const bhalf* x, const float* g, const float* b, const float* mean, const float* rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu);
873+
template bool LayerNormSegmentedBackward_NC<float,float4>(CUstream stream, int SMs, float* dx, float* dg, float* db, const float* dy, const float* x, const float* g, const float* b, const float* mean, const float* rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu, int atomics);
874+
template bool LayerNormSegmentedBackward_NC<ehalf,ehalf4>(CUstream stream, int SMs, ehalf* dx, float* dg, float* db, const ehalf* dy, const ehalf* x, const float* g, const float* b, const float* mean, const float* rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu, int atomics);
875+
template bool LayerNormSegmentedBackward_NC<bhalf,bhalf4>(CUstream stream, int SMs, bhalf* dx, float* dg, float* db, const bhalf* dy, const bhalf* x, const float* g, const float* b, const float* mean, const float* rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu, int atomics);
875876

876877

877878
#endif // GOOGLE_CUDA

src/layer_norm_op.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ REGISTER_OP("LayerNorm")
3535
.Attr("axis: int")
3636
.Attr("epsilon: float")
3737
.Attr("relu: bool")
38+
.Attr("atomics: bool = true")
3839
.Attr("bench: int = 0")
3940
.SetShapeFn([](InferenceContext* ctx) {
4041

@@ -80,6 +81,7 @@ class LayerNormOp : public OpKernel {
8081
OP_REQUIRES_OK(ctx, ctx->GetAttr("S", &S_ ));
8182
OP_REQUIRES_OK(ctx, ctx->GetAttr("relu", &relu_ ));
8283
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_ ));
84+
OP_REQUIRES_OK(ctx, ctx->GetAttr("atomics", &atomics_));
8385
OP_REQUIRES_OK(ctx, ctx->GetAttr("bench", &bench_ ));
8486
repeat_ = bench_ ? bench_ : 1;
8587
}
@@ -104,6 +106,8 @@ class LayerNormOp : public OpKernel {
104106
N *= x.dim_size(i);
105107
}
106108
}
109+
OP_REQUIRES(ctx, axis_ != 0 || (N & 3) == 0, errors::InvalidArgument("Sum of non-feature axis dims needs to be multiple of 4 for feature axis=0."));
110+
107111
if (K_ == 0)
108112
{
109113
OP_REQUIRES(ctx, K == g.shape().num_elements(), errors::InvalidArgument("Bad Gain Shape"));
@@ -166,15 +170,15 @@ class LayerNormOp : public OpKernel {
166170
}
167171
float epsilon_, rcpK_;
168172
int S_, K_, axis_, SMs_, bench_, repeat_;
169-
bool relu_;
173+
bool relu_, atomics_;
170174
};
171175
REGISTER_KERNEL_BUILDER(Name("LayerNorm").Device(DEVICE_GPU).TypeConstraint<FLOAT>("T"), LayerNormOp<FLOAT,float,float4>);
172176
REGISTER_KERNEL_BUILDER(Name("LayerNorm").Device(DEVICE_GPU).TypeConstraint<EHALF>("T"), LayerNormOp<EHALF,ehalf,ehalf4>);
173177
REGISTER_KERNEL_BUILDER(Name("LayerNorm").Device(DEVICE_GPU).TypeConstraint<BHALF>("T"), LayerNormOp<BHALF,bhalf,bhalf4>);
174178

175179
template <typename T, typename V> bool LayerNormBackward_NC(CUstream stream, int SMs, T* dx, float* dg, float* db, const T* dy, const T* x, const float* g, const float* b, const float* mean, const float* rstd, float epsilon, int K, int N, float rcpK, int relu);
176180
template <typename T, typename V> bool LayerNormBackward_CN(CUstream stream, int SMs, T* dx, float* dg, float* db, float* sum1, float* sum2, const T* dy, const T* x, const float* g, const float* b, const float* mean, const float* rstd, float epsilon, int K, int N, float rcpK, int relu);
177-
template <typename T, typename V> bool LayerNormSegmentedBackward_NC(CUstream stream, int SMs, T* dx, float* dg, float* db, const T* dy, const T* x, const float* g, const float* b, const float* mean, const float* rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu);
181+
template <typename T, typename V> bool LayerNormSegmentedBackward_NC(CUstream stream, int SMs, T* dx, float* dg, float* db, const T* dy, const T* x, const float* g, const float* b, const float* mean, const float* rstd, float epsilon, uint N, uint S, uint K, float rcpK, int relu, int atomics);
178182

179183
REGISTER_OP("LayerNormGrad")
180184
.Input("dy: T")
@@ -193,6 +197,7 @@ REGISTER_OP("LayerNormGrad")
193197
.Attr("axis: int")
194198
.Attr("epsilon: float")
195199
.Attr("relu: bool")
200+
.Attr("atomics: bool = true")
196201
.Attr("bench: int = 0")
197202
.SetShapeFn([](InferenceContext* ctx) {
198203
ctx->set_output(0, ctx->input(1));
@@ -215,6 +220,7 @@ class LayerNormGradOp : public OpKernel {
215220
OP_REQUIRES_OK(ctx, ctx->GetAttr("S", &S_ ));
216221
OP_REQUIRES_OK(ctx, ctx->GetAttr("relu", &relu_ ));
217222
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_ ));
223+
OP_REQUIRES_OK(ctx, ctx->GetAttr("atomics", &atomics_));
218224
OP_REQUIRES_OK(ctx, ctx->GetAttr("bench", &bench_ ));
219225
repeat_ = bench_ ? bench_ : 1;
220226
}
@@ -289,7 +295,7 @@ class LayerNormGradOp : public OpKernel {
289295
else
290296
{
291297
if (S_ > 1 || K_ <= 1024*8)
292-
LayerNormSegmentedBackward_NC<V1,V4>(stream, SMs_, dx_ptr, dg_ptr, db_ptr, dy_ptr, x_ptr, g_ptr, b_ptr, mean_ptr, rstd_ptr, epsilon_, N, S_, K_, rcpK_, relu_);
298+
LayerNormSegmentedBackward_NC<V1,V4>(stream, SMs_, dx_ptr, dg_ptr, db_ptr, dy_ptr, x_ptr, g_ptr, b_ptr, mean_ptr, rstd_ptr, epsilon_, N, S_, K_, rcpK_, relu_, atomics_);
293299
else
294300
LayerNormBackward_NC<V1,V4>(stream, SMs_, dx_ptr, dg_ptr, db_ptr, dy_ptr, x_ptr, g_ptr, b_ptr, mean_ptr, rstd_ptr, epsilon_, K_, N, rcpK_, relu_);
295301
}
@@ -298,7 +304,7 @@ class LayerNormGradOp : public OpKernel {
298304
}
299305
float epsilon_, rcpK_;
300306
int S_, K_, axis_, SMs_, bench_, repeat_;
301-
bool relu_;
307+
bool relu_, atomics_;
302308
};
303309
REGISTER_KERNEL_BUILDER(Name("LayerNormGrad").Device(DEVICE_GPU).TypeConstraint<FLOAT>("T"), LayerNormGradOp<FLOAT,float,float4>);
304310
REGISTER_KERNEL_BUILDER(Name("LayerNormGrad").Device(DEVICE_GPU).TypeConstraint<EHALF>("T"), LayerNormGradOp<EHALF,ehalf,ehalf4>);

src/lstm_op_gpu.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ __global__ void __launch_bounds__(THREADS) sparse_relu_forward(
617617
}
618618
// var = avg(x**2) - avg(x)**2
619619
// std = sqrt(var)
620-
float std = sqrtf(mean.y - mean.x*mean.x);
620+
float std = sqrtf(precise_sub(mean.y, mean.x*mean.x));
621621

622622
// Norm/Gain/Bias
623623
X += offset;

0 commit comments

Comments
 (0)