Skip to content

Commit 1a0dbf1

Browse files
authored
Add TTFT benchmarks + update sparsity benchmarks (#1140)
This PR adds in TTFT token benchmarks to torchAO, and also updates the benchmarking script to handle sparsity a bit nicer + use the 2:4 sparse checkpoints that are available. Additionally also adds in padding support for int8 dynamic quant + 2:4 sparsity, which we were missing before.
1 parent b7630f1 commit 1a0dbf1

File tree

5 files changed

+136
-16
lines changed

5 files changed

+136
-16
lines changed

scripts/prepare.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf
22
python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B
33
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B
44
python scripts/download.py --repo_id meta-llama/Llama-3.2-3B
5+
python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4
56
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf
67
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B
78
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B
89
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B
10+
# neuralmagic doesn't come with tokenizer, so we need to copy it over
11+
mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model
12+
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4

test/prototype/test_sparse_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def test_sparse(self):
5050
sparsify_(model, semi_sparse_weight())
5151
sparse_result = model(input)
5252

53+
if compile:
54+
model = torch.compile(model)
55+
5356
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
5457

5558

torchao/_models/llama/benchmarks.sh

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr
5252
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
5353
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
5454
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt
55-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
55+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
5656
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
5757
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
5858

@@ -62,7 +62,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr
6262
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
6363
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
6464
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
65-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
65+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
6666
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
6767
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
6868

@@ -79,3 +79,20 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
7979
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1
8080
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32
8181
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128
82+
83+
# TTFT benchmarks
84+
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
85+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8000
86+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --write_result benchmark_results.txt --prefill_size 8000
87+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --write_result benchmark_results.txt --prefill_size 8000
88+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --sparsity semi-structured --write_result benchmark_results.txt --prefill_size 8000
89+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8dq --write_result benchmark_results.txt --prefill_size 8000
90+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8wo --write_result benchmark_results.txt --prefill_size 8000
91+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000
92+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured
93+
94+
# 2:4 sparse model
95+
export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4
96+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt
97+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
98+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt

torchao/_models/llama/generate.py

Lines changed: 104 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,29 @@
1717
from torchao.quantization.quant_primitives import MappingType
1818
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1919

20+
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
21+
22+
class HostEvent:
23+
def __init__(self):
24+
self.event_time = None
25+
26+
def record(self):
27+
self.event_time = time.perf_counter()
28+
29+
def elapsed_time(self, other_event):
30+
if self.event_time is None:
31+
raise ValueError("Event not recorded!")
32+
# return ms to match cuda event
33+
return abs(other_event.event_time - self.event_time) * 1000
34+
35+
def device_timer(device):
36+
if "cuda" in device:
37+
return torch.cuda.Event(enable_timing=True)
38+
elif ("cpu" in device) or ("mps" in device):
39+
return HostEvent()
40+
else:
41+
print(f"device={device} is not yet suppported")
42+
2043
def device_sync(device):
2144
if "cuda" in device:
2245
torch.cuda.synchronize(device)
@@ -98,6 +121,10 @@ def generate(
98121
kv_cache_quantization: bool = False,
99122
cache_size: Optional[int] = None,
100123
linear_causal_mask: bool=False,
124+
prefill_start_event: Optional[torch.cuda.Event]=None,
125+
prefill_end_event: Optional[torch.cuda.Event]=None,
126+
decode_start_event: Optional[torch.cuda.Event]=None,
127+
decode_end_event: Optional[torch.cuda.Event]=None,
101128
**sampling_kwargs
102129
) -> torch.Tensor:
103130
"""
@@ -128,12 +155,21 @@ def generate(
128155
model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T)
129156

130157
# execute prefill
158+
if prefill_start_event is not None:
159+
prefill_start_event.record()
131160
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
132161
seq[:, T] = next_token.squeeze()
162+
if prefill_end_event is not None:
163+
prefill_end_event.record()
164+
133165
# execute token generation
166+
if decode_start_event is not None:
167+
decode_start_event.record()
134168
input_pos = torch.tensor([T], device=device, dtype=torch.int)
135169
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
136170
seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1)
171+
if decode_end_event is not None:
172+
decode_end_event.record()
137173

138174
return seq
139175

@@ -157,6 +193,7 @@ def _load_model(checkpoint_path, device, precision):
157193
B_INST, E_INST = "[INST]", "[/INST]"
158194

159195
def main(
196+
prefill_size: Optional[int] = None,
160197
prompt: str = "Hello, my name is",
161198
interactive: bool = False,
162199
num_samples: int = 5,
@@ -166,6 +203,7 @@ def main(
166203
temperature: float = 0.8,
167204
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
168205
quantization: Optional[str] = None,
206+
sparsity: Optional[str] = None,
169207
kv_cache_quantization: bool = False,
170208
cache_size: Optional[int] = None,
171209
linear_causal_mask: bool=False,
@@ -181,6 +219,10 @@ def main(
181219
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
182220
"""
183221

222+
if prefill_size is not None and prefill_size > 0:
223+
# create prompt of prefill size
224+
prompt = "prompt " * (int(prefill_size)-3)
225+
184226
torchao.quantization.utils.recommended_inductor_config_setter()
185227

186228
assert checkpoint_path.is_file(), checkpoint_path
@@ -205,6 +247,14 @@ def main(
205247

206248
torch.manual_seed(1234)
207249

250+
def ffn_only(mod, fqn):
251+
return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn
252+
253+
def not_ffn_only(mod, fqn):
254+
return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn)
255+
256+
def ffn_or_attn_only(mod, fqn):
257+
return isinstance(mod, torch.nn.Linear) and ("feed_forward" in fqn or "attention" in fqn)
208258

209259
if quantization:
210260
from torchao.quantization import (
@@ -228,9 +278,14 @@ def main(
228278
apply_spinquant(model)
229279
if "int8wo" in quantization:
230280
quantize_(model, int8_weight_only())
231-
elif "int8dq" in quantization:
232-
quantize_(model, int8_dynamic_activation_int8_weight())
233-
elif "int4wo" in quantization:
281+
if "int8dq" in quantization:
282+
if sparsity and "semi" in sparsity:
283+
from torchao.dtypes import SemiSparseLayout
284+
quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), filter_fn=ffn_only)
285+
quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only)
286+
else:
287+
quantize_(model, int8_dynamic_activation_int8_weight())
288+
if "int4wo" in quantization:
234289
if "hqq" in quantization:
235290
use_hqq=True
236291
else:
@@ -250,9 +305,9 @@ def main(
250305
layout=MarlinQQQLayout(),
251306
),
252307
)
253-
else:
308+
elif "semi" in sparsity:
254309
from torchao.dtypes import MarlinSparseLayout
255-
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
310+
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()), filter_fn=ffn_or_attn_only)
256311
if "fp6" in quantization:
257312
quantize_(model, fpx_weight_only(3, 2))
258313
elif "embed-int8wo" in quantization:
@@ -440,6 +495,13 @@ def main(
440495
if not TORCH_VERSION_AT_LEAST_2_5:
441496
unwrap_tensor_subclass(model)
442497

498+
# standalone sparsity
499+
elif sparsity:
500+
from torchao.sparsity import semi_sparse_weight, sparsify_
501+
if "semi" in sparsity:
502+
#TODO there is a bug here, need to fix
503+
sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only)
504+
443505
model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9
444506

445507
if save:
@@ -465,6 +527,9 @@ def main(
465527

466528
aggregate_metrics = {
467529
'tokens_per_sec': [],
530+
'time': [],
531+
'decode_tokens_per_sec': [],
532+
'prefill_time': [],
468533
}
469534
start = -1 if compile else 0
470535

@@ -499,6 +564,8 @@ def callback(x):
499564
else:
500565
callback = lambda x : x
501566
t0 = time.perf_counter()
567+
prefill_start_event, prefill_end_event = device_timer(device), device_timer(device)
568+
decode_start_event, decode_end_event = device_timer(device), device_timer(device)
502569
import contextlib
503570
if (i != num_samples - 1 or not profile):
504571
prof = contextlib.nullcontext()
@@ -518,6 +585,10 @@ def callback(x):
518585
kv_cache_quantization=kv_cache_quantization,
519586
cache_size=cache_size,
520587
linear_causal_mask=linear_causal_mask,
588+
prefill_start_event=prefill_start_event,
589+
prefill_end_event=prefill_end_event,
590+
decode_start_event=decode_start_event,
591+
decode_end_event=decode_end_event,
521592
)
522593
if i == -1:
523594
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
@@ -527,7 +598,7 @@ def callback(x):
527598
device_sync(device=device) # MKG
528599
t = time.perf_counter() - t0
529600

530-
if not interactive:
601+
if not interactive and prefill_size is None:
531602
tok_list = y[0].tolist()
532603
# truncate text after end of string token
533604
tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())]
@@ -537,7 +608,14 @@ def callback(x):
537608
tokens_generated = (y.size(-1) - prompt_length)
538609
tokens_sec = tokens_generated / t
539610
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
540-
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
611+
aggregate_metrics['time'].append(t)
612+
decode_time = decode_start_event.elapsed_time(decode_end_event) / 1000
613+
decode_tokens_sec = tokens_generated / decode_time
614+
aggregate_metrics['decode_tokens_per_sec'].append(decode_tokens_sec)
615+
prefill_time = prefill_start_event.elapsed_time(prefill_end_event) / 1000
616+
aggregate_metrics['prefill_time'].append(prefill_time)
617+
print(f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec",
618+
f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec")
541619
print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s")
542620

543621
if memory_profile and i==0:
@@ -558,8 +636,15 @@ def callback(x):
558636
break
559637
print("==========")
560638

639+
#ignore first sample for warmup
561640
tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
641+
ttft = torch.mean(torch.tensor(aggregate_metrics['prefill_time'])).item()
642+
decode_tokpersec = torch.mean(torch.tensor(aggregate_metrics['decode_tokens_per_sec'])).item()
562643
bandwidth = model_size * tokpersec
644+
mem = torch.cuda.max_memory_reserved() /1e9
645+
print(f"Average overall tokens/sec: {tokpersec:.2f}")
646+
print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s")
647+
print(f"Average TTFT: {ttft:.04f} s")
563648
if device == "cuda":
564649
mem = torch.cuda.max_memory_reserved() /1e9
565650
elif device == "xpu":
@@ -571,15 +656,17 @@ def callback(x):
571656
print(f"Peak Memory Usage: {mem:.02f} GB")
572657
print(f"Model Size: {model_size:.02f} GB")
573658
if write_result:
574-
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
575-
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
659+
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
660+
result_txt += f"quant: {quantization}, sparse: {sparsity}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
576661
result_txt += f"repro: python generate.py "
577662
result_txt += f"--quantization {quantization} " if quantization else ""
663+
result_txt += f"--sparsity {sparsity} " if sparsity else ""
578664
result_txt += f"--checkpoint_path {checkpoint_path} "
579665
result_txt += f"--device {device} "
580666
result_txt += f"--precision {precision} "
581667
result_txt += f"--compile " if compile else ""
582668
result_txt += f"--compile_prefill " if compile_prefill else ""
669+
result_txt += f"--prefill_size {prefill_size}" if prefill_size else ""
583670
result_txt += f"--profile {profile} " if profile else ""
584671
result_txt += f"--profile {memory_profile} " if memory_profile else ""
585672
result_txt += f"--interactive " if interactive else ""
@@ -601,7 +688,7 @@ def callback(x):
601688
if __name__ == '__main__':
602689
import argparse
603690
parser = argparse.ArgumentParser(description='Your CLI description.')
604-
691+
parser.add_argument('--prefill_size', type=int, default=0, help='Whether to run in ttft mode')
605692
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
606693
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
607694
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
@@ -617,6 +704,11 @@ def callback(x):
617704
+'embed-int8wo, marlin_qqq'
618705
)
619706
)
707+
parser.add_argument('-s', '--sparsity', type=str,
708+
help=(
709+
'Which sparsity techniques to apply: semi-structured'
710+
)
711+
)
620712
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
621713
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
622714
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
@@ -631,6 +723,6 @@ def callback(x):
631723

632724
args = parser.parse_args()
633725
main(
634-
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
635-
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
726+
args.prefill_size, args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
727+
args.temperature, args.checkpoint_path, args.quantization, args.sparsity, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
636728
)

0 commit comments

Comments
 (0)