Skip to content

Commit 75279ff

Browse files
committed
More options and benchmarking tools
1 parent 8a5c7dd commit 75279ff

10 files changed

+287
-90
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM nvcr.io/nvidia/pytorch:23.01-py3
1+
FROM nvcr.io/nvidia/pytorch:23.03-py3
22

33
ARG USER=1000
44
ARG USERNAME=user

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ bitsandbytes
33
safetensors
44
deepspeed==0.7.7
55
-e ./transformers
6+
flash-attn
67

78
# TODO: Analysis only
89
py-markdown-table
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
# Santacoder
3+
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 5 0
4+
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 5 0
5+
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 5 0
6+
7+
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 1 2040 11 1
8+
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 11 1
9+
./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 256 2040 11 1
10+
11+
# Large model
12+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 11 0
13+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 11 0
14+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 11 0
15+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 11 0 # OOM?
16+
17+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 1 8190 29 1
18+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 8 8190 29 1
19+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 32 8190 29 1
20+
./scripts/run_benchmark_breakdown.sh large_model ./data/large-model 256 8190 29 1 # OOM?

scripts/run_benchmark_breakdown.sh

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
2+
# Santacoder prefill.
3+
# ./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 5 0
4+
# Santacoder decode (fewer data points because slower)
5+
# ./scripts/run_benchmark_breakdown.sh santacoder bigcode/gpt_bigcode-santacoder 32 2040 11 1
6+
MODEL_NAME=${1:-"santacoder"}
7+
MODEL_PATH=${2:-"bigcode/gpt_bigcode-santacoder"}
8+
BATCH_SIZE=${3:-32}
9+
MAX_NEW_TOKENS=${4:-2040}
10+
# Prime number to see key length padding effect.
11+
TOKEN_STEP=${5:-5}
12+
STEP_ID=${6:-""}
13+
14+
SAVE_DIR=data/benchmarks/v2
15+
#BATCH_SIZES="1 2 4 8 16 24 32 48 64 96 128 160 224 256"
16+
RUN="python3 src/main.py --max_log_outputs=0 --dtype=float16 --device=cuda --custom_generate --breakdown_latency --ignore_oom"
17+
18+
19+
RUNTIME=("" "pre_allocate_kv_cache=True" "pre_allocate_kv_cache=True inference_runner=3")
20+
RUNTIME_NAMES=("base" "pre_allocate" "graph")
21+
22+
ATTN_NAME=("jit" "flash" "torch" "torchflash" "torchmem" "torchcpp")
23+
24+
25+
STEP=("--no_prefill" "--no_cache")
26+
STEP_NAME=("decode" "prefill")
27+
28+
COMMON="--pretrained_model=$MODEL_PATH --tokenizer=$MODEL_PATH --cycles=10 --max_input_length=1 --max_new_tokens=$MAX_NEW_TOKENS --key_length_step=$TOKEN_STEP --batch_size=$BATCH_SIZE"
29+
30+
run () { # run(step, runtime, attn)
31+
FILE_NAME="$SAVE_DIR"/"$MODEL_NAME"_bs_"$BATCH_SIZE"_tok_"$MAX_NEW_TOKENS"_step_"$TOKEN_STEP"_"${STEP_NAME[$1]}"/"${RUNTIME_NAMES[$2]}"_"${ATTN_NAME[$3]}".json
32+
if [ -f "$FILE_NAME" ];
33+
then
34+
echo "Skipping existing $FILE_NAME"
35+
else
36+
$RUN $COMMON ${STEP[$1]} ${RUNTIME[$2]} "attention_implementation=$3" --save="$FILE_NAME"
37+
fi
38+
}
39+
40+
if [ "${STEP_ID}" -eq "0" ]
41+
then
42+
# Decode
43+
for runtime in {0..2}
44+
do
45+
for attn in {0..5}
46+
do
47+
run 0 $runtime $attn
48+
done
49+
done
50+
else
51+
# Prefill (all runtimes are the same)
52+
for attn in {0..5}
53+
do
54+
run 1 0 $attn
55+
done
56+
fi

src/main.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,16 @@ def get_arg_parser() -> ArgumentParser:
3131
parser.add_argument("--device", default="cuda", type=torch.device)
3232
parser.add_argument("--dtype", default="float16", type=lambda x: getattr(torch, x))
3333
parser.add_argument("--local_rank", type=int)
34-
parser.add_argument("--no_fast_init","--nf", dest="fast_init", action="store_false")
35-
parser.add_argument("--no_cache","--nc", dest="use_cache", action="store_false")
36-
parser.add_argument("--no_prefill","--np", dest="do_prefill", action="store_false")
34+
parser.add_argument("--no_fast_init", "--nf", dest="fast_init", action="store_false")
35+
parser.add_argument("--no_cache", "--nc", dest="use_cache", action="store_false")
36+
parser.add_argument("--no_prefill", "--np", dest="do_prefill", action="store_false")
37+
parser.add_argument("--key_length_step", "--ks", default=1, type=int)
38+
parser.add_argument("--ignore_oom", "--oom", action="store_true")
3739

3840
# Input and output
39-
parser.add_argument("--batch_size","-b", default=1, type=int)
40-
parser.add_argument("--max_input_length","-i", default=-1, type=int)
41-
parser.add_argument("--max_new_tokens","-g", default=100, type=int)
41+
parser.add_argument("--batch_size", "-b", default=1, type=int)
42+
parser.add_argument("--max_input_length", "-i", default=-1, type=int)
43+
parser.add_argument("--max_new_tokens", "-g", default=100, type=int)
4244

4345
# Cleanup
4446
parser.add_argument("--clear_every_run", action="store_true")
@@ -50,11 +52,11 @@ def get_arg_parser() -> ArgumentParser:
5052

5153
# Profiling and logging
5254
parser.add_argument("--max_log_outputs", type=int)
53-
parser.add_argument("--breakdown_latency","--bl", action="store_true")
54-
parser.add_argument("--profile","-p", action="store_true")
55-
parser.add_argument("--profile_cycles","--pc", type=int)
56-
parser.add_argument("--full_trace","--pt", action="store_true")
57-
parser.add_argument("--show_op_names","--pn", action="store_true")
55+
parser.add_argument("--breakdown_latency", "--bl", action="store_true")
56+
parser.add_argument("--profile", "-p", action="store_true")
57+
parser.add_argument("--profile_cycles", "--pc", type=int)
58+
parser.add_argument("--full_trace", "--pt", action="store_true")
59+
parser.add_argument("--show_op_names", "--pn", action="store_true")
5860
parser.add_argument("--save", type=Path)
5961

6062
return parser
@@ -91,10 +93,6 @@ def main(argv: Optional[List[str]] = None) -> None:
9193
dtype=args.dtype,
9294
fast_init=args.fast_init,
9395
trust_remote_code=args.trust_remote_code,
94-
custom_generate=args.custom_generate,
95-
use_cache=args.use_cache,
96-
do_prefill=args.do_prefill,
97-
breakdown_latency=args.breakdown_latency,
9896
)
9997

10098
all_metrics = []
@@ -128,10 +126,26 @@ def main(argv: Optional[List[str]] = None) -> None:
128126
t1 = time.perf_counter()
129127
with profiler as p:
130128
for step in range(args.skip + warmup + args.cycles):
129+
log_rank_n(
130+
(
131+
f"*** Running generation step {step} "
132+
f"({'skip' if step<args.skip else 'warmup' if step<args.skip + warmup else 'benchmark'})"
133+
),
134+
logger.info,
135+
)
131136
if step == args.skip + warmup:
132137
t2 = time.perf_counter()
133138
benchmark_metrics[Metrics.RUNTIME_WARMUP] = t2 - t1
134-
generated_text, metrics = pipeline(inputs, args.max_new_tokens)
139+
generated_text, metrics = pipeline(
140+
inputs,
141+
args.max_new_tokens,
142+
custom_generate=args.custom_generate,
143+
use_cache=args.use_cache,
144+
do_prefill=args.do_prefill,
145+
breakdown_latency=args.breakdown_latency,
146+
key_length_step=args.key_length_step,
147+
ignore_oom=args.ignore_oom,
148+
)
135149
if args.profile:
136150
p.step()
137151

src/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def format_ms(t: float) -> str:
1717
return f"{1000 * t:.2f} ms"
1818

1919

20-
def format_ms_dict(t_dict: Dict[str,float]) -> Dict[str,str]:
21-
return {key:format_ms(value) for key, value in t_dict.items()}
20+
def format_ms_dict(t_dict: Dict[str, float]) -> Dict[str, str]:
21+
return {key: format_ms(value) for key, value in t_dict.items()}
2222

2323

2424
def format_mib(m: float) -> str:

src/parse_breakdown_results.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import json
2+
from argparse import ArgumentParser
3+
from pathlib import Path
4+
from typing import List, Optional
5+
6+
7+
def get_arg_parser() -> ArgumentParser:
8+
parser = ArgumentParser()
9+
parser.add_argument("input_dir", type=Path)
10+
parser.add_argument("--title")
11+
return parser
12+
13+
14+
def read_data(input_file: Path):
15+
try:
16+
with input_file.open("r") as f:
17+
data = json.load(f)
18+
data = {**data["config"], **data["results"]}
19+
except (ValueError, OSError) as e:
20+
raise ValueError(f"Cannot parse file {input_file} ({e})")
21+
data["Setting"] = input_file.stem
22+
return data
23+
24+
25+
def plot(data, title=None):
26+
import matplotlib.pyplot as plt
27+
28+
fig = plt.figure()
29+
ax = fig.add_subplot()
30+
31+
for dat in data:
32+
latency_data = dat["Latency (generate breakdown)"]
33+
ax.plot(
34+
[int(k) for k in latency_data.keys()],
35+
[v * 1000 for v in latency_data.values()],
36+
label=dat["Setting"],
37+
linewidth=1,
38+
) # , linestyle=":")#, markersize=1, marker="o")
39+
40+
ax.set_title(title)
41+
ax.set_xlabel("Sequence length")
42+
ax.set_ylabel("Latency (ms)")
43+
ax.legend()
44+
fig.show()
45+
input("Press enter to continue")
46+
47+
48+
def main(argv: Optional[List[str]] = None) -> None:
49+
parser = get_arg_parser()
50+
args = parser.parse_args(argv)
51+
data = [read_data(input_file) for input_file in args.input_dir.iterdir()]
52+
53+
if len(data) == 0:
54+
raise RuntimeError(f"No data to show.")
55+
56+
plot(data, args.title)
57+
58+
59+
if __name__ == "__main__":
60+
main()

0 commit comments

Comments
 (0)