Skip to content

implement incbench command for ease-of-use benchmark #1884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9054793
move prettytable into inc.common
xin3he Jun 24, 2024
fd510db
add benchmark
xin3he Jun 25, 2024
29f974c
support windows
xin3he Jun 26, 2024
dabe436
fix bug
xin3he Jun 26, 2024
4f7cb7c
enable subprocess running
xin3he Jun 26, 2024
4a3b6cd
fix bug in windows
xin3he Jun 26, 2024
3cc3885
enhance log
xin3he Jun 26, 2024
b3c1091
add document
xin3he Jun 26, 2024
ca1f3b6
update platform status
xin3he Jun 26, 2024
29ebf1a
add incbench dlrm example
xin3he Jun 27, 2024
5960bb7
add more docstring
xin3he Jun 27, 2024
4c15bda
add performance test for sq opt-125m
xin3he Jun 28, 2024
60340f2
enhance pre-commit for max-line-length check
xin3he Jun 28, 2024
5f02407
add Multiple Instance Benchmark Summary
xin3he Jun 28, 2024
cc014af
Dump Throughput and Latency Summary
xin3he Jun 28, 2024
c3de633
change log folder and add UTs
xin3he Jul 1, 2024
9757779
add requirement
xin3he Jul 1, 2024
6ca810f
Merge branch 'master' into xinhe/benchmark
xin3he Jul 2, 2024
8549e92
improve UT coverage
xin3he Jul 3, 2024
0f6e057
fix pylint
xin3he Jul 3, 2024
7f3aff5
remove previous useless code
xin3he Jul 8, 2024
eeb56f6
fix bug
xin3he Jul 8, 2024
24ec333
fix pylint
xin3he Jul 8, 2024
18ca594
fix bug
xin3he Jul 8, 2024
b55b22b
Merge branch 'master' into xinhe/benchmark
chensuyue Jul 9, 2024
a524d9c
update summary format per suyue's request
xin3he Jul 9, 2024
245c75a
fdsa
xin3he Jul 9, 2024
81687bd
revert pre-commit change
xin3he Jul 9, 2024
d681fc7
Merge branch 'master' into xinhe/benchmark
xin3he Jul 10, 2024
7e73d1a
update UT
xin3he Jul 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions docs/3x/benchmark.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
Benchmark
---

1. [Introduction](#introduction)

2. [Supported Matrix](#supported-matrix)

3. [Usage](#usage)

## Introduction

Intel Neural Compressor provides a command `incbench` to launch the Intel CPU performance benchmark.

To get the peak performance on Intel Xeon CPU, we should avoid crossing NUMA node in one instance.
Therefore, by default, `incbench` will trigger 1 instance on the first NUMA node.

## Supported Matrix

| Platform | Status |
|:---:|:---:|
| Linux | ✔ |
| Windows | ✔ |

## Usage

| Parameters | Default | comments |
|:----------------------:|:------------------------:|:-------------------------------------:|
| num_instances | 1 | Number of instances |
| num_cores_per_instance | None | Number of cores in each instance |
| C, cores | 0-${num_cores_on_NUMA-1} | decides the visible core range |
| cross_memory | False | whether to allocate memory cross NUMA |

> Note: cross_memory is set to True only when memory is insufficient.

### General Use Cases

1. `incbench main.py`: run 1 instance on NUMA:0.
2. `incbench --num_i 2 main.py`: run 2 instances on NUMA:0.
3. `incbench --num_c 2 main.py`: run multi-instances with 2 cores per instance on NUMA:0.
4. `incbench -C 24-47 main.py`: run 1 instance on COREs:24-47.
5. `incbench -C 24-47 --num_c 4 main.py`: run multi-instances with 4 COREs per instance on COREs:24-47.

> Note:
> - `num_i` works the same as `num_instances`
> - `num_c` works the same as `num_cores_per_instance`

### Dump Throughput and Latency Summary

To merge benchmark results from multi-instances, "incbench" automatically checks log file messages for "throughput" and "latency" information matching the following patterns.

```python
throughput_pattern = r"[T,t]hroughput:\s*([0-9]*\.?[0-9]+)\s*([a-zA-Z/]*)"
latency_pattern = r"[L,l]atency:\s*([0-9]*\.?[0-9]+)\s*([a-zA-Z/]*)"
```

#### Demo usage

```python
print("Throughput: {:.3f} samples/sec".format(throughput))
print("Latency: {:.3f} ms".format(latency * 10**3))
```
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,34 @@ function run_benchmark {

if [ "${topology}" = "opt_125m_ipex_sq" ]; then
model_name_or_path="facebook/opt-125m"
extra_cmd=$extra_cmd" --ipex --sq --alpha 0.5"
extra_cmd=$extra_cmd" --ipex"
elif [ "${topology}" = "llama2_7b_ipex_sq" ]; then
model_name_or_path="meta-llama/Llama-2-7b-hf"
extra_cmd=$extra_cmd" --ipex --sq --alpha 0.8"
extra_cmd=$extra_cmd" --ipex"
elif [ "${topology}" = "gpt_j_ipex_sq" ]; then
model_name_or_path="EleutherAI/gpt-j-6b"
extra_cmd=$extra_cmd" --ipex --sq --alpha 1.0"
extra_cmd=$extra_cmd" --ipex"
fi

python -u run_clm_no_trainer.py \
--model ${model_name_or_path} \
--approach ${approach} \
--output_dir ${tuned_checkpoint} \
--task ${task} \
--batch_size ${batch_size} \
${extra_cmd} ${mode_cmd}
if [[ ${mode} == "accuracy" ]]; then
python -u run_clm_no_trainer.py \
--model ${model_name_or_path} \
--approach ${approach} \
--output_dir ${tuned_checkpoint} \
--task ${task} \
--batch_size ${batch_size} \
${extra_cmd} ${mode_cmd}
elif [[ ${mode} == "performance" ]]; then
incbench --num_cores_per_instance 4 run_clm_no_trainer.py \
--model ${model_name_or_path} \
--approach ${approach} \
--batch_size ${batch_size} \
--output_dir ${tuned_checkpoint} \
${extra_cmd} ${mode_cmd}
else
echo "Error: No such mode: ${mode}"
exit 1
fi
}

main "$@"
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import sys

sys.path.append('./')
sys.path.append("./")
import time
import re
import torch
Expand All @@ -12,15 +12,11 @@
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer

parser = argparse.ArgumentParser()
parser.add_argument("--model", nargs="?", default="EleutherAI/gpt-j-6b")
parser.add_argument("--trust_remote_code", default=True, help="Transformers parameter: use the external repo")
parser.add_argument(
"--model", nargs="?", default="EleutherAI/gpt-j-6b"
"--revision", default=None, help="Transformers parameter: set the model hub commit number"
)
parser.add_argument(
"--trust_remote_code", default=True,
help="Transformers parameter: use the external repo")
parser.add_argument(
"--revision", default=None,
help="Transformers parameter: set the model hub commit number")
parser.add_argument("--dataset", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k")
parser.add_argument("--output_dir", nargs="?", default="./saved_results")
parser.add_argument("--quantize", action="store_true")
Expand All @@ -29,29 +25,26 @@
action="store_true",
help="By default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)",
)
parser.add_argument("--seed", type=int, default=42, help="Seed for sampling the calibration data.")
parser.add_argument(
'--seed',
type=int, default=42, help='Seed for sampling the calibration data.'
"--approach", type=str, default="static", help="Select from ['dynamic', 'static', 'weight-only']"
)
parser.add_argument("--approach", type=str, default='static',
help="Select from ['dynamic', 'static', 'weight-only']")
parser.add_argument("--int8", action="store_true")
parser.add_argument("--ipex", action="store_true", help="Use intel extension for pytorch.")
parser.add_argument("--load", action="store_true", help="Load quantized model.")
parser.add_argument("--accuracy", action="store_true")
parser.add_argument("--performance", action="store_true")
parser.add_argument("--iters", default=100, type=int,
help="For accuracy measurement only.")
parser.add_argument("--batch_size", default=1, type=int,
help="For accuracy measurement only.")
parser.add_argument("--save_accuracy_path", default=None,
help="Save accuracy results path.")
parser.add_argument("--pad_max_length", default=512, type=int,
help="Pad input ids to max length.")
parser.add_argument("--calib_iters", default=512, type=int,
help="calibration iters.")
parser.add_argument("--tasks", default="lambada_openai,hellaswag,winogrande,piqa,wikitext",
type=str, help="tasks for accuracy validation")
parser.add_argument("--iters", default=100, type=int, help="For accuracy measurement only.")
parser.add_argument("--batch_size", default=1, type=int, help="For accuracy measurement only.")
parser.add_argument("--save_accuracy_path", default=None, help="Save accuracy results path.")
parser.add_argument("--pad_max_length", default=512, type=int, help="Pad input ids to max length.")
parser.add_argument("--calib_iters", default=512, type=int, help="calibration iters.")
parser.add_argument(
"--tasks",
default="lambada_openai,hellaswag,winogrande,piqa,wikitext",
type=str,
help="tasks for accuracy validation",
)
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")
# ============SmoothQuant configs==============
parser.add_argument("--sq", action="store_true")
Expand Down Expand Up @@ -91,7 +84,7 @@ def collate_batch(self, batch):
pad_len = self.pad_max - input_ids.shape[0]
last_ind.append(input_ids.shape[0] - 1)
if self.is_calib:
input_ids = input_ids[:self.pad_max] if len(input_ids) > self.pad_max else input_ids
input_ids = input_ids[: self.pad_max] if len(input_ids) > self.pad_max else input_ids
else:
input_ids = pad(input_ids, (0, pad_len), value=self.pad_val)
input_ids_padded.append(input_ids)
Expand Down Expand Up @@ -144,6 +137,7 @@ def get_user_model():

if args.peft_model_id is not None:
from peft import PeftModel

user_model = PeftModel.from_pretrained(user_model, args.peft_model_id)

# to channels last
Expand All @@ -158,7 +152,9 @@ def get_user_model():
calib_dataset = load_dataset(args.dataset, split="train")
# calib_dataset = datasets.load_from_disk('/your/local/dataset/pile-10k/') # use this if trouble with connecting to HF
calib_dataset = calib_dataset.shuffle(seed=args.seed)
calib_evaluator = Evaluator(calib_dataset, tokenizer, args.batch_size, pad_max=args.pad_max_length, is_calib=True)
calib_evaluator = Evaluator(
calib_dataset, tokenizer, args.batch_size, pad_max=args.pad_max_length, is_calib=True
)
calib_dataloader = DataLoader(
calib_evaluator.dataset,
batch_size=calib_size,
Expand All @@ -167,6 +163,7 @@ def get_user_model():
)

from neural_compressor.torch.quantization import SmoothQuantConfig

args.alpha = eval(args.alpha)
excluded_precisions = [] if args.int8_bf16_mixed else ["bf16"]
quant_config = SmoothQuantConfig(alpha=args.alpha, folding=False, excluded_precisions=excluded_precisions)
Expand All @@ -176,6 +173,7 @@ def get_user_model():

from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device
from tqdm import tqdm

def run_fn(model):
calib_iter = 0
for batch in tqdm(calib_dataloader, total=args.calib_iters):
Expand All @@ -186,16 +184,18 @@ def run_fn(model):
model(**batch)
else:
model(batch)

calib_iter += 1
if calib_iter >= args.calib_iters:
break
return

from utils import get_example_inputs

example_inputs = get_example_inputs(user_model, calib_dataloader)

from neural_compressor.torch.quantization import prepare, convert

user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(user_model)
user_model = convert(user_model)
Expand All @@ -207,6 +207,7 @@ def run_fn(model):
if args.int8 or args.int8_bf16_mixed:
print("load int8 model")
from neural_compressor.torch.quantization import load

tokenizer = AutoTokenizer.from_pretrained(args.model)
config = AutoConfig.from_pretrained(args.model)
user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)))
Expand All @@ -218,6 +219,7 @@ def run_fn(model):
if args.accuracy:
user_model.eval()
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser

eval_args = LMEvalParser(
model="hf",
user_model=user_model,
Expand All @@ -233,32 +235,25 @@ def run_fn(model):
else:
acc = results["results"][task_name]["acc,none"]
print("Accuracy: %.5f" % acc)
print('Batch size = %d' % args.batch_size)
print("Batch size = %d" % args.batch_size)

if args.performance:
user_model.eval()
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
batch_size, input_leng = args.batch_size, 512
example_inputs = torch.ones((batch_size, input_leng), dtype=torch.long)
print("Batch size = {:d}".format(batch_size))
print("The length of input tokens = {:d}".format(input_leng))
import time

samples = args.iters * args.batch_size
eval_args = LMEvalParser(
model="hf",
user_model=user_model,
tokenizer=tokenizer,
batch_size=args.batch_size,
tasks=args.tasks,
limit=samples,
device="cpu",
)
start = time.time()
results = evaluate(eval_args)
end = time.time()
for task_name in args.tasks.split(","):
if task_name == "wikitext":
acc = results["results"][task_name]["word_perplexity,none"]
else:
acc = results["results"][task_name]["acc,none"]
print("Accuracy: %.5f" % acc)
print('Throughput: %.3f samples/sec' % (samples / (end - start)))
print('Latency: %.3f ms' % ((end - start) * 1000 / samples))
print('Batch size = %d' % args.batch_size)
total_iters = args.iters
warmup_iters = 5
with torch.no_grad():
for i in range(total_iters):
if i == warmup_iters:
start = time.time()
user_model(example_inputs)
end = time.time()
latency = (end - start) / ((total_iters - warmup_iters) * args.batch_size)
throughput = ((total_iters - warmup_iters) * args.batch_size) / (end - start)
print("Latency: {:.3f} ms".format(latency * 10**3))
print("Throughput: {:.3f} samples/sec".format(throughput))
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def dash_separated_ints(value):
return value


def trace_model(args, dlrm, test_ld, inplace=True):
def trace_or_load_model(args, dlrm, test_ld, inplace=True):
dlrm.eval()
for j, inputBatch in enumerate(test_ld):
X, lS_o, lS_i, _, _, _ = unpack_batch(inputBatch)
Expand Down Expand Up @@ -462,7 +462,7 @@ def inference(
total_time = 0
total_iter = 0
if args.inference_only and trace:
dlrm = trace_model(args, dlrm, test_ld)
dlrm = trace_or_load_model(args, dlrm, test_ld)
if args.share_weight_instance != 0:
run_throughput_benchmark(args, dlrm, test_ld)
with torch.cpu.amp.autocast(enabled=args.bf16):
Expand Down Expand Up @@ -833,11 +833,11 @@ def eval_func(model):

# calibration
def calib_fn(model):
calib_number = 0
calib_iter = 0
for X_test, lS_o_test, lS_i_test, T in train_ld:
if calib_number < 100:
if calib_iter < 100:
model(X_test, lS_o_test, lS_i_test)
calib_number += 1
calib_iter += 1
else:
break

Expand All @@ -857,8 +857,22 @@ def calib_fn(model):
dlrm.save(args.save_model)
exit(0)
if args.benchmark:
# To do
print('Not implemented yet')
dlrm = trace_or_load_model(args, dlrm, test_ld, inplace=True)
import time
X_test, lS_o_test, lS_i_test, T = next(iter(test_ld))
total_iters = 100
warmup_iters = 5
with torch.no_grad():
for i in range(total_iters):
if i == warmup_iters:
start = time.time()
dlrm(X_test, lS_o_test, lS_i_test)
end = time.time()
latency = (end - start) / ((total_iters - warmup_iters) * args.mini_batch_size)
throughput = ((total_iters - warmup_iters) * args.mini_batch_size) / (end - start)
print('Batch size = {:d}'.format(args.mini_batch_size))
print('Latency: {:.3f} ms'.format(latency * 10**3))
print('Throughput: {:.3f} samples/sec'.format(throughput))
exit(0)

if args.accuracy_only:
Expand Down Expand Up @@ -934,7 +948,7 @@ def update_training_performance(time, iters, training_record=training_record):
training_record[0] += time
training_record[1] += 1

def print_training_performance( training_record=training_record):
def print_training_performance(training_record=training_record):
if training_record[0] == 0:
print("num-batches larger than warm up iters, please increase num-batches or decrease warmup iters")
exit()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function run_tuning {
--save-model ${tuned_checkpoint} --test-freq=2048 --print-auc $ARGS \
--load-model=${input_model} --accuracy_only
elif [[ ${mode} == "performance" ]]; then
python -u $MODEL_SCRIPT \
incbench --num_cores_per_instance 4 -u $MODEL_SCRIPT \
--raw-data-file=${dataset_location}/day --processed-data-file=${dataset_location}/terabyte_processed.npz \
--data-set=terabyte --benchmark \
--memory-map --mlperf-bin-loader --round-targets=True --learning-rate=1.0 \
Expand Down
Loading
Loading