Skip to content

Commit 87756f9

Browse files
committed
Update
[ghstack-poisoned]
2 parents 8691bd4 + 3766ed7 commit 87756f9

File tree

23 files changed

+1244
-706
lines changed

23 files changed

+1244
-706
lines changed

.github/workflows/torchao_experimental_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ jobs:
103103
pip install parameterized
104104
pip install pyyaml
105105
pip install numpy
106+
pip install importlib-metadata
106107
- name: Print pip freeze
107108
run: |
108109
pip freeze

benchmarks/float8/training/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ Training parameters can be configured via environment variables.
1414
- `FLOAT8_RECIPE_WITH_BEST_SETTINGS`: "rowwise" or "tensorwise". Applies float8 training with the specified scaling recipe, as well as additional training configs which are optimal for that scaling recipe. See `float8_training_benchmark.sh` for more details.
1515
- `BATCH_SIZE`: Defaults to 1.
1616
- `STEPS`: Defaults to 100.
17+
- `EXTRA_ARGS`: Extra arguments to pass to torchtitan training script. See [torchtitan](https://github.com/pytorch/torchtitan) docs for the full list of options.
1718

1819
**NOTE**: `torch.compile` and FSDP2 are always used. Other forms of parallelism supported in torchtitan are not yet supported in this script.

benchmarks/float8/training/float8_training_benchmark.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ if [ -z "${TORCHTITAN_ROOT}" ]; then
2222
echo " * FLOAT8_RECIPE_WITH_BEST_SETTINGS: "rowwise" or "tensorwise". if set, use float8 training in torchtitan with the specified recipe, including the additional settings which are optimal for that recipe. otherwise, use bf16 mixed precision training."
2323
echo " * BATCH_SIZE: defaults to 1."
2424
echo " * STEPS: defaults to 100."
25+
echo " * EXTRA_ARGS: additional arguments to pass to the torchtitan training script."
2526
exit 1
2627
fi
2728

@@ -44,7 +45,7 @@ cd ${TORCHTITAN_ROOT}
4445
echo "float8 args: ${FLOAT8_ARGS}"
4546

4647
# run the command with the specified arguments
47-
CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ${TORCHTITAN_ROOT}/run_train.sh --training.steps=${STEPS} --training.batch_size=${BATCH_SIZE} --training.compile ${FLOAT8_ARGS} 2>&1 | tee ${LOG_FILE}
48+
CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ${TORCHTITAN_ROOT}/run_train.sh --training.steps=${STEPS} --training.batch_size=${BATCH_SIZE} --training.compile ${FLOAT8_ARGS} ${EXTRA_ARGS} 2>&1 | tee ${LOG_FILE}
4849

4950
# return to original working directory
5051
cd $original_dir

benchmarks/microbenchmarks/benchmark_inference.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
string_to_config,
2525
)
2626
from torchao.quantization import quantize_
27+
from torchao.sparsity.sparse_api import sparsify_
2728

2829

2930
def run(config: BenchmarkConfig) -> BenchmarkResult:
@@ -44,11 +45,33 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
4445

4546
# Use quantize_ to apply each quantization function to the model
4647
m_copy = deepcopy(base_model).eval().to(config.device)
47-
quantization_config = string_to_config(
48-
config.quantization, high_precision_dtype=config.high_precision_dtype
48+
ao_base_config = string_to_config(
49+
config.quantization,
50+
config.sparsity,
51+
high_precision_dtype=config.high_precision_dtype,
4952
)
50-
if quantization_config is not None:
51-
quantize_(m_copy, quantization_config)
53+
54+
# Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA)
55+
is_cuda = config.device == "cuda" and torch.cuda.is_available()
56+
57+
if config.sparsity is not None and (
58+
config.quantization is None or "baseline" in config.quantization
59+
):
60+
if is_cuda:
61+
print(f"Applying {config.sparsity} sparsity to model")
62+
sparsify_(m_copy, ao_base_config)
63+
else:
64+
print(
65+
f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}"
66+
)
67+
elif config.sparsity is None and (
68+
config.quantization is None or "baseline" in config.quantization
69+
):
70+
pass # No quantization or sparsity specified, do nothing
71+
else:
72+
print("Quantizing model....")
73+
quantize_(m_copy, ao_base_config)
74+
5275
if config.use_torch_compile:
5376
print("Compiling model....")
5477
m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True)

benchmarks/microbenchmarks/benchmark_runner.py

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import argparse
2323
from itertools import product
24-
from typing import Any, Dict, List, Tuple
24+
from typing import Any, Dict, List, Optional, Set, Tuple
2525

2626
import yaml
2727

@@ -68,6 +68,53 @@ def get_param_combinations(model_param):
6868
return shapes, base_params
6969

7070

71+
def get_quantization_sparsity_recipes(
72+
quantization_recipes: List[str], sparsity_recipes: List[str]
73+
) -> Set[Tuple[str, Optional[str]]]:
74+
"""Generate valid quantization and sparsity recipes.
75+
76+
Args:
77+
quantization_recipes: List of quantization recipes
78+
sparsity_recipes: List of sparsity recipes
79+
80+
Returns:
81+
Set of tuples containing (quantization_recipe, sparsity_recipe)
82+
For block sparsity, quantization is always "baseline"
83+
All quantization techniques are also run without sparsity
84+
"""
85+
config_recipes = set()
86+
87+
# Always include baseline without sparsity
88+
config_recipes.add(("baseline", None))
89+
90+
# Add all quantization techniques without sparsity
91+
for quant_config in quantization_recipes:
92+
config_recipes.add((quant_config, None))
93+
94+
# Process combinations of quantization and sparsity
95+
for sparse_config in sparsity_recipes:
96+
if sparse_config is None:
97+
# Skip None sparsity as we've already added all quantization techniques without sparsity
98+
continue
99+
elif "block" in sparse_config:
100+
# For block sparsity, only pair with baseline quantization
101+
config_recipes.add(("baseline", sparse_config))
102+
elif "semi" in sparse_config or "2:4" in sparse_config:
103+
# For semi-sparse, only pair with compatible quantization methods
104+
for quant_config in quantization_recipes:
105+
if (
106+
"marlin" in quant_config
107+
or "int8dq" in quant_config
108+
or "float8dq" in quant_config
109+
or quant_config == "baseline"
110+
):
111+
config_recipes.add((quant_config, sparse_config))
112+
else:
113+
raise ValueError(f"Invalid sparsity recipe: {sparse_config}")
114+
115+
return config_recipes
116+
117+
71118
def load_benchmark_configs(cli_args: argparse.Namespace) -> List[BenchmarkConfig]:
72119
"""Load benchmark configurations from CLI arguments and YAML file."""
73120
with open(cli_args.config, "r") as f:
@@ -78,24 +125,29 @@ def load_benchmark_configs(cli_args: argparse.Namespace) -> List[BenchmarkConfig
78125

79126
# Create all possible combinations
80127
configs = []
128+
quantization_sparsity_recipes = get_quantization_sparsity_recipes(
129+
config.get("quantization_config_recipe_names", []),
130+
config.get("sparsity_config_recipe_names", []),
131+
)
81132
for model_param in config["model_params"]:
82133
shapes, params = get_param_combinations(model_param)
83134

84135
# Create configs for all combinations
85-
for quant_config, (shape_name, shape) in product(
86-
config.get("quantization_config_recipe_names", ["baseline"]), shapes
136+
for (quant_config, sparse_config), (shape_name, shape) in product(
137+
quantization_sparsity_recipes,
138+
shapes,
87139
):
88140
configs.append(
89141
BenchmarkConfig(
90142
quantization=quant_config,
143+
sparsity=sparse_config,
91144
params=params,
92145
shape_name=shape_name,
93146
shape=shape,
94147
output_dir=output_dir,
95148
benchmark_mode=benchmark_mode,
96149
)
97150
)
98-
99151
return configs
100152

101153

@@ -104,14 +156,17 @@ def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None
104156
from benchmarks.microbenchmarks.benchmark_inference import run as run_inference
105157

106158
results = []
107-
print("Benchmarking Inference ......")
159+
print("----------------- RUNNING BENCHMARKS FOR INFERENCE -----------------------")
108160
for config in configs:
161+
print("----------------------------------------")
109162
try:
110-
print(f"Running: {config.name}")
163+
print(
164+
f"Running: {config.name} for Quantization: {config.quantization} and Sparsity: {config.sparsity}"
165+
)
111166
result = run_inference(config) # Pass the config object directly
112167
results.append(result)
113-
except Exception as e:
114-
print(f"Error running benchmark {config.name}: {e}")
168+
except Exception:
169+
print(f"Error running benchmark {config.name}")
115170
continue
116171

117172
# Add results to csv

benchmarks/microbenchmarks/test/benchmark_config.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# Sample configuration for inference benchmarks
22
benchmark_mode: "inference"
33
quantization_config_recipe_names:
4-
- "baseline"
4+
# Will run a baseline inference for model by default, without quantization for comparison
55
- "int4wo-32"
6-
- "int4wo-128"
6+
- "marlin"
7+
sparsity_config_recipe_names:
8+
# Will run a baseline inference for model by default, without sparsity for comparison
9+
- "semi-sparse"
10+
- "block"
711
output_dir: "benchmarks/microbenchmarks/results"
812
model_params:
913
- name: "small_bf16_linear"

benchmarks/microbenchmarks/test/test_benchmark_inference.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import tempfile
77
import unittest
8+
from unittest.mock import patch
89

910
from benchmarks.microbenchmarks.benchmark_inference import run
1011
from benchmarks.microbenchmarks.utils import BenchmarkConfig, BenchmarkResult
@@ -17,6 +18,7 @@ def setUp(self):
1718

1819
self.config = BenchmarkConfig(
1920
quantization="baseline",
21+
sparsity="semi-sparse",
2022
params={
2123
"high_precision_dtype": "torch.float32",
2224
"use_torch_compile": False,
@@ -35,11 +37,74 @@ def tearDown(self):
3537

3638
shutil.rmtree(self.temp_dir)
3739

38-
def test_run_inference(self):
40+
@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
41+
def test_run_inference(self, mock_string_to_config):
42+
# Mock string_to_config to return a valid config
43+
from torchao.sparsity.sparse_api import SemiSparseWeightConfig
44+
45+
mock_string_to_config.return_value = SemiSparseWeightConfig()
46+
3947
result = run(self.config)
4048
self.assertIsInstance(result, BenchmarkResult)
4149
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))
4250

51+
@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
52+
def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config):
53+
"""Test running inference with sparsity configurations"""
54+
# Mock string_to_config to return valid configs
55+
from torchao.dtypes import MarlinSparseLayout
56+
from torchao.quantization import Int4WeightOnlyConfig
57+
58+
# Test with semi-sparse config
59+
mock_string_to_config.return_value = Int4WeightOnlyConfig(
60+
layout=MarlinSparseLayout()
61+
)
62+
config = BenchmarkConfig(
63+
quantization="marlin",
64+
sparsity="semi-sparse",
65+
params={
66+
"high_precision_dtype": "torch.float32",
67+
"use_torch_compile": False,
68+
"device": "cpu",
69+
"model_type": "linear",
70+
},
71+
shape_name="custom",
72+
shape=[64, 64, 64], # Use dimensions divisible by 64
73+
output_dir=self.temp_dir,
74+
benchmark_mode="inference",
75+
)
76+
result = run(config)
77+
self.assertIsInstance(result, BenchmarkResult)
78+
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))
79+
80+
@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
81+
def test_run_inference_with_block_sparsity(self, mock_string_to_config):
82+
"""Test running inference with sparsity configurations"""
83+
# Mock string_to_config to return valid configs
84+
from torchao.sparsity.sparse_api import (
85+
BlockSparseWeightConfig,
86+
)
87+
88+
# Test with block sparsity
89+
mock_string_to_config.return_value = BlockSparseWeightConfig()
90+
config = BenchmarkConfig(
91+
quantization="baseline",
92+
sparsity="block",
93+
params={
94+
"high_precision_dtype": "torch.float32",
95+
"use_torch_compile": False,
96+
"device": "cpu",
97+
"model_type": "linear",
98+
},
99+
shape_name="custom",
100+
shape=[64, 64, 64], # Use dimensions divisible by 64
101+
output_dir=self.temp_dir,
102+
benchmark_mode="inference",
103+
)
104+
result = run(config)
105+
self.assertIsInstance(result, BenchmarkResult)
106+
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))
107+
43108

44109
if __name__ == "__main__":
45110
unittest.main()

0 commit comments

Comments
 (0)