|
1 | 1 | import hashlib |
| 2 | +import itertools |
2 | 3 | import json |
3 | 4 | import logging |
4 | 5 | from typing import Any |
@@ -146,60 +147,68 @@ def from_dict(cls, data: dict[str, Any], skip_validity_check: bool = False) -> " |
146 | 147 | ) |
147 | 148 |
|
148 | 149 |
|
149 | | -def cross_generate_configs( |
150 | | - attn_impl_and_sdpa_backend: list[tuple[str, str | None]], |
151 | | - compiled_mode: list[str | None], |
152 | | - kernelized: list[bool], |
153 | | - warmup_iterations: int = 5, |
154 | | - measurement_iterations: int = 20, |
155 | | - batch_size: int = 1, |
156 | | - sequence_length: int = 128, |
157 | | - num_tokens_to_generate: int = 128, |
158 | | - gpu_monitoring: bool = True, |
| 150 | +def adapt_configs( |
| 151 | + configs: list[BenchmarkConfig], |
| 152 | + warmup_iterations: int | list[int] = 5, |
| 153 | + measurement_iterations: int | list[int] = 20, |
| 154 | + batch_size: int | list[int] = 1, |
| 155 | + sequence_length: int | list[int] = 128, |
| 156 | + num_tokens_to_generate: int | list[int] = 128, |
| 157 | + gpu_monitoring: bool | list[bool] = True, |
159 | 158 | ) -> list[BenchmarkConfig]: |
160 | | - # Create kwargs common to all configs |
161 | | - kwargs = { |
162 | | - "warmup_iterations": warmup_iterations, |
163 | | - "measurement_iterations": measurement_iterations, |
164 | | - "batch_size": batch_size, |
165 | | - "sequence_length": sequence_length, |
166 | | - "num_tokens_to_generate": num_tokens_to_generate, |
167 | | - "gpu_monitoring": gpu_monitoring, |
168 | | - } |
169 | | - # Cross-generate all combinations of attn_implementation, compiled_mode, and kernelized |
| 159 | + parameters = ( |
| 160 | + x if isinstance(x, list) else [x] |
| 161 | + for x in [ |
| 162 | + warmup_iterations, |
| 163 | + measurement_iterations, |
| 164 | + batch_size, |
| 165 | + sequence_length, |
| 166 | + num_tokens_to_generate, |
| 167 | + gpu_monitoring, |
| 168 | + ] |
| 169 | + ) |
| 170 | + iterator = itertools.product(*parameters) |
| 171 | + |
| 172 | + adapted_configs = [] |
| 173 | + for warmup_iters, measurement_iters, bs, seqlen, ntok, monitor in iterator: |
| 174 | + for config in configs: |
| 175 | + config = config.to_dict() |
| 176 | + config["warmup_iterations"] = warmup_iters |
| 177 | + config["measurement_iterations"] = measurement_iters |
| 178 | + config["batch_size"] = bs |
| 179 | + config["sequence_length"] = seqlen |
| 180 | + config["num_tokens_to_generate"] = ntok |
| 181 | + config["gpu_monitoring"] = monitor |
| 182 | + adapted_configs.append(BenchmarkConfig.from_dict(config)) |
| 183 | + return adapted_configs |
| 184 | + |
| 185 | + |
| 186 | +def get_config_by_level(level: int) -> list[BenchmarkConfig]: |
170 | 187 | configs = [] |
171 | | - for attn_implementation, sdpa_backend in list(dict.fromkeys(attn_impl_and_sdpa_backend)): |
172 | | - for cm in list(dict.fromkeys(compiled_mode)): |
173 | | - for kernelize_on in list(dict.fromkeys(kernelized)): |
174 | | - config = BenchmarkConfig( |
175 | | - attn_implementation=attn_implementation, |
176 | | - sdpa_backend=sdpa_backend, |
177 | | - compile_mode=cm, |
178 | | - kernelize=kernelize_on, |
179 | | - **kwargs, |
180 | | - ) |
181 | | - configs.append(config) |
| 188 | + # Early return if level is greater than 3: we generate all combinations of configs, maybe even w/ all compile modes |
| 189 | + if level >= 3: |
| 190 | + for attn_implementation, sdpa_backend in BenchmarkConfig.all_attn_implementations: |
| 191 | + # Usually there is not much to gain by compiling with other modes, but we allow it for level 4 |
| 192 | + compile_modes = BenchmarkConfig.all_compiled_modes if level >= 4 else [None, "default"] |
| 193 | + for cm in compile_modes: |
| 194 | + for kernelize_on in [False, KERNELIZATION_AVAILABLE]: |
| 195 | + configs.append( |
| 196 | + BenchmarkConfig( |
| 197 | + attn_implementation=attn_implementation, |
| 198 | + sdpa_backend=sdpa_backend, |
| 199 | + compile_mode=cm, |
| 200 | + kernelize=kernelize_on, |
| 201 | + ) |
| 202 | + ) |
| 203 | + return configs |
| 204 | + # Otherwise, we add the configs for the given level |
| 205 | + if level >= 0: |
| 206 | + configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default")) |
| 207 | + if level >= 1: |
| 208 | + configs.append(BenchmarkConfig(attn_implementation="flash_attention_2")) |
| 209 | + configs.append(BenchmarkConfig(attn_implementation="eager", compile_mode="default")) |
| 210 | + if level >= 2: |
| 211 | + configs.append(BenchmarkConfig(attn_implementation="sdpa", compile_mode="default")) |
| 212 | + configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", kernelize=True)) |
| 213 | + configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", kernelize=True)) |
182 | 214 | return configs |
183 | | - |
184 | | - |
185 | | -def generate_main_configs( |
186 | | - warmup_iterations: int = 5, |
187 | | - measurement_iterations: int = 20, |
188 | | - batch_size: int = 1, |
189 | | - sequence_length: int = 128, |
190 | | - num_tokens_to_generate: int = 128, |
191 | | -) -> list[BenchmarkConfig]: |
192 | | - # Create kwargs common to all configs |
193 | | - kwargs = { |
194 | | - "warmup_iterations": warmup_iterations, |
195 | | - "measurement_iterations": measurement_iterations, |
196 | | - "batch_size": batch_size, |
197 | | - "sequence_length": sequence_length, |
198 | | - "num_tokens_to_generate": num_tokens_to_generate, |
199 | | - } |
200 | | - return [ # TODO: test max-autotune instead of default |
201 | | - BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", gpu_monitoring=False, **kwargs), |
202 | | - BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", gpu_monitoring=True, **kwargs), |
203 | | - BenchmarkConfig(attn_implementation="eager", compile_mode="default", gpu_monitoring=True, **kwargs), |
204 | | - BenchmarkConfig(attn_implementation="flash_attention_2", gpu_monitoring=True, **kwargs), |
205 | | - ] |
|
0 commit comments