Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit ec8b46c

Browse files
vkuzofacebook-github-bot
authored andcommitted
fixes to matmul and linear benchmarks (#320)
Summary: Pull Request resolved: #320 for matmul benchmarks, unbreaks them - we need the scales to be fp32, not integers for linear benchmarks, aligns default settings to current best supported path (compile on, dynamic scaling) Reviewed By: awgu Differential Revision: D59877198 fbshipit-source-id: 092daaffeb0096f9fbd12ca407701bc3aa80c97c
1 parent e6bb1eb commit ec8b46c

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

benchmarks/bench_linear_float8.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ def float8_pct_top_peak(self):
9191

9292
def main(
9393
sweep_path: Optional[Path] = None,
94-
compile: bool = False,
94+
compile: bool = True,
9595
n_limit: Optional[int] = None,
9696
fast_accum_filter: Optional[bool] = None,
9797
shape_name_filter: Optional[str] = None,
98-
scaling_type_x: str = "delayed",
99-
scaling_type_w: str = "delayed",
100-
scaling_type_dL_dY: str = "delayed",
98+
scaling_type_x: str = "dynamic",
99+
scaling_type_w: str = "dynamic",
100+
scaling_type_dL_dY: str = "dynamic",
101101
):
102102
device = "cuda"
103103
print(f"Compile is set to | {compile}")
@@ -274,7 +274,7 @@ def wrapper(*args, **kwargs):
274274
def invoke_main() -> None:
275275
parser = argparse.ArgumentParser()
276276
parser.add_argument("-o", "--output_path", type=str, required=False)
277-
parser.add_argument("--compile", action="store_true")
277+
parser.add_argument("--disable_compile", action="store_true")
278278
parser.add_argument("-n", "--n_limit", type=int, required=False)
279279
parser.add_argument("--fast_accum_filter", type=bool, required=False)
280280
parser.add_argument("--shape_name_filter", type=str, required=False)
@@ -292,7 +292,7 @@ def invoke_main() -> None:
292292
kwargs["scaling_type_dL_dY"] = args.scaling_type_dL_dY
293293
main(
294294
output_path,
295-
args.compile,
295+
not args.disable_compile,
296296
args.n_limit,
297297
args.fast_accum_filter,
298298
args.shape_name_filter,

benchmarks/bench_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def run(n_limit: Optional[int] = None):
101101
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
102102

103103
def do_matmul(A, B):
104-
scale_a = torch.tensor([1], device=device)
105-
scale_b = torch.tensor([1], device=device)
104+
scale_a = torch.tensor([1.0], device=device)
105+
scale_b = torch.tensor([1.0], device=device)
106106
return torch._scaled_mm(
107107
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
108108
)

0 commit comments

Comments
 (0)