Skip to content

Commit

Permalink
[perf][refactor] Fused linear rewrite (facebookresearch#42)
Browse files Browse the repository at this point in the history
* removing the config_matmul file, not that useful
cleaning up dimension handling

cleaning up all the pointers
typos and docstrings
flattening the batch dim, preparing a fused kernel for grad bias and grad weight

fusing the grad weight and grad bias

deactivating the bias computation in the second fused bw kernel for now

optional masking, perf still a mixed bag, squarely because of the bw kernels

dedicated summation kernel

simplicity wins

complete rewrite, ditching the semi-fusion, kernel too big and not useful

bugfix + better flop computations for the bw

* new iteration, saving the inputs but cleaner back processing

* trying to improve on the comments

* mask all the loads

* fixing a type issue with squared relu

* cannot repro the crash locally on an ampere machine
  • Loading branch information
blefaudeux authored Nov 9, 2021
1 parent a56c542 commit 03d58e1
Show file tree
Hide file tree
Showing 21 changed files with 430 additions and 486 deletions.
23 changes: 18 additions & 5 deletions BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,28 @@ Note that in the Triton case the slowdowns at extreme sizes are because of regis
### Fused linear layer

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_fused_linear_layer.py`. The units are TFlops/s. These results are for a nVidia V100, Triton 1.1 and PyTorch 1.9.
**As of October 2021, these Triton kernelsonly competitive with Pytorch for float16 inference, this is a work in progress**.

![Fused linear layers throughput in fp16 - inference](docs/plots/FusedLinear_fp16_FW.png)
**As of October 2021, these Triton kernels are only competitive with Pytorch in float16, this is a work in progress**.

![Fused linear layers throughput in fp16 - training](docs/plots/FusedLinear_fp16_FW_BW.png)
![Fused linear layers throughput in fp16 - inference](docs/plots/FusedLinear_fp16_FW_gelu.png)

![Fused linear layers throughput in fp32 - inference](docs/plots/FusedLinear_fp32_FW.png)
![Fused linear layers throughput in fp16 - training](docs/plots/FusedLinear_fp16_FW_BW_gelu.png)

![Fused linear layers throughput in fp32 - training](docs/plots/FusedLinear_fp32_FW_BW.png)
![Fused linear layers throughput in fp16 - inference](docs/plots/FusedLinear_fp16_FW_relu.png)

![Fused linear layers throughput in fp16 - training](docs/plots/FusedLinear_fp16_FW_BW_relu.png)

![Fused linear layers throughput in fp16 - inference](docs/plots/FusedLinear_fp16_FW_leaky_relu.png)

![Fused linear layers throughput in fp16 - training](docs/plots/FusedLinear_fp16_FW_BW_leaky_relu.png)

![Fused linear layers throughput in fp16 - inference](docs/plots/FusedLinear_fp16_FW_squared_relu.png)

![Fused linear layers throughput in fp16 - training](docs/plots/FusedLinear_fp16_FW_BW_squared_relu.png)

![Fused linear layers throughput in fp16 - inference](docs/plots/FusedLinear_fp16_FW_none.png)

![Fused linear layers throughput in fp16 - training](docs/plots/FusedLinear_fp16_FW_BW_none.png)

### Fused layer norm

Expand Down
Binary file added docs/plots/FusedLinear_fp16_FW_BW_gelu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/FusedLinear_fp16_FW_BW_leaky_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/FusedLinear_fp16_FW_BW_none.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/FusedLinear_fp16_FW_BW_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/FusedLinear_fp16_FW_gelu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/FusedLinear_fp16_FW_leaky_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/FusedLinear_fp16_FW_none.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/FusedLinear_fp16_FW_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/FusedLinear_fp16_FW_squared_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/plots/FusedLinear_fp32_FW.png
Binary file not shown.
Binary file removed docs/plots/FusedLinear_fp32_FW_BW.png
Binary file not shown.
10 changes: 5 additions & 5 deletions tests/test_triton_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
try:
from xformers.triton import FusedLinear
from xformers.triton.activations import get_triton_activation_kernel
from xformers.triton.k_fused_matmul import fused_matmul
from xformers.triton.k_fused_matmul_fw import fused_matmul
from xformers.triton.utils import gpu_capabilities_older_than_70

except ImportError:
Expand Down Expand Up @@ -129,24 +129,24 @@ def test_fused_linear_parity(shape, activation: Activation, bias: bool, amp: boo
loss_triton = torch.norm(y_triton)
loss_triton.backward()

assert torch.allclose(X, X_, atol=tolerance), f"{X[:,0,0]} vs. {X_[:,0,0]}"
assert torch.allclose(X, X_, atol=tolerance), f"{X} vs. {X_}"

# Input grad being correct checks both the loss + some of the backward pass
assert torch.allclose(
X.grad, X_.grad, atol=tolerance
), f"{X.grad[:,0,0]} vs. {X_.grad[:,0,0]}"
), f"{X.grad} vs. {X_.grad}"

# Check that the linear layer bias are also properly trainable
if bias:
assert triton_fused_linear.bias is not None
assert triton_fused_linear.bias.grad is not None
assert torch.allclose(
torch_linear.bias.grad, triton_fused_linear.bias.grad, atol=tolerance
)
), f"{torch_linear.bias.grad} vs. {triton_fused_linear.bias.grad}"

# Check that the linear layer weights are also properly trainable
assert torch.allclose(
torch_linear.weight.grad,
triton_fused_linear.weight.grad,
atol=tolerance,
)
), f"{torch_linear.weight.grad} vs. {triton_fused_linear.weight.grad}"
50 changes: 30 additions & 20 deletions xformers/benchmarks/benchmark_triton_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from xformers.triton.fused_linear_layer import FusedLinear

SHAPES = [
(8, 256, 512),
(8, 512, 1024),
(4, 1024, 1024),
(2, 2048, 2048),
(2, 4096, 4096),
(8, 512, 256), # Batch x Seq x Embedding
(8, 512, 512),
(4, 512, 1024),
(2, 512, 2048),
(2, 512, 4096),
(2, 512, 8192),
]


Expand All @@ -40,7 +41,15 @@ def get_metrics_transform(
if backward:
flop *= 2

# optional weight on top
# backward will also output a gradient with respect to the bias
# which consolidates on all the activation gradient
flop += a.shape[0] * a.shape[1] * w.shape[1]

# backward will also ouput another gradient with respect to the weight,
# which is another matmul, in between the grad_out and the inputs this time
flop += a.shape[0] * a.shape[1] * w.shape[1] * (2 * a.shape[2] - 1)

# optional bias on top
if b is not None:
flop += b.numel()

Expand All @@ -58,11 +67,11 @@ def bench_linear(activations: List[Optional[Activation]]):
torch.float16,
torch.float32,
]:
for backward in [False, True]:

results: Dict[str, Any] = {}
for backward in [True, False]:

for activation in activations:
results: Dict[str, Any] = {}

for bias in [False, True]:
for B, M, K in SHAPES:
a = torch.rand(
Expand Down Expand Up @@ -129,17 +138,18 @@ def triton_step(x):
metric = metrics_transform(time)
results[key][testcase.name] = f"{metric:.1f}"

pretty_print(
results,
title="\n --- Type: {} ---".format(dtype),
units="TFlops/s",
)

_type = "_fp16" if dtype == torch.float16 else "_fp32"
title = "FusedLinear" + _type + "_FW"
if backward:
title += "_BW"
pretty_plot(results, title, "TFlops/s", dash_key="pytorch")
pretty_print(
results,
title="\n --- Type: {} ---".format(dtype),
units="TFlops/s",
)

_type = "_fp16" if dtype == torch.float16 else "_fp32"
title = "FusedLinear" + _type + "_FW"
if backward:
title += "_BW"
title += "_" + activation.value if activation else "_none"
pretty_plot(results, title, "TFlops/s", dash_key="pytorch")


activations = [ac for ac in Activation] + [None] # type: ignore
Expand Down
34 changes: 20 additions & 14 deletions xformers/triton/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@ def relu(x):

@triton.jit
def relu_grad(x):
return tl.where(x >= 0, 1.0, 0.0)
# ReLU is different from other activations
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
zero = 0.0
zero = zero.to(x.dtype)
return tl.where(x >= 0, x, zero)


@triton.jit
Expand Down Expand Up @@ -96,12 +101,19 @@ def leaky_relu(x):
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
"""
scale = 0.01 + 0.0
scale = scale.to(x.dtype)
return tl.where(x >= 0, x, scale * x)


@triton.jit
def leaky_relu_grad(x):
return tl.where(x >= 0, 1.0, 0.01)
min_grad = 0.01
max_grad = 1

min_grad = min_grad.to(x.dtype)
max_grad = max_grad.to(x.dtype)

return tl.where(x >= 0, max_grad, min_grad)


@triton.jit
Expand All @@ -116,15 +128,9 @@ def gelu(x):

@triton.jit
def gelu_grad(x):
# Normal computation, just try to maximize reuse
x_3 = x * x * x
_a = 0.0356774 * x_3 + _kAlpha * x

# (hoping that a division is cheaper than an exponential..)
exp_a = tl.exp(_a)
exp_m_a = 1.0 / exp_a

_cos_h = exp_a + exp_m_a
_tan_h = (exp_a - exp_m_a) / _cos_h
_cos_h *= 0.5
return 0.5 + 0.5 * _tan_h + (0.0535161 * x_3 + 0.398942 * x) / (_cos_h * _cos_h)
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
return 0.5 * x * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
) + 0.5 * (1 + tanh_out)
46 changes: 0 additions & 46 deletions xformers/triton/configs_matmul.py

This file was deleted.

27 changes: 13 additions & 14 deletions xformers/triton/fused_linear_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
get_triton_activation_bwd_kernel,
get_triton_activation_kernel,
)
from xformers.triton.k_fused_matmul import fused_matmul, fused_matmul_backward
from xformers.triton.k_fused_matmul_bw import fused_matmul_backward
from xformers.triton.k_fused_matmul_fw import fused_matmul

# The following activations require their inputs to be saved to be able to compute their gradients
_requires_bwd_inputs = [
Activation.GeLU,
Activation.SquaredReLU,
Activation.LeakyReLU,
]


Expand All @@ -34,9 +36,9 @@ def forward(
bias,
activation,
act_grad_kernel,
save_activation_inputs,
trainable_weight,
trainable_bias,
save_activation_inputs,
):

# Kick the fused Triton kernel, handling bias and activation in one go
Expand All @@ -47,14 +49,12 @@ def forward(
ctx.activation_grad_kernel = act_grad_kernel
ctx.trainable_weight = trainable_weight
ctx.trainable_bias = trainable_bias

ctx.save_activation_inputs = save_activation_inputs

# Micro-optimization: saving these is not always needed (?)
if x.requires_grad or ctx.trainable_weight or ctx.trainable_bias:
if ctx.trainable_weight:
ctx.save_for_backward(weight, activation_inputs, x)
else:
ctx.save_for_backward(weight, None, None)
ctx.save_for_backward(weight, activation_inputs, x)

return y

Expand All @@ -64,21 +64,20 @@ def backward(ctx: Any, grad_out: torch.Tensor) -> Any: # type: ignore
"""
Compute the derivative with respect to x, other tensors were not trainable inputs.
"""
(weight, activation_inputs, inputs) = ctx.saved_tensors
(weight, activation_inputs, x) = ctx.saved_tensors

# Kick the fused Triton kernel, handling transpose and activation gradient in one go
grad_input, grad_weight, grad_bias = fused_matmul_backward(
grad_out=grad_out,
inputs=inputs,
inputs=x,
act_in=activation_inputs,
weight=weight,
trainable_weight=ctx.trainable_weight,
trainable_bias=ctx.trainable_bias,
activation_inputs=activation_inputs,
activation_grad=ctx.activation_grad_kernel,
activation_grad_req_inputs=ctx.save_activation_inputs,
act_requires_input=ctx.save_activation_inputs,
)

return grad_input, grad_weight, grad_bias, None, None, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None


class FusedLinear(nn.Module):
Expand Down Expand Up @@ -112,10 +111,10 @@ def __init__(

self._activation_kernel = get_triton_activation_kernel(activation)
self._activation_grad_kernel = get_triton_activation_bwd_kernel(activation)

self._save_activation_inputs = (
activation in _requires_bwd_inputs if activation is not None else False
)

self.reset_parameters()

def reset_parameters(self) -> None:
Expand All @@ -132,7 +131,7 @@ def forward(self, x):
self.bias,
self._activation_kernel,
self._activation_grad_kernel,
self._save_activation_inputs,
self.weight.requires_grad,
self.bias.requires_grad if self.bias is not None else False,
self._save_activation_inputs,
)
Loading

0 comments on commit 03d58e1

Please sign in to comment.