Skip to content

Commit d97ae74

Browse files
authored
training acceleration via runtime semi-structured sparsity (#184)
This PR adds in support for training acceleration, using runtime semi-structured sparsity kernels, which landed in core earlier: pytorch/pytorch#122350 This collects the necessary autograd functions, to support training and packages it up in a replacement `nn.Linear` modules, `SemiSparseLinear`, as well as a user API to swap out modules, `swap_linear_with_semi_sparse_linear_`. It also adds in some benchmarking code from xformers in order to measure the speedup of this module when applied to DINO shapes. We have a blog post coming out with more details about how this works. Testing: ``` python test/sparsity/test_fast_sparse_training.py ``` Benchmarking: ``` python benchmarks/benchmark_semi_sparse.py ``` For VIT-L MLP shapes we see the following results: ``` [------------------------------------------------ mlpfwbw -------------------------------------------------] | act24 | dense | w24 | s24_inp_sparsify24 | s24_inp_clone 1 threads: ------------------------------------------------------------------------------------------------- f16 (44160,1024,4096,1024) | 11881.0 | 11534.3 | 9204.7 | 255.1 | 125.8 Times are in microseconds (us). ```
1 parent f3f2ea8 commit d97ae74

File tree

7 files changed

+1363
-0
lines changed

7 files changed

+1363
-0
lines changed

benchmarks/benchmark_semi_sparse.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
6+
#
7+
# This source code is licensed under the BSD license found in the
8+
# LICENSE file in the root directory of this source tree.
9+
from typing import Tuple
10+
11+
import torch
12+
import torch.nn.functional as F
13+
from torch import nn
14+
from xformers_benchmark_utils import DTYPE2STR, benchmark_main_helper2, product_dict
15+
16+
from torchao.sparsity.training import SemiSparseLinear
17+
from torchao.sparsity.training.autograd import semi_structured_sparsify
18+
19+
min_run_time = 0.5
20+
device = torch.device("cuda")
21+
22+
CASES = list(
23+
product_dict(
24+
B_in_hidden_out_ft=[
25+
# DINO ViT-L: lg + sm crops (patch16)
26+
(64 * 2 * (14 * 14 + 1) + 64 * 8 * (6 * 6 + 1), 1024, 1024 * 4, 1024),
27+
],
28+
dtype=[torch.half],
29+
bias=[False],
30+
)
31+
)
32+
33+
class Mlp(nn.Module):
34+
LINEAR_CLS = nn.Linear
35+
36+
def __init__(
37+
self, B_in_hidden_out_ft: Tuple[int, int, int, int], dtype, bias: bool, bw: bool
38+
) -> None:
39+
B, in_ft, hid_ft, out_ft = B_in_hidden_out_ft
40+
super().__init__()
41+
self.label = "mlp"
42+
self.sub_label = (
43+
f"{DTYPE2STR[dtype]} ({B},{in_ft},{hid_ft},{out_ft}){' b' if bias else ''}"
44+
)
45+
self.fc1 = self.LINEAR_CLS(in_ft, hid_ft, bias=bias)
46+
self.act = nn.GELU()
47+
self.fc2 = self.LINEAR_CLS(hid_ft, out_ft, bias=bias)
48+
self.grad = torch.randn([B, out_ft], device="cuda", dtype=dtype)
49+
self.input = torch.randn(
50+
[B, in_ft], device="cuda", dtype=dtype, requires_grad=True
51+
)
52+
self.out = self.input
53+
self.to("cuda").to(dtype)
54+
55+
def forward(self, x):
56+
x = self.fc1(x)
57+
x = self.act(x)
58+
x = self.fc2(x)
59+
return x
60+
61+
def fw(self):
62+
self.out = self.forward(self.input)
63+
64+
def bw(self):
65+
self.out.backward(self.grad, retain_graph=True)
66+
67+
68+
class MlpAct24(Mlp):
69+
def fw(self):
70+
x = self.input
71+
x = self.fc1(x)
72+
x = semi_structured_sparsify(x)
73+
x = self.act(x)
74+
x = self.fc2(x)
75+
self.out = x
76+
77+
78+
79+
class MlpW24(Mlp):
80+
LINEAR_CLS = SemiSparseLinear
81+
82+
83+
class MicrobenchmarkBase:
84+
def __init__(
85+
self, B_in_hidden_out_ft: Tuple[int, int, int, int], dtype, bias: bool, bw: bool
86+
) -> None:
87+
B, in_ft, hid_ft, out_ft = B_in_hidden_out_ft
88+
super().__init__()
89+
self.label = "mlp"
90+
self.sub_label = (
91+
f"{DTYPE2STR[dtype]} ({B},{in_ft},{hid_ft},{out_ft}){' b' if bias else ''}"
92+
)
93+
self.input = torch.randn(
94+
[B, in_ft], device="cuda", dtype=dtype, requires_grad=True
95+
)
96+
self.input_colMajor = self.input.t().contiguous().t()
97+
self.input_sp = semi_structured_sparsify(self.input)
98+
99+
def bw(self) -> None:
100+
return None
101+
102+
103+
class MicrobenchmarkSparsify24(MicrobenchmarkBase):
104+
def fw(self) -> torch.Tensor:
105+
semi_structured_sparsify(self.input)
106+
return self.input
107+
108+
109+
class MicrobenchmarkInputClone(MicrobenchmarkBase):
110+
def fw(self) -> torch.Tensor:
111+
self.input.clone()
112+
return self.input
113+
114+
115+
functions = {
116+
"act24": MlpAct24,
117+
"dense": Mlp,
118+
"w24": MlpW24,
119+
"s24_inp_sparsify24": MicrobenchmarkSparsify24,
120+
"s24_inp_clone": MicrobenchmarkInputClone,
121+
}
122+
benchmark_main_helper2(
123+
"sp24_fwbw",
124+
fw=True,
125+
bw=True,
126+
cases=CASES,
127+
functions=functions,
128+
min_run_time=min_run_time,
129+
)

0 commit comments

Comments
 (0)