Skip to content

Commit

Permalink
benchmark fix (#229)
Browse files Browse the repository at this point in the history
* benchmark fix

*  add seven new testing parameters

* move shapes info to yaml file

* Added the BenchmarkMetrics & BenchmarkResult  abstraction
  • Loading branch information
kiddyjinjin authored Oct 30, 2024
1 parent 4bcb3ea commit 4e6cb3b
Show file tree
Hide file tree
Showing 18 changed files with 2,093 additions and 1,806 deletions.
28 changes: 28 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,34 @@ tools/code_coverage/coverage.sh PR_ID

Currently, the pipeline does not check the performance of operators. You can write performance tests in the `benchmark` directory to evaluate your optimization results.

### 2.5 Operator Performance Benchmarking

`Op Benchmark` is used to evaluate the performance of operators. If you are adding a new operator, you need to add corresponding test cases in the appropriate file under the `benchmark` directory. It is recommended to follow the steps below to add test cases for the new operator:

1. **Select the appropriate test file**
Based on the type of operator, choose the corresponding file in the `benchmark` directory:
- For reduction operators, add the test case to `test_reduction_perf.py`.
- For tensor constructor operators, add the test case to `test_tensor_constructor_perf.py`.
- If the operator doesn't fit into an existing category, you can add it to `test_special_perf.py` or create a new file for the new operator category.

2. **Check existing benchmark classes**
Once you've identified the correct file, review the existing classes that inherit from the `Benchmark` structure to see if any fit the test scenario for your operator, specifically considering:
- Whether the **metric collection** is suitable.
- Whether the **input generation function** (`input_generator` or `input_fn`) is appropriate.

3. **Add test cases**
Depending on the test scenario, follow one of the approaches below to add the test case:

3.1 **Using existing metric and input generator**
If the existing metric collection and input generation function meet the requirements of your operator, you can add a line of `pytest.mark.parametrize` directly, following the code organization in the file. For example, see the operators in `test_binary_pointwise_perf.py`.

3.2 **Custom input generator**
If the metric collection is suitable but the input generation function does not meet the operator's requirements, you can implement a custom `input_generator`. Refer to the `topk_input_fn` function in `test_special_perf.py` as an example of a custom input function for the `topk` operator.

3.3 **Custom metric and input generator**
If neither the existing metric collection nor the input generation function meets the operator's needs, you can create a new class. This class should define operator-specific metric collection logic and a custom input generator. You can refer to various `Benchmark` subclasses across the `benchmark` directory for examples.


## 3. Project Structure

```
Expand Down
28 changes: 28 additions & 0 deletions CONTRIBUTING_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,34 @@ tools/code_coverage/coverage.sh PR_ID

当前流水线尚未对算子的性能进行检查,可以在 `benchmark` 目录下撰写性能测试,查看自己的优化效果。


### 2.5 算子性能测试

`Op Benchmark` 用于评估算子的性能。如果新增了算子,需要在 `benchmark` 目录下的相应文件中添加对应的测试用例。建议按照以下步骤完成算子用例的添加:

1. **选择合适的测试文件**
根据算子的类别,选择 `benchmark` 目录下对应的文件:
- 对于 reduction 类算子,可以添加到 `test_reduction_perf.py` 文件。
- 对于 tensor constructor 类算子,可以添加到 `test_tensor_constructor_perf.py` 文件。
- 如果算子难以归类,可以放到 `test_special_perf.py` 文件,或者创建一个新文件来表示新的算子类别。

2. **检查现有测试类**
确认所需添加的文件后,查看该文件下已有的继承了 `Benchmark` 结构的各类(Class)。检查是否有适合你算子的测试场景,主要考虑以下两点:
- **Metric 采集是否合适**
- **输入构造函数(`input_generator``input_fn`)是否合适**

3. **添加测试用例**
根据测试场景的需求,选择以下方式添加测试用例:

3.1 **使用现有的 metric 和输入构造函数**
如果现有的 metric 采集和输入构造函数满足算子的要求,可以按照文件内的代码组织形式,直接添加一行 `pytest.mark.parametrize`。例如,可以参考 `test_binary_pointwise_perf.py` 文件中的所有算子用例。

3.2 **自定义输入构造函数**
如果现有的 metric 采集符合要求,但输入构造函数不满足算子需求,可以实现自定义的 `input_generator`。具体可参考 `test_special_perf.py` 文件中的 `topk_input_fn` 函数,它是为 `topk` 算子编写的输入构造函数。

3.3 **自定义 metric 和输入构造函数**
如果现有的 metric 采集和输入构造函数都不满足需求,可以新建一个 `Class`,为该类设置算子特化的 metric 采集逻辑和输入构造函数。此类场景可以参考 `benchmark` 目录下各种 `Benchmark` 子类的写法。

## 3. 项目结构

```
Expand Down
231 changes: 231 additions & 0 deletions benchmark/attri_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import itertools
from dataclasses import asdict, dataclass, fields
from enum import Enum
from typing import List, Optional, Tuple

import torch

FLOAT_DTYPES = [torch.float16, torch.float32, torch.bfloat16]
INT_DTYPES = [torch.int16, torch.int32]
BOOL_DTYPES = [
torch.bool,
]

DEFAULT_WARMUP_COUNT = 1000
DEFAULT_ITER_COUNT = 100

# LEGACY_SHAPES are maintained for legacy benchmark SIZE settings and may be removed in the future.
# Do not reference this elsewhere.
LEGACY_SHAPES = [i * 64 for i in range(1, 22, 5)]
LEGACY_NON_BLAS_SHAPES = [(1024, shape) for shape in LEGACY_SHAPES]
LEGACY_BLAS_SHAPES = [(16, shape, shape, shape) for shape in LEGACY_SHAPES]

# Default shapes settings
DEFAULT_SHAPES = [
(1024 * 1024 * 1024,), # from perf
(64, 64),
(4096, 4096),
(64, 512, 512),
(1024, 1024, 1024), # from perf
]


# This function is adapted from: https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/utils/triton_op.py
def llama_shapes():
# batch sizes * seq lengths
BS = [2**i for i in range(0, 17)]
# attn: wqkv, wo; ffn: w13, w2
KN = [
(4096, 12288),
(4096, 4096),
(4096, 22016),
(11008, 4096),
(8192, 1280),
(1024, 8192),
(8192, 7168),
(3584, 8192),
(16384, 2304),
(2048, 16384),
(16384, 13312),
(6656, 16384),
]
return [(bs, n, k, None) for bs, (k, n) in itertools.product(BS, KN)]


@dataclass
class BenchmarkMetrics:
# Legacy shape information for backward compatibility
# This field corresponds to the 'size' field in the previous version's benchmark.
legacy_shape: Optional[int] = None
# Detailed size info
shape_detail: Optional[Tuple[int, ...]] = None
# Latency base in ms
latency_base: Optional[float] = None
# Latency in ms
latency: Optional[float] = None
# Speedup over baseline
speedup: Optional[float] = None
# Accuracy over baseline (not implemented yet)
accuracy: Optional[float] = None
# TFLOPS (not implemented yet)
tflops: Optional[float] = None
# Utilization (not implemented yet)
utilization: Optional[float] = None


ALL_AVAILABLE_METRICS = set(map(lambda x: x.name, fields(BenchmarkMetrics))) - {
"legacy_shape",
"shape_detail",
}

DEFAULT_METRICS = [
metric
for metric in ["latency_base", "latency", "speedup"]
if metric in ALL_AVAILABLE_METRICS
]


def check_metric_dependencies(
requested_metrics: Optional[List[str]],
) -> Optional[List[str]]:
"""
Checks if the requested metrics satisfy their dependencies.
Returns True if the dependencies are satisfied, otherwise False.
"""
# Predefined dependencies between metrics
buildin_dependencies = {
"speedup": ["latency", "latency_base"],
"utilization": ["latency", "tflops"],
}
unsatisfied_metrics = []
if requested_metrics is None:
return unsatisfied_metrics

satisfied_metrics = set()
for metric in requested_metrics:
if metric not in buildin_dependencies:
# If the metric has no dependencies, it's automatically satisfied
satisfied_metrics.add(metric)
else:
required_metrics = buildin_dependencies[metric]
# Check if all dependencies are in the satisfied metrics list
if not all(req in satisfied_metrics for req in required_metrics):
unsatisfied_metrics.append(metric)
else:
satisfied_metrics.add(metric)
return unsatisfied_metrics


def get_recommended_shapes(
op_name: str, op_specified_shapes: Optional[List[Tuple[int, ...]]]
):
def _shapes_sort(shapes):
shapes = [shape if isinstance(shape, tuple) else (shape,) for shape in shapes]
return sorted(shapes, key=lambda x: torch.tensor(x).prod().item())

if op_specified_shapes:
# TODO: handle situation that list as the basic element in shape.
return _shapes_sort(op_specified_shapes)
return _shapes_sort(DEFAULT_SHAPES)


class BenchLevel(Enum):
COMPREHENSIVE = "comprehensive"
CORE = "core"


@dataclass
class OperationAttribute:
op_name: str
# Recommended core benchmark shapes for the given operation
recommended_core_shapes: List[Tuple[int, ...]]
shape_desc: str

def __str__(self) -> str:
return (
f"{'Operator name':<40} | {self.op_name}\n"
f"{'Recommended Core Shapes[' + self.shape_desc + ']':<40} | {self.recommended_core_shapes}\n"
)

def to_dict(self) -> dict:
return self.__dict__


def custom_json_encoder(obj):
if isinstance(obj, torch.dtype):
return str(obj)
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")


@dataclass
class BenchmarkResult:
"""Record the benchmark result for each operator."""

# Unique name of the operator
op_name: str
dtype: str
mode: str
level: str
# Benchmark results
result: List[BenchmarkMetrics]

def __str__(self) -> str:
header = (
f"\nOperator: {self.op_name} Performance Test (dtype={self.dtype}, mode={self.mode}, level={self.level})\n"
f"{'Size':<10} {'Torch Latency (ms)':>20} {'Gems Latency (ms)':>20} {'Gems Speedup':>20}"
f"{'Size Detail':>20}\n"
f"{'-' * 90}\n"
)
metrics_lines = "".join(self._format_metrics(ele) for ele in self.result)
return header + metrics_lines

def _format_metrics(self, metrics: BenchmarkMetrics) -> str:
self.gen_legacy_shape(metrics)
legacy_shape_str = (
metrics.legacy_shape if metrics.legacy_shape is not None else "N/A"
)
latency_base_str = (
f"{metrics.latency_base:.6f}" if metrics.latency_base is not None else "N/A"
)
latency_str = f"{metrics.latency:.6f}" if metrics.latency is not None else "N/A"
speedup_str = f"{metrics.speedup:.3f}" if metrics.speedup is not None else "N/A"
shape_detail_str = (
metrics.shape_detail if metrics.shape_detail is not None else "N/A"
)
return (
f"{legacy_shape_str:<10}"
f"{latency_base_str:>20}"
f"{latency_str:>20}"
f"{speedup_str:>20}"
f"{' ' * 10}"
f"{shape_detail_str}\n"
)

def gen_legacy_shape(self, metrics: BenchmarkMetrics) -> Optional[int]:
first_shape = (
metrics.shape_detail[0] if isinstance(metrics.shape_detail, list) else None
)
to_record_shape = (
tuple(first_shape) if isinstance(first_shape, torch.Size) else None
)

if to_record_shape in LEGACY_NON_BLAS_SHAPES:
metrics.legacy_shape = to_record_shape[-1]
elif (
isinstance(to_record_shape, tuple)
and len(to_record_shape) == 2
and to_record_shape[0] == 1024
):
metrics.legacy_shape = to_record_shape[-1]
else:
metrics.legacy_shape = None

def to_json(self) -> str:
import json

# Convert to dict and handle tuple serialization for shape_detail
result_dict = asdict(self)
return json.dumps(result_dict, default=custom_json_encoder)

def to_dict(self) -> dict:
return self.__dict__
Loading

0 comments on commit 4e6cb3b

Please sign in to comment.