Skip to content

QOL improvements to float8 gemm benchmark #596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 87 additions & 31 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,39 +48,91 @@ def do_benchmarks(tops, peak_tops, f, *args, **kwargs):
return time_sec, tops_sec, pct_top_peak


def get_name_to_shapes_iter(
shape_gen_name: str,
M: Optional[int],
K: Optional[int],
N: Optional[int],
):
if shape_gen_name == 'llama':
assert M == K == N == None, \
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
bsz, seq_len = 4, 4096
M = bsz * seq_len
# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
name_to_shapes_70b = {
"attn.wqkv": (M, 8192, 1280),
"attn.w0": (M, 1024, 8192),
"ffn.w13": (M, 8192, 7168),
"ffn.w2": (M, 3584, 8192),
}
return name_to_shapes_70b.items()

elif shape_gen_name == 'square':
assert M == K == N == None, \
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
name_to_shapes = {}
min_power_of_2 = 5 # 32
max_power_of_2 = 16 # 65,536
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
val = 2 ** power_of_2
name_to_shapes[idx] = val, val, val
return name_to_shapes.items()

elif shape_gen_name == 'sweep':
assert M == K == N == None, \
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
name_to_shapes = {}
min_p2 = 5 # 32
max_p2 = 16 # 65,536
counter = 0
for M_p2 in range(min_p2, max_p2 + 1):
M = 2 ** M_p2
for K_p2 in range(min_p2, max_p2 + 1):
K = 2 ** K_p2
for N_p2 in range(min_p2, max_p2 + 1):
N = 2 ** N_p2
name_to_shapes[counter] = M, K, N
counter += 1
return name_to_shapes.items()

elif shape_gen_name == 'custom':
assert M is not None and K is not None and N is not None, \
'M, K, N must be specified for custom shape_gen'
name_to_shapes = {
1: (M, K, N),
}
return name_to_shapes.items()

raise AssertionError(f'unknown shape_gen_name {shape_gen_name}')


@torch.inference_mode()
def run(n_limit: Optional[int] = None):
def run(
n_limit: Optional[int] = None,
shape_gen_name: str = 'llama',
out_filename: Optional[str] = None,
M: Optional[int] = None,
K: Optional[int] = None,
N: Optional[int] = None,
):
device = "cuda"

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}

headers = ("name", "shape", "dtype", "ref_time_s", "fp8_time_s", "fp8_speedup")
headers = ("fast_accum", "name", "M", "K", "N", "ref_time_s", "fp8_time_s", "fp8_speedup")
results = []

name_to_shapes = name_to_shapes_70b
dtypes = torch.bfloat16, torch.float16
dtype = torch.bfloat16
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
fast_accum_vals = [True, False]

for idx, (dtype, (name, (K, N))) in enumerate(
itertools.product(dtypes, name_to_shapes.items())
):
for idx, (fast_accum, (name, (M, K, N))) in enumerate(itertools.product(fast_accum_vals, name_to_shapes)):
if n_limit is not None and idx >= n_limit:
break

# source: Xiao Sun, these are realistic for LLaMa 70B training
bsz, seq_len = 4, 4096

M = bsz * seq_len
print("M, K, N:", M, K, N)
tops = 2 * M * N * K
print(f"tops: {tops:.2E}")
print("M, K, N:", M, K, N, f"tops: {tops:.2E}")

# raw torch.mm
A = torch.randn(M, K, device=device, dtype=dtype)
Expand All @@ -99,12 +151,12 @@ def run(n_limit: Optional[int] = None):
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
A = torch.zeros(M, K, device=device, dtype=d1)
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)

def do_matmul(A, B):
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)

fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
Expand All @@ -114,22 +166,26 @@ def do_matmul(A, B):
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
)

del A, B
del A, B, scale_a, scale_b

results.append(
[
fast_accum,
name,
(M, K, N),
dtype,
M,
K,
N,
ref_time_sec,
fp8_time_sec,
ref_time_sec / fp8_time_sec,
]
)

data_pd = pd.DataFrame(results, columns=headers)
print(data_pd)
data_df = pd.DataFrame(results, columns=headers)
print(data_df)

if out_filename is not None:
data_df.to_csv(out_filename)

def main() -> None:
fire.Fire(run)
Expand Down
Loading