17
17
from torchao .quantization .quant_primitives import MappingType
18
18
from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
19
19
20
+ torch .sparse .SparseSemiStructuredTensor ._FORCE_CUTLASS = False
21
+
22
+ class HostEvent :
23
+ def __init__ (self ):
24
+ self .event_time = None
25
+
26
+ def record (self ):
27
+ self .event_time = time .perf_counter ()
28
+
29
+ def elapsed_time (self , other_event ):
30
+ if self .event_time is None :
31
+ raise ValueError ("Event not recorded!" )
32
+ # return ms to match cuda event
33
+ return abs (other_event .event_time - self .event_time ) * 1000
34
+
35
+ def device_timer (device ):
36
+ if "cuda" in device :
37
+ return torch .cuda .Event (enable_timing = True )
38
+ elif ("cpu" in device ) or ("mps" in device ):
39
+ return HostEvent ()
40
+ else :
41
+ print (f"device={ device } is not yet suppported" )
42
+
20
43
def device_sync (device ):
21
44
if "cuda" in device :
22
45
torch .cuda .synchronize (device )
@@ -98,6 +121,10 @@ def generate(
98
121
kv_cache_quantization : bool = False ,
99
122
cache_size : Optional [int ] = None ,
100
123
linear_causal_mask : bool = False ,
124
+ prefill_start_event : Optional [torch .cuda .Event ]= None ,
125
+ prefill_end_event : Optional [torch .cuda .Event ]= None ,
126
+ decode_start_event : Optional [torch .cuda .Event ]= None ,
127
+ decode_end_event : Optional [torch .cuda .Event ]= None ,
101
128
** sampling_kwargs
102
129
) -> torch .Tensor :
103
130
"""
@@ -128,12 +155,21 @@ def generate(
128
155
model .setup_caches (max_batch_size = batch_size , max_seq_length = cache_size , kv_cache_quantization = kv_cache_quantization , linear_causal_mask = linear_causal_mask , prompt_length = T )
129
156
130
157
# execute prefill
158
+ if prefill_start_event is not None :
159
+ prefill_start_event .record ()
131
160
next_token = prefill (model , prompt .view (batch_size , - 1 ), input_pos , ** sampling_kwargs ).clone ()
132
161
seq [:, T ] = next_token .squeeze ()
162
+ if prefill_end_event is not None :
163
+ prefill_end_event .record ()
164
+
133
165
# execute token generation
166
+ if decode_start_event is not None :
167
+ decode_start_event .record ()
134
168
input_pos = torch .tensor ([T ], device = device , dtype = torch .int )
135
169
generated_tokens , _ = decode_n_tokens (model , next_token .view (batch_size , - 1 ), input_pos , new_tokens - 1 , callback = callback , ** sampling_kwargs )
136
170
seq = torch .cat ((seq [:, :T + 1 ], * generated_tokens ), dim = - 1 )
171
+ if decode_end_event is not None :
172
+ decode_end_event .record ()
137
173
138
174
return seq
139
175
@@ -157,6 +193,7 @@ def _load_model(checkpoint_path, device, precision):
157
193
B_INST , E_INST = "[INST]" , "[/INST]"
158
194
159
195
def main (
196
+ prefill_size : Optional [int ] = None ,
160
197
prompt : str = "Hello, my name is" ,
161
198
interactive : bool = False ,
162
199
num_samples : int = 5 ,
@@ -166,6 +203,7 @@ def main(
166
203
temperature : float = 0.8 ,
167
204
checkpoint_path : Path = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ),
168
205
quantization : Optional [str ] = None ,
206
+ sparsity : Optional [str ] = None ,
169
207
kv_cache_quantization : bool = False ,
170
208
cache_size : Optional [int ] = None ,
171
209
linear_causal_mask : bool = False ,
@@ -181,6 +219,10 @@ def main(
181
219
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
182
220
"""
183
221
222
+ if prefill_size is not None and prefill_size > 0 :
223
+ # create prompt of prefill size
224
+ prompt = "prompt " * (int (prefill_size )- 3 )
225
+
184
226
torchao .quantization .utils .recommended_inductor_config_setter ()
185
227
186
228
assert checkpoint_path .is_file (), checkpoint_path
@@ -205,6 +247,14 @@ def main(
205
247
206
248
torch .manual_seed (1234 )
207
249
250
+ def ffn_only (mod , fqn ):
251
+ return isinstance (mod , torch .nn .Linear ) and "feed_forward" in fqn
252
+
253
+ def not_ffn_only (mod , fqn ):
254
+ return isinstance (mod , torch .nn .Linear ) and not ffn_only (mod , fqn )
255
+
256
+ def ffn_or_attn_only (mod , fqn ):
257
+ return isinstance (mod , torch .nn .Linear ) and ("feed_forward" in fqn or "attention" in fqn )
208
258
209
259
if quantization :
210
260
from torchao .quantization import (
@@ -228,9 +278,14 @@ def main(
228
278
apply_spinquant (model )
229
279
if "int8wo" in quantization :
230
280
quantize_ (model , int8_weight_only ())
231
- elif "int8dq" in quantization :
232
- quantize_ (model , int8_dynamic_activation_int8_weight ())
233
- elif "int4wo" in quantization :
281
+ if "int8dq" in quantization :
282
+ if sparsity and "semi" in sparsity :
283
+ from torchao .dtypes import SemiSparseLayout
284
+ quantize_ (model , int8_dynamic_activation_int8_weight (layout = SemiSparseLayout ()), filter_fn = ffn_only )
285
+ quantize_ (model , int8_dynamic_activation_int8_weight (), filter_fn = not_ffn_only )
286
+ else :
287
+ quantize_ (model , int8_dynamic_activation_int8_weight ())
288
+ if "int4wo" in quantization :
234
289
if "hqq" in quantization :
235
290
use_hqq = True
236
291
else :
@@ -250,9 +305,9 @@ def main(
250
305
layout = MarlinQQQLayout (),
251
306
),
252
307
)
253
- else :
308
+ elif "semi" in sparsity :
254
309
from torchao .dtypes import MarlinSparseLayout
255
- quantize_ (model , int4_weight_only (layout = MarlinSparseLayout ()))
310
+ quantize_ (model , int4_weight_only (layout = MarlinSparseLayout ()), filter_fn = ffn_or_attn_only )
256
311
if "fp6" in quantization :
257
312
quantize_ (model , fpx_weight_only (3 , 2 ))
258
313
elif "embed-int8wo" in quantization :
@@ -440,6 +495,13 @@ def main(
440
495
if not TORCH_VERSION_AT_LEAST_2_5 :
441
496
unwrap_tensor_subclass (model )
442
497
498
+ # standalone sparsity
499
+ elif sparsity :
500
+ from torchao .sparsity import semi_sparse_weight , sparsify_
501
+ if "semi" in sparsity :
502
+ #TODO there is a bug here, need to fix
503
+ sparsify_ (model .to (device ), semi_sparse_weight (), filter_fn = ffn_only )
504
+
443
505
model_size = get_model_size_in_bytes (model , ignore_embeddings = True ) / 1e9
444
506
445
507
if save :
@@ -465,6 +527,9 @@ def main(
465
527
466
528
aggregate_metrics = {
467
529
'tokens_per_sec' : [],
530
+ 'time' : [],
531
+ 'decode_tokens_per_sec' : [],
532
+ 'prefill_time' : [],
468
533
}
469
534
start = - 1 if compile else 0
470
535
@@ -499,6 +564,8 @@ def callback(x):
499
564
else :
500
565
callback = lambda x : x
501
566
t0 = time .perf_counter ()
567
+ prefill_start_event , prefill_end_event = device_timer (device ), device_timer (device )
568
+ decode_start_event , decode_end_event = device_timer (device ), device_timer (device )
502
569
import contextlib
503
570
if (i != num_samples - 1 or not profile ):
504
571
prof = contextlib .nullcontext ()
@@ -518,6 +585,10 @@ def callback(x):
518
585
kv_cache_quantization = kv_cache_quantization ,
519
586
cache_size = cache_size ,
520
587
linear_causal_mask = linear_causal_mask ,
588
+ prefill_start_event = prefill_start_event ,
589
+ prefill_end_event = prefill_end_event ,
590
+ decode_start_event = decode_start_event ,
591
+ decode_end_event = decode_end_event ,
521
592
)
522
593
if i == - 1 :
523
594
print (f"Compilation time: { time .perf_counter () - t0 :.2f} seconds" )
@@ -527,7 +598,7 @@ def callback(x):
527
598
device_sync (device = device ) # MKG
528
599
t = time .perf_counter () - t0
529
600
530
- if not interactive :
601
+ if not interactive and prefill_size is None :
531
602
tok_list = y [0 ].tolist ()
532
603
# truncate text after end of string token
533
604
tokens = tok_list if not tokenizer .eos_id () in tok_list else tok_list [:tok_list .index (tokenizer .eos_id ())]
@@ -537,7 +608,14 @@ def callback(x):
537
608
tokens_generated = (y .size (- 1 ) - prompt_length )
538
609
tokens_sec = tokens_generated / t
539
610
aggregate_metrics ['tokens_per_sec' ].append (tokens_sec )
540
- print (f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec" )
611
+ aggregate_metrics ['time' ].append (t )
612
+ decode_time = decode_start_event .elapsed_time (decode_end_event ) / 1000
613
+ decode_tokens_sec = tokens_generated / decode_time
614
+ aggregate_metrics ['decode_tokens_per_sec' ].append (decode_tokens_sec )
615
+ prefill_time = prefill_start_event .elapsed_time (prefill_end_event ) / 1000
616
+ aggregate_metrics ['prefill_time' ].append (prefill_time )
617
+ print (f"Sample { i + 1 } | overall time { t :.04f} s { tokens_sec :.02f} tokens/sec" ,
618
+ f"| prefill time { prefill_time :.04f} s decode { decode_tokens_sec :.02f} tokens/sec" )
541
619
print (f"Bandwidth achieved: { model_size * tokens_sec :.02f} GB/s" )
542
620
543
621
if memory_profile and i == 0 :
@@ -558,8 +636,15 @@ def callback(x):
558
636
break
559
637
print ("==========" )
560
638
639
+ #ignore first sample for warmup
561
640
tokpersec = torch .mean (torch .tensor (aggregate_metrics ['tokens_per_sec' ])).item ()
641
+ ttft = torch .mean (torch .tensor (aggregate_metrics ['prefill_time' ])).item ()
642
+ decode_tokpersec = torch .mean (torch .tensor (aggregate_metrics ['decode_tokens_per_sec' ])).item ()
562
643
bandwidth = model_size * tokpersec
644
+ mem = torch .cuda .max_memory_reserved () / 1e9
645
+ print (f"Average overall tokens/sec: { tokpersec :.2f} " )
646
+ print (f"Average decode tokens/sec: { decode_tokens_sec :.04f} s" )
647
+ print (f"Average TTFT: { ttft :.04f} s" )
563
648
if device == "cuda" :
564
649
mem = torch .cuda .max_memory_reserved () / 1e9
565
650
elif device == "xpu" :
@@ -571,15 +656,17 @@ def callback(x):
571
656
print (f"Peak Memory Usage: { mem :.02f} GB" )
572
657
print (f"Model Size: { model_size :.02f} GB" )
573
658
if write_result :
574
- result_txt = f"\n { datetime .today ().strftime ('%Y%m%d%H%M%S' )} , tok/s={ tokpersec :6.2f} , mem/s={ bandwidth :7.2f} GB/s, peak_mem={ mem :5.2f} GB, model_size={ model_size :5.2f} GB "
575
- result_txt += f"quant: { quantization } , mod: { checkpoint_path .parent .name } , kv_quant: { kv_cache_quantization } , compile: { compile } , compile_prefill: { compile_prefill } , dtype: { precision } , device: { device } "
659
+ result_txt = f"\n { datetime .today ().strftime ('%Y%m%d%H%M%S' )} , tok/s={ tokpersec :6.2f} , tok/s_decode= { decode_tokpersec :6.2f } , ttft= { ttft :5.4f } , mem/s={ bandwidth :7.2f} GB/s, peak_mem={ mem :5.2f} GB, model_size={ model_size :5.2f} GB "
660
+ result_txt += f"quant: { quantization } , sparse: { sparsity } , mod: { checkpoint_path .parent .name } , kv_quant: { kv_cache_quantization } , compile: { compile } , compile_prefill: { compile_prefill } , dtype: { precision } , device: { device } "
576
661
result_txt += f"repro: python generate.py "
577
662
result_txt += f"--quantization { quantization } " if quantization else ""
663
+ result_txt += f"--sparsity { sparsity } " if sparsity else ""
578
664
result_txt += f"--checkpoint_path { checkpoint_path } "
579
665
result_txt += f"--device { device } "
580
666
result_txt += f"--precision { precision } "
581
667
result_txt += f"--compile " if compile else ""
582
668
result_txt += f"--compile_prefill " if compile_prefill else ""
669
+ result_txt += f"--prefill_size { prefill_size } " if prefill_size else ""
583
670
result_txt += f"--profile { profile } " if profile else ""
584
671
result_txt += f"--profile { memory_profile } " if memory_profile else ""
585
672
result_txt += f"--interactive " if interactive else ""
@@ -601,7 +688,7 @@ def callback(x):
601
688
if __name__ == '__main__' :
602
689
import argparse
603
690
parser = argparse .ArgumentParser (description = 'Your CLI description.' )
604
-
691
+ parser . add_argument ( '--prefill_size' , type = int , default = 0 , help = 'Whether to run in ttft mode' )
605
692
parser .add_argument ('--prompt' , type = str , default = "Hello, my name is" , help = 'Input prompt.' )
606
693
parser .add_argument ('--interactive' , action = 'store_true' , help = 'Whether to launch in interactive mode' )
607
694
parser .add_argument ('--num_samples' , type = int , default = 5 , help = 'Number of samples.' )
@@ -617,6 +704,11 @@ def callback(x):
617
704
+ 'embed-int8wo, marlin_qqq'
618
705
)
619
706
)
707
+ parser .add_argument ('-s' , '--sparsity' , type = str ,
708
+ help = (
709
+ 'Which sparsity techniques to apply: semi-structured'
710
+ )
711
+ )
620
712
parser .add_argument ('--kv_cache_quantization' , action = 'store_true' , help = 'Whether to quantize the KV cache' )
621
713
parser .add_argument ('--cache_size' , type = int , default = None , help = 'Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size' )
622
714
parser .add_argument ('--linear_causal_mask' , action = 'store_true' , help = 'Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)' )
@@ -631,6 +723,6 @@ def callback(x):
631
723
632
724
args = parser .parse_args ()
633
725
main (
634
- args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .batch_size , args .top_k ,
635
- args .temperature , args .checkpoint_path , args .quantization , args .kv_cache_quantization , args .cache_size , args .linear_causal_mask , args .save , args .compile , args .compile_prefill , args .profile , args .memory_profile , args .device , args .precision , args .write_result
726
+ args .prefill_size , args . prompt , args .interactive , args .num_samples , args .max_new_tokens , args .batch_size , args .top_k ,
727
+ args .temperature , args .checkpoint_path , args .quantization , args .sparsity , args . kv_cache_quantization , args .cache_size , args .linear_causal_mask , args .save , args .compile , args .compile_prefill , args .profile , args .memory_profile , args .device , args .precision , args .write_result
636
728
)
0 commit comments