Skip to content

Commit 86d8dfa

Browse files
xin3heyuwenzho
authored andcommitted
add fp8 example and document (#1639)
Signed-off-by: xinhe3 <xinhe3@hababa.ai>
1 parent 853bb8d commit 86d8dfa

File tree

12 files changed

+775
-53
lines changed

12 files changed

+775
-53
lines changed

README.md

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -68,67 +68,52 @@ pip install "neural-compressor>=2.3" "transformers>=4.34.0" torch torchvision
6868
```
6969
After successfully installing these packages, try your first quantization program.
7070

71-
### Weight-Only Quantization (LLMs)
72-
Following example code demonstrates Weight-Only Quantization on LLMs, it supports Intel CPU, Intel Gaudi2 AI Accelerator, Nvidia GPU, best device will be selected automatically.
71+
### [FP8 Quantization](./examples/3.x_api/pytorch/cv/fp8_quant/)
72+
Following example code demonstrates FP8 Quantization, it is supported by Intel Gaudi2 AI Accelerator.
7373

7474
To try on Intel Gaudi2, docker image with Gaudi Software Stack is recommended, please refer to following script for environment setup. More details can be found in [Gaudi Guide](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#launch-docker-image-that-was-built).
7575
```bash
7676
# Run a container with an interactive shell
77-
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest
78-
79-
# Install the optimum-habana
80-
pip install --upgrade-strategy eager optimum[habana]
81-
82-
# Install INC/auto_round
83-
pip install neural-compressor auto_round
77+
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest
8478
```
8579
Run the example:
8680
```python
87-
from transformers import AutoModel, AutoTokenizer
88-
89-
from neural_compressor.config import PostTrainingQuantConfig
90-
from neural_compressor.quantization import fit
91-
from neural_compressor.adaptor.torch_utils.auto_round import get_dataloader
92-
93-
model_name = "EleutherAI/gpt-neo-125m"
94-
float_model = AutoModel.from_pretrained(model_name)
95-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
96-
dataloader = get_dataloader(tokenizer, seqlen=2048)
97-
98-
woq_conf = PostTrainingQuantConfig(
99-
approach="weight_only",
100-
op_type_dict={
101-
".*": { # match all ops
102-
"weight": {
103-
"dtype": "int",
104-
"bits": 4,
105-
"algorithm": "AUTOROUND",
106-
},
107-
}
108-
},
81+
from neural_compressor.torch.quantization import (
82+
FP8Config,
83+
prepare,
84+
convert,
10985
)
110-
quantized_model = fit(model=float_model, conf=woq_conf, calib_dataloader=dataloader)
86+
import torchvision.models as models
87+
88+
model = models.resnet18()
89+
qconfig = FP8Config(fp8_config="E4M3")
90+
model = prepare(model, qconfig)
91+
# customer defined calibration
92+
calib_func(model)
93+
model = convert(model)
11194
```
112-
**Note:**
11395

114-
To try INT4 model inference, please directly use [Intel Extension for Transformers](https://github.com/intel/intel-extension-for-transformers), which leverages Intel Neural Compressor for model quantization.
96+
### [Weight-Only Quantization (LLMs)](./examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/)
11597

116-
### Static Quantization (Non-LLMs)
98+
Following example code demonstrates Weight-Only Quantization on LLMs, it supports Intel CPU, Intel Gaudi2 AI Accelerator, Nvidia GPU, best device will be selected automatically.
11799

118100
```python
119-
from torchvision import models
101+
from neural_compressor.torch.quantization import prepare, convert, AutoRoundConfig
120102

121-
from neural_compressor.config import PostTrainingQuantConfig
122-
from neural_compressor.data import DataLoader, Datasets
123-
from neural_compressor.quantization import fit
103+
model_name = "EleutherAI/gpt-neo-125m"
104+
model = AutoModel.from_pretrained(model_name)
124105

125-
float_model = models.resnet18()
126-
dataset = Datasets("pytorch")["dummy"](shape=(1, 3, 224, 224))
127-
calib_dataloader = DataLoader(framework="pytorch", dataset=dataset)
128-
static_quant_conf = PostTrainingQuantConfig()
129-
quantized_model = fit(model=float_model, conf=static_quant_conf, calib_dataloader=calib_dataloader)
106+
quant_config = AutoRoundConfig()
107+
model = prepare(model, quant_config)
108+
# customer defined calibration
109+
run_fn(model) # calibration
110+
model = convert(model)
130111
```
131112

113+
**Note:**
114+
115+
To try INT4 model inference, please directly use [Intel Extension for Transformers](https://github.com/intel/intel-extension-for-transformers), which leverages Intel Neural Compressor for model quantization.
116+
132117
## Documentation
133118

134119
<table class="docutils">
@@ -154,12 +139,13 @@ quantized_model = fit(model=float_model, conf=static_quant_conf, calib_dataloade
154139
<tbody>
155140
<tr>
156141
<td colspan="2" align="center"><a href="./docs/source/3x/PyTorch.md">Overview</a></td>
157-
<td colspan="2" align="center"><a href="./docs/source/3x/PT_StaticQuant.md">Static Quantization</a></td>
158142
<td colspan="2" align="center"><a href="./docs/source/3x/PT_DynamicQuant.md">Dynamic Quantization</a></td>
143+
<td colspan="2" align="center"><a href="./docs/source/3x/PT_StaticQuant.md">Static Quantization</a></td>
159144
<td colspan="2" align="center"><a href="./docs/source/3x/PT_SmoothQuant.md">Smooth Quantization</a></td>
160145
</tr>
161146
<tr>
162-
<td colspan="4" align="center"><a href="./docs/source/3x/PT_WeightOnlyQuant.md">Weight-Only Quantization</a></td>
147+
<td colspan="2" align="center"><a href="./docs/source/3x/PT_WeightOnlyQuant.md">Weight-Only Quantization</a></td>
148+
<td colspan="2" align="center"><a href="./docs/3x/PT_FP8Quant.md">FP8 Quantization</a></td>
163149
<td colspan="2" align="center"><a href="./docs/source/3x/PT_MXQuant.md">MX Quantization</a></td>
164150
<td colspan="2" align="center"><a href="./docs/source/3x/PT_MixedPrecision.md">Mixed Precision</a></td>
165151
</tr>

docs/3x/PT_FP8Quant.md

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
FP8 Quantization
2+
=======
3+
4+
1. [Introduction](#introduction)
5+
2. [Supported Parameters](#supported-parameters)
6+
3. [Get Start with FP8 Quantization](#get-start-with-fp8-quantization)
7+
4. [Examples](#examples)
8+
9+
## Introduction
10+
11+
Float point 8 (FP8) is a promising data type for low precision quantization which provides a data distribution that is completely different from INT8 and it's shown as below.
12+
13+
<div align="center">
14+
<img src="./imgs/fp8_dtype.png" height="250"/>
15+
</div>
16+
17+
Intel Gaudi2, also known as HPU, provides this data type capability for low precision quantization, which includes `E4M3` and `E5M2`. For more information about these two data type, please refer to [link](https://arxiv.org/abs/2209.05433).
18+
19+
Intel Neural Compressor provides general quantization APIs to leverage HPU FP8 capability. with simple with lower memory usage and lower compute cost, 8 bit model
20+
21+
## Supported Parameters
22+
23+
<style type="text/css">
24+
.tg {border-collapse:collapse;border-spacing:0;}
25+
.tg td{border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;
26+
overflow:hidden;padding:10px 5px;word-break:normal;}
27+
.tg th{border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;
28+
font-weight:normal;overflow:hidden;padding:10px 5px;word-break:normal;}
29+
.tg .tg-fymr{border-color:inherit;font-weight:bold;text-align:left;vertical-align:top}
30+
.tg .tg-0pky{border-color:inherit;text-align:left;vertical-align:top}
31+
</style>
32+
<table class="tg"><thead>
33+
<tr>
34+
<th class="tg-fymr">Attribute</th>
35+
<th class="tg-fymr">Description</th>
36+
<th class="tg-fymr">Values</th>
37+
</tr></thead>
38+
<tbody>
39+
<tr>
40+
<td class="tg-0pky">fp8_config</td>
41+
<td class="tg-0pky">The target data type of FP8 quantization.</td>
42+
<td class="tg-0pky">E4M3 (default) - As Fig. 2<br>E5M2 - As Fig. 1.</td>
43+
</tr>
44+
<tr>
45+
<td class="tg-0pky">hp_dtype</td>
46+
<td class="tg-0pky">The high precision data type of non-FP8 operators.</td>
47+
<td class="tg-0pky">bf16 (default) - torch.bfloat16<br>fp16 - torch.float16.<br>fp32 - torch.float32.</td>
48+
</tr>
49+
<tr>
50+
<td class="tg-0pky">observer</td>
51+
<td class="tg-0pky">The observer to measure the statistics.</td>
52+
<td class="tg-0pky">maxabs (default), saves all tensors to files.</td>
53+
</tr>
54+
<tr>
55+
<td class="tg-0pky">allowlist</td>
56+
<td class="tg-0pky">List of nn.Module names or types to quantize. When setting an empty list, all the supported modules will be quantized by default. See Supported Modules. Not setting the list at all is not recommended as it will set the allowlist to these modules only: torch.nn.Linear, torch.nn.Conv2d, and BMM.</td>
57+
<td class="tg-0pky">Default = {'names': [], 'types': <span title=["Matmul","Linear","FalconLinear","KVCache","Conv2d","LoRACompatibleLinear","LoRACompatibleConv","Softmax","ModuleFusedSDPA","LinearLayer","LinearAllreduce","ScopedLinearAllReduce","LmHeadLinearAllreduce"]>FP8_WHITE_LIST}</span></td>
58+
</tr>
59+
<tr>
60+
<td class="tg-0pky">blocklist</td>
61+
<td class="tg-0pky">List of nn.Module names or types not to quantize. Defaults to empty list, so you may omit it from the config file.</td>
62+
<td class="tg-0pky">Default = {'names': [], 'types': ()}</td>
63+
</tr>
64+
<tr>
65+
<td class="tg-0pky">mode</td>
66+
<td class="tg-0pky">The mode, measure or quantize, to run HQT with.</td>
67+
<td class="tg-0pky">MEASURE - Measure statistics of all modules and emit the results to dump_stats_path.<br>QUANTIZE - Quantize and run the model according to the provided measurements.<br>AUTO (default) - Select from [MEASURE, QUANTIZE] automatically.</td>
68+
</tr>
69+
<tr>
70+
<td class="tg-0pky">dump_stats_path</td>
71+
<td class="tg-0pky">The path to save and load the measurements. The path is created up until the level before last "/". The string after the last / will be used as prefix to all the measurement files that will be created.</td>
72+
<td class="tg-0pky">Default = "./hqt_output/measure"</td>
73+
</tr>
74+
<tr>
75+
<td class="tg-0pky">scale_method</td>
76+
<td class="tg-0pky">The method for calculating the scale from the measurement.</td>
77+
<td class="tg-0pky">- without_scale - Convert to/from FP8 without scaling.<br>- unit_scale - Always use scale of 1.<br>- maxabs_hw (default) - Scale is calculated to stretch/compress the maxabs measurement to the full-scale of FP8 and then aligned to the corresponding HW accelerated scale.<br>- maxabs_pow2 - Scale is calculated to stretch/compress the maxabs measurement to the full-scale of FP8 and then rounded to the power of 2.<br>- maxabs_hw_opt_weight - Scale of model params (weights) is chosen as the scale that provides minimal mean-square-error between quantized and non-quantized weights, from all possible HW accelerated scales. Scale of activations is calculated the same as maxabs_hw.<br>- act_maxabs_pow2_weights_pcs_opt_pow2 - Scale of model params (weights) is calculated per-channel of the params tensor. The scale per-channel is calculated the same as maxabs_hw_opt_weight. Scale of activations is calculated the same as maxabs_pow2.<br>- act_maxabs_hw_weights_pcs_maxabs_pow2 - Scale of model params (weights) is calculated per-channel of the params tensor. The scale per-channel is calculated the same as maxabs_pow2. Scale of activations is calculated the same as maxabs_hw.</td>
78+
</tr>
79+
<tr>
80+
<td class="tg-0pky">measure_exclude</td>
81+
<td class="tg-0pky">If this attribute is not defined, the default is OUTPUT. Since most models do not require measuring output tensors, you can exclude it to speed up the measurement process.</td>
82+
<td class="tg-0pky">NONE - All tensors are measured.<br>OUTPUT (default) - Excludes measurement of output tensors.</td>
83+
</tr>
84+
</tbody></table>
85+
86+
## Get Start with FP8 Quantization
87+
88+
### Demo Usage
89+
90+
```python
91+
from neural_compressor.torch.quantization import (
92+
FP8Config,
93+
prepare,
94+
convert,
95+
)
96+
import torchvision.models as models
97+
98+
model = models.resnet18()
99+
qconfig = FP8Config(fp8_config="E4M3")
100+
model = prepare(model, qconfig)
101+
# customer defined calibration
102+
calib_func(model)
103+
model = convert(model)
104+
```
105+
106+
## Examples
107+
108+
| Task | Example |
109+
|----------------------|---------|
110+
| Computer Vision (CV) | [Link](../../examples/3.x_api/pytorch/cv/fp8_quant/) |
111+
| Large Language Model (LLM) | [Link](https://github.com/HabanaAI/optimum-habana-fork/tree/habana-main/examples/text-generation#running-with-fp8) |
112+
113+
> Note: For LLM, Optimum-habana provides higher performance based on modified modeling files, so here the Link of LLM goes to Optimum-habana, which utilize Intel Neural Compressor for FP8 quantization internally.

examples/.config/model_params_pytorch_3x.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@
140140
"main_script": "main.py",
141141
"batch_size": 1
142142
},
143+
"resnet18_fp8_static":{
144+
"model_src_dir": "cv/fp8_quant",
145+
"dataset_location": "/tf_dataset/pytorch/ImageNet/raw",
146+
"input_model": "",
147+
"main_script": "main.py",
148+
"batch_size": 1
149+
},
143150
"opt_125m_pt2e_static":{
144151
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e",
145152
"dataset_location": "",
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# ImageNet FP8 Quantization
2+
3+
This implements FP8 quantization of popular model architectures, such as ResNet on the ImageNet dataset, which is supported by Intel Gaudi2 AI Accelerator.
4+
5+
## Requirements
6+
7+
To try on Intel Gaudi2, docker image with Gaudi Software Stack is recommended, please refer to following script for environment setup. More details can be found in [Gaudi Guide](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#launch-docker-image-that-was-built).
8+
```bash
9+
# Run a container with an interactive shell
10+
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest
11+
```
12+
13+
- Install requirements
14+
- `pip install -r requirements.txt`
15+
- Download the ImageNet dataset from http://www.image-net.org/
16+
- Then, move and extract the training and validation images to labeled subfolders, using [the following shell script](extract_ILSVRC.sh)
17+
18+
## Quantizaiton
19+
20+
To quant a model and validate accaracy, run `main.py` with the desired model architecture and the path to the ImageNet dataset:
21+
22+
```bash
23+
python main.py --pretrained -t -a resnet50 -b 30 /path/to/imagenet
24+
```
25+
or
26+
```bash
27+
bash run_quant.sh --input_model=resnet50 --dataset_location=/path/to/imagenet
28+
```
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/bin/bash
2+
#
3+
# script to extract ImageNet dataset
4+
# ILSVRC2012_img_train.tar (about 138 GB)
5+
# ILSVRC2012_img_val.tar (about 6.3 GB)
6+
# make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar in your current directory
7+
#
8+
# Adapted from:
9+
# https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md
10+
# https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4
11+
#
12+
# imagenet/train/
13+
# ├── n01440764
14+
# │ ├── n01440764_10026.JPEG
15+
# │ ├── n01440764_10027.JPEG
16+
# │ ├── ......
17+
# ├── ......
18+
# imagenet/val/
19+
# ├── n01440764
20+
# │ ├── ILSVRC2012_val_00000293.JPEG
21+
# │ ├── ILSVRC2012_val_00002138.JPEG
22+
# │ ├── ......
23+
# ├── ......
24+
#
25+
#
26+
# Make imagnet directory
27+
#
28+
mkdir imagenet
29+
#
30+
# Extract the training data:
31+
#
32+
# Create train directory; move .tar file; change directory
33+
mkdir imagenet/train && mv ILSVRC2012_img_train.tar imagenet/train/ && cd imagenet/train
34+
# Extract training set; remove compressed file
35+
tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
36+
#
37+
# At this stage imagenet/train will contain 1000 compressed .tar files, one for each category
38+
#
39+
# For each .tar file:
40+
# 1. create directory with same name as .tar file
41+
# 2. extract and copy contents of .tar file into directory
42+
# 3. remove .tar file
43+
find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
44+
#
45+
# This results in a training directory like so:
46+
#
47+
# imagenet/train/
48+
# ├── n01440764
49+
# │ ├── n01440764_10026.JPEG
50+
# │ ├── n01440764_10027.JPEG
51+
# │ ├── ......
52+
# ├── ......
53+
#
54+
# Change back to original directory
55+
cd ../..
56+
#
57+
# Extract the validation data and move images to subfolders:
58+
#
59+
# Create validation directory; move .tar file; change directory; extract validation .tar; remove compressed file
60+
mkdir imagenet/val && mv ILSVRC2012_img_val.tar imagenet/val/ && cd imagenet/val && tar -xvf ILSVRC2012_img_val.tar && rm -f ILSVRC2012_img_val.tar
61+
# get script from soumith and run; this script creates all class directories and moves images into corresponding directories
62+
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash
63+
#
64+
# This results in a validation directory like so:
65+
#
66+
# imagenet/val/
67+
# ├── n01440764
68+
# │ ├── ILSVRC2012_val_00000293.JPEG
69+
# │ ├── ILSVRC2012_val_00002138.JPEG
70+
# │ ├── ......
71+
# ├── ......
72+
#
73+
#
74+
# Check total files after extract
75+
#
76+
# $ find train/ -name "*.JPEG" | wc -l
77+
# 1281167
78+
# $ find val/ -name "*.JPEG" | wc -l
79+
# 50000
80+
#

0 commit comments

Comments
 (0)