@@ -28,10 +28,10 @@ class ExperimentConfig:
2828@dataclass (frozen = True )
2929class ExperimentResult :
3030 # time
31- naive_us : float
31+ torch_us : float
3232 triton_us : float
3333 # mem bw
34- naive_gbps : float
34+ torch_gbps : float
3535 triton_gbps : float
3636
3737
@@ -52,6 +52,10 @@ def get_configs() -> List[ExperimentConfig]:
5252 (2048 , 4096 ),
5353 (4096 , 4096 ),
5454 (8192 , 4096 ),
55+ (16384 , 4096 ),
56+ (32768 , 4096 ),
57+ (65536 , 4096 ),
58+ (131_072 , 4096 ),
5559 ]
5660
5761 configs = []
@@ -72,11 +76,11 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
7276 M , K = config .input_shape
7377 block_size = config .block_size
7478
75- def naive_fp8_blockwise_quant (
79+ def torch_fp8_blockwise_quant (
7680 x : torch .Tensor , block_size : int = 128
7781 ) -> Tuple [torch .Tensor , torch .Tensor ]:
7882 """
79- Naive PyTorch reference implementation for RHS blockwise FP8 quantization.
83+ Torch reference implementation for RHS blockwise FP8 quantization.
8084
8185 RHS semantics:
8286 • Groups are (block_size x 1) along the M dimension (rows).
@@ -129,33 +133,47 @@ def naive_fp8_blockwise_quant(
129133 return y , s
130134
131135 def verify_outputs (
132- y_naive : torch .Tensor ,
133- s_naive : torch .Tensor ,
136+ y_torch : torch .Tensor ,
137+ s_torch : torch .Tensor ,
134138 y_triton : torch .Tensor ,
135139 s_triton : torch .Tensor ,
136140 rtol : float = 1e-2 ,
137141 atol : float = 1e-2 ,
138142 ):
139- """Verify that Triton and naive implementations produce similar results."""
143+ """Verify that Triton and torch implementations produce similar results."""
140144
141145 # Quantized tensors (both are column-major; convert to float to compare)
142- y_naive_float = y_naive .to (torch .float32 )
146+ y_torch_float = y_torch .to (torch .float32 )
143147 y_triton_float = y_triton .to (torch .float32 )
144148
149+ assert y_torch .shape == y_triton .shape , (
150+ f"Output shape mismatch: torch { y_torch .shape } vs triton { y_triton .shape } "
151+ )
152+ assert y_torch .stride () == y_triton .stride (), (
153+ f"Output stride mismatch: torch { y_torch .stride ()} vs triton { y_triton .stride ()} "
154+ )
155+
156+ assert s_torch .shape == s_triton .shape , (
157+ f"Scale shape mismatch: torch { s_torch .shape } vs triton { s_triton .shape } "
158+ )
159+ assert s_torch .stride () == s_triton .stride (), (
160+ f"Scale stride mismatch: torch { s_torch .stride ()} vs triton { s_triton .stride ()} "
161+ )
162+
145163 torch .testing .assert_close (
146- y_naive_float ,
164+ y_torch_float ,
147165 y_triton_float ,
148166 rtol = rtol ,
149167 atol = atol ,
150- msg = "Quantized values differ between naive and Triton implementations" ,
168+ msg = "Quantized values differ between torch and Triton implementations" ,
151169 )
152170
153171 torch .testing .assert_close (
154- s_naive ,
172+ s_torch ,
155173 s_triton ,
156174 rtol = rtol ,
157175 atol = atol ,
158- msg = "Scales differ between naive and Triton implementations" ,
176+ msg = "Scales differ between torch and Triton implementations" ,
159177 )
160178
161179 input_tensor = torch .randn (
@@ -166,12 +184,12 @@ def verify_outputs(
166184 )
167185
168186 # Compile once
169- naive_impl_c = torch .compile (naive_fp8_blockwise_quant )
187+ torch_impl_c = torch .compile (torch_fp8_blockwise_quant )
170188
171- # Benchmark naive implementation
172- y_naive , s_naive = naive_impl_c (input_tensor , block_size )
173- naive_time_us = benchmark_cuda_function_in_microseconds (
174- naive_impl_c ,
189+ # Benchmark torch implementation
190+ y_torch , s_torch = torch_impl_c (input_tensor , block_size )
191+ torch_time_us = benchmark_cuda_function_in_microseconds (
192+ torch_impl_c ,
175193 input_tensor ,
176194 block_size ,
177195 )
@@ -184,8 +202,8 @@ def verify_outputs(
184202 block_size ,
185203 )
186204
187- # Verify correctness (compare to naive )
188- verify_outputs (y_naive , s_naive , y_triton , s_triton )
205+ # Verify correctness (compare to torch )
206+ verify_outputs (y_torch , s_torch , y_triton , s_triton )
189207
190208 # Memory bandwidth calculations
191209 bytes_per_input_el = torch .finfo (input_tensor .dtype ).bits / 8
@@ -197,13 +215,13 @@ def verify_outputs(
197215 y_triton .numel () * bytes_per_output_el + s_triton .numel () * bytes_per_scale_el
198216 )
199217
200- naive_gbps = ((read_bytes + write_bytes ) / 1e9 ) / (naive_time_us / 1e6 )
218+ torch_gbps = ((read_bytes + write_bytes ) / 1e9 ) / (torch_time_us / 1e6 )
201219 triton_gbps = ((read_bytes + write_bytes ) / 1e9 ) / (triton_time_us / 1e6 )
202220
203221 return ExperimentResult (
204- naive_us = naive_time_us ,
222+ torch_us = torch_time_us ,
205223 triton_us = triton_time_us ,
206- naive_gbps = naive_gbps ,
224+ torch_gbps = torch_gbps ,
207225 triton_gbps = triton_gbps ,
208226 )
209227
@@ -212,23 +230,23 @@ def print_results(experiments: List[Experiment]):
212230 headers = [
213231 "input_shape (M, K)" ,
214232 "block_size" ,
215- "naive_us " ,
233+ "torch_us " ,
216234 "triton_us" ,
217235 "speedup" ,
218- "naive_gbps " ,
236+ "torch_gbps " ,
219237 "triton_gbps" ,
220238 ]
221239 rows = []
222240 for experiment in experiments :
223- speedup = experiment .result .naive_us / experiment .result .triton_us
241+ speedup = experiment .result .torch_us / experiment .result .triton_us
224242 rows .append (
225243 [
226244 f"{ experiment .config .input_shape [0 ]} x{ experiment .config .input_shape [1 ]} " ,
227245 experiment .config .block_size ,
228- f"{ experiment .result .naive_us :.2f} " ,
246+ f"{ experiment .result .torch_us :.2f} " ,
229247 f"{ experiment .result .triton_us :.2f} " ,
230248 f"{ speedup :.2f} x" ,
231- f"{ experiment .result .naive_gbps :.1f} " ,
249+ f"{ experiment .result .torch_gbps :.1f} " ,
232250 f"{ experiment .result .triton_gbps :.1f} " ,
233251 ]
234252 )
0 commit comments