Skip to content

Commit

Permalink
fix: layernorm stride (#190)
Browse files Browse the repository at this point in the history
* fix: layernorm stride

* fix: remove unused
  • Loading branch information
gaetansnl authored Nov 29, 2022
1 parent ae1d7b5 commit d488f34
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 11 deletions.
38 changes: 27 additions & 11 deletions src/kernl/implementations/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def layer_norm_xformers(
Bias,
Mean,
Rstd,
stride_row_out,
stride_col_out,
stride_row_a,
stride_col_a,
N,
eps,
HAS_BIAS: tl.constexpr, # not used, just to make the signature similar to single pass
Expand All @@ -85,7 +88,7 @@ def layer_norm_xformers(
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N

x_ptrs = A + row * stride_row_a + cols
x_ptrs = A + row * stride_row_a + cols * stride_col_a

x = tl.load(x_ptrs, mask=mask, other=0.0, eviction_policy="evict_first").to(tl.float32)
w = tl.load(Weight + cols, mask=mask, other=1.0)
Expand All @@ -104,7 +107,7 @@ def layer_norm_xformers(
tl.store(Rstd + row, rstd)

y = y * w + b
y_ptrs = Out + row * stride_row_a + cols
y_ptrs = Out + row * stride_row_out + cols * stride_col_out
tl.store(y_ptrs, y, mask=mask)


Expand All @@ -116,7 +119,10 @@ def _layer_norm_fwd_fused_single_pass(
Bias,
Mean,
Rstd,
stride_row_out,
stride_col_out,
stride_row_a,
stride_col_a,
N,
eps,
HAS_BIAS: tl.constexpr,
Expand Down Expand Up @@ -145,7 +151,7 @@ def _layer_norm_fwd_fused_single_pass(
# position of elements processed by this program
_idx = tl.program_id(0)
a_ptr = A + _idx * stride_row_a
out_ptr = Out + _idx * stride_row_a
out_ptr = Out + _idx * stride_row_out
# compute mean
mean = 0.0
var = 0.0
Expand All @@ -156,7 +162,9 @@ def _layer_norm_fwd_fused_single_pass(
mask = column_offset < N
# eviction policy below have little impact now because of new implementation. Kept as is.
# float32 is used to avoid overflow because of the square operation
a = tl.load(a_ptr + column_offset, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32)
a = tl.load(a_ptr + column_offset * stride_col_a, mask=mask, other=0.0, eviction_policy="evict_last").to(
tl.float32
)
if IS_RMSNORM:
var += tl.sum(a * a, axis=0)
else:
Expand Down Expand Up @@ -184,14 +192,16 @@ def _layer_norm_fwd_fused_single_pass(
weight = tl.load(Weight + column_offset, mask=mask)

# eviction policy helps to keep weights in cache (reused by other threads)
a = tl.load(a_ptr + column_offset, mask=mask, other=0.0, eviction_policy="evict_first").to(tl.float32)
a = tl.load(a_ptr + column_offset * stride_col_a, mask=mask, other=0.0, eviction_policy="evict_first").to(
tl.float32
)
a_hat = (a - mean) * rstd
out = a_hat * weight
if HAS_BIAS:
bias = tl.load(Bias + column_offset, mask=mask)
out = out + bias
# write-back
tl.store(out_ptr + column_offset, out, mask=mask)
tl.store(out_ptr + column_offset * stride_col_out, out, mask=mask)


@triton.jit
Expand All @@ -202,7 +212,10 @@ def _layer_norm_fwd_fused_multi_pass(
Bias,
Mean,
Rstd,
stride_row_out,
stride_col_out,
stride_row_a,
stride_col_a,
N,
eps,
IS_RMSNORM: tl.constexpr, # not used, just to have the same signature than the single pass
Expand All @@ -217,21 +230,21 @@ def _layer_norm_fwd_fused_multi_pass(
"""
# position of elements processed by this program
row = tl.program_id(0)
Out += row * stride_row_a
Out += row * stride_row_out
A += row * stride_row_a
# compute mean
mean = 0
_mean = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.0, eviction_policy="evict_last").to(tl.float32)
a = tl.load(A + cols * stride_col_a, mask=cols < N, other=0.0, eviction_policy="evict_last").to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# compute variance
_var = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.0, eviction_policy="evict_last").to(tl.float32)
a = tl.load(A + cols * stride_col_a, mask=cols < N, other=0.0, eviction_policy="evict_last").to(tl.float32)
a = tl.where(cols < N, a - mean, 0.0)
_var += a * a
var = tl.sum(_var, axis=0) / N
Expand All @@ -245,11 +258,11 @@ def _layer_norm_fwd_fused_multi_pass(
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(A + cols, mask=mask, other=0.0, eviction_policy="evict_first").to(tl.float32)
a = tl.load(A + cols * stride_col_a, mask=mask, other=0.0, eviction_policy="evict_first").to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# write-back
tl.store(Out + cols, out, mask=mask)
tl.store(Out + cols * stride_col_out, out, mask=mask)


class LayerNorm(torch.autograd.Function):
Expand Down Expand Up @@ -294,7 +307,10 @@ def forward(
Bias=bias if bias is not None else a_arg,
Mean=mean,
Rstd=std,
stride_row_out=out.stride(0),
stride_col_out=out.stride(1),
stride_row_a=a_arg.stride(0),
stride_col_a=a_arg.stride(1),
N=N,
eps=eps,
HAS_BIAS=bias is not None,
Expand Down
16 changes: 16 additions & 0 deletions test/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,19 @@ def test_benchmark_rms_norm(benchmark, shape: int, dtype, cuda_graphs: bool, imp

value = benchmark(fn, x)
assert_all_close(value.float(), expected, atol=1e-1)


@pytest.mark.parametrize("implementation", implementations_layer_norm.keys())
def test_stride(implementation):
M = N = 250
eps = 1e-5
factory_kwargs = {"device": "cuda", "dtype": torch.float32, "requires_grad": False}
layer_weight = torch.rand((N,), **factory_kwargs)
layer_bias = torch.randn_like(layer_weight)
x = -20 + 0.5 * torch.randn((M, N), **factory_kwargs)
x = x.transpose(-1, -2)

expected = torch.nn.functional.layer_norm(x, layer_weight.shape, layer_weight, layer_bias, eps)
fn = implementations_layer_norm[implementation](layer_weight, layer_bias, eps)
value = fn(x)
assert_all_close(value.float(), expected, atol=1e-1)

0 comments on commit d488f34

Please sign in to comment.