Skip to content

Add sparsity to microbenchmark inference #1874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,6 @@ checkpoints/
# Experimental
torchao/experimental/cmake-out
torchao/experimental/deps

# Benchmark outputs
benchmarks/microbenchmarks/test/results/
Empty file added benchmarks/__init__.py
Empty file.
84 changes: 84 additions & 0 deletions benchmarks/microbenchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Microbenchmarks

This directory contains microbenchmarking tools for measuring inference performance across different quantization methods and model architectures.

## Overview

The microbenchmarking system works as follows:

![Microbenchmarks Process Flow](../../docs/static/microbenchmarking_process_diagram.png)

## Components

![Microbenchmarks Flow](../../docs/static/microbenchmarks_code_flow_diagram.png)

- **benchmark_runner.py**: Main entry point that orchestrates the benchmarking process
- **benchmark_inference.py**: Handles model creation and inference benchmarking
- **utils.py**: Contains utility functions and configuration classes
- **test\/**: Test files and sample configurations

## Usage

1. Create a configuration YAML file (see example below)
2. Run the benchmark using:

```bash
python -m benchmarks.microbenchmarks.benchmark_runner --config path/to/config.yml
```

### Example Configuration

```yaml
# Sample configuration for inference benchmarks
quantization_config_recipe_names:
- "baseline"
- "int8wo"
- "int4wo-128"
- "int4wo-128-hqq"

output_dir: "benchmarks/microbenchmarks/results"

model_params:
matrix_shapes:
- name: "custom"
shapes: [
[1024, 1024, 1024], # [m, k, n]
[2048, 4096, 1024],
[4096, 4096, 1024]
]
high_precision_dtype: "torch.bfloat16"
compile: "max-autotune" # Options: "default", "max-autotune", "false"
device: "cuda" # Options: "cuda", "mps", "xpu", "cpu"
model_type: "linear" # Options: "linear", "ln_linear_sigmoid"
```

## Configuration Options

### Quantization Methods
Currently, quantization string is in same format as the one being passed in llama/generate.py.
- `baseline`: No quantization
- `int8wo`: 8-bit weight-only quantization
- `int4wo-{group_size}`: 4-bit weight-only quantization with specified group size
- `int4wo-{group_size}-hqq`: 4-bit weight-only quantization with HQQ

### Model Types
- `linear`: Simple linear layer
- `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid

### Device Options
- `cuda`: NVIDIA GPU
- `xpu`: Intel GPU
- `mps`: Apple Silicon GPU
- `cpu`: CPU fallback

## Output

Results are saved to a CSV file in the specified output directory

## Running Tests

To run the test suite:

```bash
python -m unittest discover benchmarks/microbenchmarks/test
```
Empty file.
70 changes: 70 additions & 0 deletions benchmarks/microbenchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
Inference benchmark runner

This script runs inference benchmarks and generates a micro-benchmarking report for it.
- run() function is the main entry point for running inference benchmarks.
"""

from copy import deepcopy
from pathlib import Path

import torch

from benchmarks.microbenchmarks.utils import (
BenchmarkConfig,
BenchmarkResult,
clean_caches,
create_model_and_input,
model_inference_time_in_ms,
string_to_config,
)
from torchao.quantization import quantize_


def run(config: BenchmarkConfig) -> BenchmarkResult:
"""Run inference benchmarks"""
clean_caches() # Clean caches

# Create output directory if it doesn't exist
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

base_model, input_data = create_model_and_input(
config.model_type,
config.m,
config.k,
config.n,
high_precision_dtype=config.high_precision_dtype,
device=config.device,
)

# Use quantize_ to apply each quantization function to the model
m_copy = deepcopy(base_model).eval().to(config.device)
quantization_config = string_to_config(
config.quantization, high_precision_dtype=config.high_precision_dtype
)
if quantization_config is not None:
quantize_(m_copy, quantization_config)
if config.use_torch_compile:
print("Compiling model....")
m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True)

# Run benchmarks
result = BenchmarkResult(config=config)

# Benchmark time to run an inference call for quantized model
result.model_inference_time_in_ms = model_inference_time_in_ms(
model=m_copy, input_data=input_data
)

# TODO: Benchmark time using profiler
# Profile dtype model evaluation
# prof_dtype = benchmark_model_op_with_profiler_in_microseconds(m_copy, input_data, quantized_dtype)
# prof_dtype.export_chrome_trace(f"{quantization}_model_{input_data[0].size()[0]}.json") # Save profiling details

# TODO: Benchmark gemm time using cuda graph
# gemm_time = benchmark_torch_function_in_microseconds(gemm_op, *args, **kwargs)

# TODO: Benchmark op with cuda graph
# time = benchmark_op_with_cuda_graph(op, args)

return result
123 changes: 123 additions & 0 deletions benchmarks/microbenchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Benchmark Runner

This is the main entry point for the benchmarking application. It reads the YAML configuration
file and orchestrates the entire benchmarking process by:
- Loading and validating benchmark configurations
- Executing benchmark scenarios
- Collecting and processing results
- Generating reports

Usage:
python benchmark_runner.py [config.yaml]

The YAML file should contain all necessary configuration parameters for the benchmarks.
"""

from itertools import product
from typing import Any, Dict, List, Tuple

import yaml

from benchmarks.microbenchmarks.utils import (
BenchmarkConfig,
generate_results_csv,
print_results,
)


def get_shapes_for_config(shape_config: Dict[str, Any]) -> List[Tuple[str, List[int]]]:
"""Get shapes for a given configuration"""
name = shape_config["name"]
if name == "custom":
return [(name, shape) for shape in shape_config["shapes"]]
else:
raise NotImplementedError(
f"Shape config {name} not supported. Currently only supports custom shapes."
)


def load_benchmark_configs(config_path: str) -> List[BenchmarkConfig]:
"""Load benchmark configurations from YAML file"""
with open(config_path, "r") as f:
config_data = yaml.safe_load(f)

quantization_config_recipe_names = config_data["quantization_config_recipe_names"]
params = config_data["model_params"]
output_dir = config_data.get("output_dir", "benchmarks/microbenchmarks/results")

configs = []
# Process each shape configuration
for shape_config in params["matrix_shapes"]:
shapes = get_shapes_for_config(shape_config)
# Generate combinations for each shape
for quant, (shape_name, shape) in product(
quantization_config_recipe_names, shapes
):
configs.append(
BenchmarkConfig(
quantization=quant,
params=params,
shape_name=shape_name,
shape=shape,
output_dir=output_dir,
)
)
return configs


def run_inference_benchmarks_from_config(config_path: str) -> None:
"""Run benchmarks using configurations from YAML file"""
from benchmarks.microbenchmarks.benchmark_inference import run as run_inference

configs = load_benchmark_configs(config_path)
results = []
print("Benchmarking Inference ......")
for config in configs:
print(f"Running: {config.name}")
result = run_inference(config) # Pass the config object directly
results.append(result)

# Add results to csv
generate_results_csv(results, configs[0].output_dir)

# Print results
print_results(results)

# TODO: Process results: Speedups:
# 1. For different shapes for same model and quantization
# 2. For different quantizations for same model and shape
# 3. For different models for same quantization


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Run benchmarks from config file")
parser.add_argument(
"--config",
type=str,
required=True,
help="Path to benchmark configuration file",
)
parser.add_argument(
"--benchmark_mode",
"-m",
type=str,
default="inference",
choices=["inference", "training"],
help="Benchmark mode to run: inference or training",
)
args = parser.parse_args()

# Run benchmarks
if args.benchmark_mode == "inference":
run_inference_benchmarks_from_config(args.config)
elif args.benchmark_mode == "training":
print("Training mode not implemented yet")
else:
raise ValueError(
f"Invalid benchmark mode: {args.benchmark_mode}, choose from inference or training"
)

# TODO: Add support for args to override config values and run smaller benchmarks
Empty file.
20 changes: 20 additions & 0 deletions benchmarks/microbenchmarks/test/benchmark_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Sample configuration for inference kernel benchmarks
quantization_config_recipe_names:
- "baseline"
- "int8dq"
- "int4wo-128"
output_dir: "benchmarks/microbenchmarks/test/results" # Directory for results and plots
model_params:
matrix_shapes:
- name: "custom"
shapes: [
[1024, 1024, 1024], # [m, k, n]
[2048, 4096, 1024],
[4096, 4096, 1024]
]
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda" # Change this to "cuda", "mps", "xpu", or "cpu" as needed
model_type: "linear"
sparsity: "2:4"
32 changes: 32 additions & 0 deletions benchmarks/microbenchmarks/test/test_benchmark_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest

from benchmarks.microbenchmarks.benchmark_inference import run
from benchmarks.microbenchmarks.utils import BenchmarkConfig


class TestBenchmarkInference(unittest.TestCase):
def setUp(self):
self.params = {
"high_precision_dtype": "torch.float32", # Use float32 for testing
"use_torch_compile": False,
"device": "cpu", # Use CPU for testing
"model_type": "linear",
}
self.config = BenchmarkConfig(
quantization="baseline",
params=self.params,
shape_name="test",
shape=[16, 32, 8], # Small shape for testing
output_dir="benchmarks/microbenchmarks/test/test_output/",
)

def test_run_inference(self):
result = run(self.config)

# Check benchmark result is present and reasonable
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))
self.assertGreater(result.model_inference_time_in_ms, 0)


if __name__ == "__main__":
unittest.main()
Loading
Loading