Skip to content

Commit de4a1fb

Browse files
authored
QOL improvements to float8 gemm benchmark (#596)
Summary: 1. add more options for shape generation, such as - square: M == K == N sweeping through powers of 2 - sweep: M, K, N each sweeping through powers of 2 - custom: user specifies a single value of M, K, N 2. fix a bug when calling `torch._scaled_mm`, we should create the scales outside the benchmark for a less biased result 3. add sweep over `fast_accum` setting 4. add ability to save result to file, for easy analysis later Test Plan: ``` time python benchmarks/float8/bench_matmul.py --out_filename ~/local/tmp/20240803_f8_gemm_sweep_2.csv --shape_gen_name sweep // result: https://gist.github.com/vkuzo/1d82e84ddd8aac8166695d819ebc8883 ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 1328787 commit de4a1fb

File tree

1 file changed

+87
-31
lines changed

1 file changed

+87
-31
lines changed

benchmarks/float8/bench_matmul.py

Lines changed: 87 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -48,39 +48,91 @@ def do_benchmarks(tops, peak_tops, f, *args, **kwargs):
4848
return time_sec, tops_sec, pct_top_peak
4949

5050

51+
def get_name_to_shapes_iter(
52+
shape_gen_name: str,
53+
M: Optional[int],
54+
K: Optional[int],
55+
N: Optional[int],
56+
):
57+
if shape_gen_name == 'llama':
58+
assert M == K == N == None, \
59+
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
60+
bsz, seq_len = 4, 4096
61+
M = bsz * seq_len
62+
# LLaMa 2 70B single-node weight shapes
63+
# assumes fused attn.wqkv and ffn.w13
64+
# source: https://fburl.com/gsheet/g8onr7rh
65+
name_to_shapes_70b = {
66+
"attn.wqkv": (M, 8192, 1280),
67+
"attn.w0": (M, 1024, 8192),
68+
"ffn.w13": (M, 8192, 7168),
69+
"ffn.w2": (M, 3584, 8192),
70+
}
71+
return name_to_shapes_70b.items()
72+
73+
elif shape_gen_name == 'square':
74+
assert M == K == N == None, \
75+
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
76+
name_to_shapes = {}
77+
min_power_of_2 = 5 # 32
78+
max_power_of_2 = 16 # 65,536
79+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
80+
val = 2 ** power_of_2
81+
name_to_shapes[idx] = val, val, val
82+
return name_to_shapes.items()
83+
84+
elif shape_gen_name == 'sweep':
85+
assert M == K == N == None, \
86+
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
87+
name_to_shapes = {}
88+
min_p2 = 5 # 32
89+
max_p2 = 16 # 65,536
90+
counter = 0
91+
for M_p2 in range(min_p2, max_p2 + 1):
92+
M = 2 ** M_p2
93+
for K_p2 in range(min_p2, max_p2 + 1):
94+
K = 2 ** K_p2
95+
for N_p2 in range(min_p2, max_p2 + 1):
96+
N = 2 ** N_p2
97+
name_to_shapes[counter] = M, K, N
98+
counter += 1
99+
return name_to_shapes.items()
100+
101+
elif shape_gen_name == 'custom':
102+
assert M is not None and K is not None and N is not None, \
103+
'M, K, N must be specified for custom shape_gen'
104+
name_to_shapes = {
105+
1: (M, K, N),
106+
}
107+
return name_to_shapes.items()
108+
109+
raise AssertionError(f'unknown shape_gen_name {shape_gen_name}')
110+
111+
51112
@torch.inference_mode()
52-
def run(n_limit: Optional[int] = None):
113+
def run(
114+
n_limit: Optional[int] = None,
115+
shape_gen_name: str = 'llama',
116+
out_filename: Optional[str] = None,
117+
M: Optional[int] = None,
118+
K: Optional[int] = None,
119+
N: Optional[int] = None,
120+
):
53121
device = "cuda"
54122

55-
# LLaMa 2 70B single-node weight shapes
56-
# assumes fused attn.wqkv and ffn.w13
57-
# source: https://fburl.com/gsheet/g8onr7rh
58-
name_to_shapes_70b = {
59-
"attn.wqkv": (8192, 1280),
60-
"attn.w0": (1024, 8192),
61-
"ffn.w13": (8192, 7168),
62-
"ffn.w2": (3584, 8192),
63-
}
64-
65-
headers = ("name", "shape", "dtype", "ref_time_s", "fp8_time_s", "fp8_speedup")
123+
headers = ("fast_accum", "name", "M", "K", "N", "ref_time_s", "fp8_time_s", "fp8_speedup")
66124
results = []
67125

68-
name_to_shapes = name_to_shapes_70b
69-
dtypes = torch.bfloat16, torch.float16
126+
dtype = torch.bfloat16
127+
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
128+
fast_accum_vals = [True, False]
70129

71-
for idx, (dtype, (name, (K, N))) in enumerate(
72-
itertools.product(dtypes, name_to_shapes.items())
73-
):
130+
for idx, (fast_accum, (name, (M, K, N))) in enumerate(itertools.product(fast_accum_vals, name_to_shapes)):
74131
if n_limit is not None and idx >= n_limit:
75132
break
76133

77-
# source: Xiao Sun, these are realistic for LLaMa 70B training
78-
bsz, seq_len = 4, 4096
79-
80-
M = bsz * seq_len
81-
print("M, K, N:", M, K, N)
82134
tops = 2 * M * N * K
83-
print(f"tops: {tops:.2E}")
135+
print("M, K, N:", M, K, N, f"tops: {tops:.2E}")
84136

85137
# raw torch.mm
86138
A = torch.randn(M, K, device=device, dtype=dtype)
@@ -99,12 +151,12 @@ def run(n_limit: Optional[int] = None):
99151
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
100152
A = torch.zeros(M, K, device=device, dtype=d1)
101153
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
154+
scale_a = torch.tensor([1.0], device=device)
155+
scale_b = torch.tensor([1.0], device=device)
102156

103157
def do_matmul(A, B):
104-
scale_a = torch.tensor([1.0], device=device)
105-
scale_b = torch.tensor([1.0], device=device)
106158
return torch._scaled_mm(
107-
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
159+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
108160
)
109161

110162
fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
@@ -114,22 +166,26 @@ def do_matmul(A, B):
114166
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
115167
)
116168

117-
del A, B
169+
del A, B, scale_a, scale_b
118170

119171
results.append(
120172
[
173+
fast_accum,
121174
name,
122-
(M, K, N),
123-
dtype,
175+
M,
176+
K,
177+
N,
124178
ref_time_sec,
125179
fp8_time_sec,
126180
ref_time_sec / fp8_time_sec,
127181
]
128182
)
129183

130-
data_pd = pd.DataFrame(results, columns=headers)
131-
print(data_pd)
184+
data_df = pd.DataFrame(results, columns=headers)
185+
print(data_df)
132186

187+
if out_filename is not None:
188+
data_df.to_csv(out_filename)
133189

134190
def main() -> None:
135191
fire.Fire(run)

0 commit comments

Comments
 (0)