Skip to content

Commit a22d36f

Browse files
committed
Fixes to comments
1 parent 83af1d7 commit a22d36f

5 files changed

+230
-144
lines changed

benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_lhs.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ class ExperimentConfig:
3333
@dataclass(frozen=True)
3434
class ExperimentResult:
3535
# time
36-
naive_us: float
36+
torch_us: float
3737
triton_us: float
3838
# mem bw
39-
naive_gbps: float
39+
torch_gbps: float
4040
triton_gbps: float
4141

4242

@@ -58,6 +58,10 @@ def get_configs() -> List[ExperimentConfig]:
5858
(2048, 4096),
5959
(4096, 4096),
6060
(8192, 4096),
61+
(16384, 4096),
62+
(32768, 4096),
63+
(65536, 4096),
64+
(131_072, 4096),
6165
]
6266

6367
configs = []
@@ -79,35 +83,49 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
7983
block_size = config.block_size
8084

8185
def verify_outputs(
82-
y_naive: torch.Tensor,
83-
s_naive: torch.Tensor,
86+
y_torch: torch.Tensor,
87+
s_torch: torch.Tensor,
8488
y_triton: torch.Tensor,
8589
s_triton: torch.Tensor,
8690
rtol: float = 1e-2,
8791
atol: float = 1e-2,
8892
):
89-
"""Verify that Triton and naive implementations produce similar results."""
93+
"""Verify that Triton and torch implementations produce similar results."""
9094

9195
# Convert FP8 back to float for comparison
92-
y_naive_float = y_naive.to(torch.float32)
96+
y_torch_float = y_torch.to(torch.float32)
9397
y_triton_float = y_triton.to(torch.float32)
9498

99+
assert y_torch.shape == y_triton.shape, (
100+
f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}"
101+
)
102+
assert y_torch.stride() == y_triton.stride(), (
103+
f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}"
104+
)
105+
106+
assert s_torch.shape == s_triton.shape, (
107+
f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}"
108+
)
109+
assert s_torch.stride() == s_triton.stride(), (
110+
f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}"
111+
)
112+
95113
# Check quantized values are close
96114

97115
torch.testing.assert_close(
98-
y_naive_float,
116+
y_torch_float,
99117
y_triton_float,
100118
rtol=rtol,
101119
atol=atol,
102-
msg="Quantized values differ between naive and Triton implementations",
120+
msg="Quantized values differ between torch and Triton implementations",
103121
)
104122

105123
torch.testing.assert_close(
106-
s_naive,
124+
s_torch,
107125
s_triton,
108126
rtol=rtol,
109127
atol=atol,
110-
msg="Scales differ between naive and Triton implementations",
128+
msg="Scales differ between torch and Triton implementations",
111129
)
112130

113131
input_tensor = torch.randn(
@@ -117,11 +135,11 @@ def verify_outputs(
117135
device=device,
118136
)
119137

120-
# Benchmark naive implementation
121-
naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs)
122-
y_naive, s_naive = naive_impl_c(input_tensor, block_size)
123-
naive_time_us = benchmark_cuda_function_in_microseconds(
124-
naive_impl_c,
138+
# Benchmark torch implementation
139+
torch_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs)
140+
y_torch, s_torch = torch_impl_c(input_tensor, block_size)
141+
torch_time_us = benchmark_cuda_function_in_microseconds(
142+
torch_impl_c,
125143
input_tensor,
126144
block_size,
127145
)
@@ -135,7 +153,7 @@ def verify_outputs(
135153
)
136154

137155
# Verify correctness (optional, can comment out for pure benchmarking)
138-
verify_outputs(y_naive, s_naive, y_triton, s_triton)
156+
verify_outputs(y_torch, s_torch, y_triton, s_triton)
139157

140158
# Memory bandwidth calculations
141159
bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8
@@ -147,13 +165,13 @@ def verify_outputs(
147165
y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el
148166
)
149167

150-
naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6)
168+
torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
151169
triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)
152170

153171
return ExperimentResult(
154-
naive_us=naive_time_us,
172+
torch_us=torch_time_us,
155173
triton_us=triton_time_us,
156-
naive_gbps=naive_gbps,
174+
torch_gbps=torch_gbps,
157175
triton_gbps=triton_gbps,
158176
)
159177

@@ -162,23 +180,23 @@ def print_results(experiments: List[Experiment]):
162180
headers = [
163181
"input_shape (M, K)",
164182
"block_size",
165-
"naive_us",
183+
"torch_us",
166184
"triton_us",
167185
"speedup",
168-
"naive_gbps",
186+
"torch_gbps",
169187
"triton_gbps",
170188
]
171189
rows = []
172190
for experiment in experiments:
173-
speedup = experiment.result.naive_us / experiment.result.triton_us
191+
speedup = experiment.result.torch_us / experiment.result.triton_us
174192
rows.append(
175193
[
176194
f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}",
177195
experiment.config.block_size,
178-
f"{experiment.result.naive_us:.2f}",
196+
f"{experiment.result.torch_us:.2f}",
179197
f"{experiment.result.triton_us:.2f}",
180198
f"{speedup:.2f}x",
181-
f"{experiment.result.naive_gbps:.1f}",
199+
f"{experiment.result.torch_gbps:.1f}",
182200
f"{experiment.result.triton_gbps:.1f}",
183201
]
184202
)

benchmarks/prototype/blockwise_fp8_training/bench_triton_fp8_blockwise_act_quant_rhs.py

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ class ExperimentConfig:
2828
@dataclass(frozen=True)
2929
class ExperimentResult:
3030
# time
31-
naive_us: float
31+
torch_us: float
3232
triton_us: float
3333
# mem bw
34-
naive_gbps: float
34+
torch_gbps: float
3535
triton_gbps: float
3636

3737

@@ -52,6 +52,10 @@ def get_configs() -> List[ExperimentConfig]:
5252
(2048, 4096),
5353
(4096, 4096),
5454
(8192, 4096),
55+
(16384, 4096),
56+
(32768, 4096),
57+
(65536, 4096),
58+
(131_072, 4096),
5559
]
5660

5761
configs = []
@@ -72,11 +76,11 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
7276
M, K = config.input_shape
7377
block_size = config.block_size
7478

75-
def naive_fp8_blockwise_quant(
79+
def torch_fp8_blockwise_quant(
7680
x: torch.Tensor, block_size: int = 128
7781
) -> Tuple[torch.Tensor, torch.Tensor]:
7882
"""
79-
Naive PyTorch reference implementation for RHS blockwise FP8 quantization.
83+
Torch reference implementation for RHS blockwise FP8 quantization.
8084
8185
RHS semantics:
8286
• Groups are (block_size x 1) along the M dimension (rows).
@@ -129,33 +133,47 @@ def naive_fp8_blockwise_quant(
129133
return y, s
130134

131135
def verify_outputs(
132-
y_naive: torch.Tensor,
133-
s_naive: torch.Tensor,
136+
y_torch: torch.Tensor,
137+
s_torch: torch.Tensor,
134138
y_triton: torch.Tensor,
135139
s_triton: torch.Tensor,
136140
rtol: float = 1e-2,
137141
atol: float = 1e-2,
138142
):
139-
"""Verify that Triton and naive implementations produce similar results."""
143+
"""Verify that Triton and torch implementations produce similar results."""
140144

141145
# Quantized tensors (both are column-major; convert to float to compare)
142-
y_naive_float = y_naive.to(torch.float32)
146+
y_torch_float = y_torch.to(torch.float32)
143147
y_triton_float = y_triton.to(torch.float32)
144148

149+
assert y_torch.shape == y_triton.shape, (
150+
f"Output shape mismatch: torch {y_torch.shape} vs triton {y_triton.shape}"
151+
)
152+
assert y_torch.stride() == y_triton.stride(), (
153+
f"Output stride mismatch: torch {y_torch.stride()} vs triton {y_triton.stride()}"
154+
)
155+
156+
assert s_torch.shape == s_triton.shape, (
157+
f"Scale shape mismatch: torch {s_torch.shape} vs triton {s_triton.shape}"
158+
)
159+
assert s_torch.stride() == s_triton.stride(), (
160+
f"Scale stride mismatch: torch {s_torch.stride()} vs triton {s_triton.stride()}"
161+
)
162+
145163
torch.testing.assert_close(
146-
y_naive_float,
164+
y_torch_float,
147165
y_triton_float,
148166
rtol=rtol,
149167
atol=atol,
150-
msg="Quantized values differ between naive and Triton implementations",
168+
msg="Quantized values differ between torch and Triton implementations",
151169
)
152170

153171
torch.testing.assert_close(
154-
s_naive,
172+
s_torch,
155173
s_triton,
156174
rtol=rtol,
157175
atol=atol,
158-
msg="Scales differ between naive and Triton implementations",
176+
msg="Scales differ between torch and Triton implementations",
159177
)
160178

161179
input_tensor = torch.randn(
@@ -166,12 +184,12 @@ def verify_outputs(
166184
)
167185

168186
# Compile once
169-
naive_impl_c = torch.compile(naive_fp8_blockwise_quant)
187+
torch_impl_c = torch.compile(torch_fp8_blockwise_quant)
170188

171-
# Benchmark naive implementation
172-
y_naive, s_naive = naive_impl_c(input_tensor, block_size)
173-
naive_time_us = benchmark_cuda_function_in_microseconds(
174-
naive_impl_c,
189+
# Benchmark torch implementation
190+
y_torch, s_torch = torch_impl_c(input_tensor, block_size)
191+
torch_time_us = benchmark_cuda_function_in_microseconds(
192+
torch_impl_c,
175193
input_tensor,
176194
block_size,
177195
)
@@ -184,8 +202,8 @@ def verify_outputs(
184202
block_size,
185203
)
186204

187-
# Verify correctness (compare to naive)
188-
verify_outputs(y_naive, s_naive, y_triton, s_triton)
205+
# Verify correctness (compare to torch)
206+
verify_outputs(y_torch, s_torch, y_triton, s_triton)
189207

190208
# Memory bandwidth calculations
191209
bytes_per_input_el = torch.finfo(input_tensor.dtype).bits / 8
@@ -197,13 +215,13 @@ def verify_outputs(
197215
y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el
198216
)
199217

200-
naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6)
218+
torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
201219
triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)
202220

203221
return ExperimentResult(
204-
naive_us=naive_time_us,
222+
torch_us=torch_time_us,
205223
triton_us=triton_time_us,
206-
naive_gbps=naive_gbps,
224+
torch_gbps=torch_gbps,
207225
triton_gbps=triton_gbps,
208226
)
209227

@@ -212,23 +230,23 @@ def print_results(experiments: List[Experiment]):
212230
headers = [
213231
"input_shape (M, K)",
214232
"block_size",
215-
"naive_us",
233+
"torch_us",
216234
"triton_us",
217235
"speedup",
218-
"naive_gbps",
236+
"torch_gbps",
219237
"triton_gbps",
220238
]
221239
rows = []
222240
for experiment in experiments:
223-
speedup = experiment.result.naive_us / experiment.result.triton_us
241+
speedup = experiment.result.torch_us / experiment.result.triton_us
224242
rows.append(
225243
[
226244
f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}",
227245
experiment.config.block_size,
228-
f"{experiment.result.naive_us:.2f}",
246+
f"{experiment.result.torch_us:.2f}",
229247
f"{experiment.result.triton_us:.2f}",
230248
f"{speedup:.2f}x",
231-
f"{experiment.result.naive_gbps:.1f}",
249+
f"{experiment.result.torch_gbps:.1f}",
232250
f"{experiment.result.triton_gbps:.1f}",
233251
]
234252
)

0 commit comments

Comments
 (0)