18
18
AutoTokenizer ,
19
19
PretrainedConfig ,
20
20
PreTrainedModel ,
21
+ GPTBigCodeConfig ,GPTBigCodeForCausalLM
21
22
)
22
23
23
24
@@ -37,25 +38,41 @@ def __init__(
37
38
dtype : torch .dtype ,
38
39
fast_init : bool = True ,
39
40
trust_remote_code : bool = False ,
41
+ custom_generate :bool = False ,
42
+ use_cache : bool = True ,
43
+ do_prefill : bool = True ,
44
+ breakdown_latency = False ,
40
45
):
41
46
self .global_metrics = {}
42
47
log_rank_n ("*** Setting up tokenizer" , logger .info )
43
- t0 = time .perf_counter ()
44
- self .tokenizer = AutoTokenizer .from_pretrained (tokenizer )
48
+ t0 = self ._get_time ()
49
+ self .tokenizer = AutoTokenizer .from_pretrained (tokenizer , padding_side = "left" )
50
+ if self .tokenizer .pad_token is None :
51
+ self .tokenizer .pad_token = self .tokenizer .eos_token
45
52
46
- self .tokenizer .add_special_tokens ({"pad_token" : "[PAD]" })
47
- t1 = time .perf_counter ()
53
+ t1 = self ._get_time ()
48
54
49
55
self .device = device
56
+ if self .device == torch .device ("cuda" ):
57
+ self .device = torch .device ("cuda:0" )
58
+
50
59
self .dtype = dtype
51
60
self .is_int8 = self .dtype == torch .int8
52
61
self .fast_init = fast_init
53
62
self .trust_remote_code = trust_remote_code
54
- if self .is_int8 and self .device != torch .device ("cuda" ):
63
+ self .use_cache = use_cache
64
+ self .do_prefill = do_prefill
65
+ if not self .do_prefill :
66
+ assert custom_generate
67
+ assert self .use_cache
68
+ self .breakdown_latency = breakdown_latency
69
+ if self .is_int8 and self .device != torch .device ("cuda:0" ):
55
70
raise ValueError (f"Model quantization not supported on device { self .device } " )
56
71
72
+ self ._generate = self ._generate_custom if custom_generate else self ._generate_hf
73
+
57
74
self .config = self ._get_config (model_type , pretrained_config or pretrained_model , config_args )
58
- t2 = time . perf_counter ()
75
+ t2 = self . _get_time ()
59
76
60
77
logger .info (f"Model configuration: { self .config } " )
61
78
@@ -67,27 +84,27 @@ def __init__(
67
84
self .model = self ._load_pretrained (pretrained_model )
68
85
69
86
self .model .eval ()
70
- t3 = time . perf_counter ()
87
+ t3 = self . _get_time ()
71
88
self .global_metrics [Metrics .INIT_TOKEN ] = t1 - t0
72
89
self .global_metrics [Metrics .INIT_CONFIG ] = t2 - t1
73
90
self .global_metrics [Metrics .INIT_TOTAL ] = t3 - t0
74
91
75
92
def _create_model (self ) -> PreTrainedModel :
76
- t0 = time . perf_counter ()
93
+ t0 = self . _get_time ()
77
94
log_rank_n ("*** Creating model" , logger .info )
78
95
with fast_init (self .device ) if self .fast_init else contextlib .nullcontext ():
79
96
torch_dtype = torch .float16 if self .is_int8 else self .dtype
80
97
model = AutoModelForCausalLM .from_config (
81
98
config = self .config , torch_dtype = torch_dtype , trust_remote_code = self .trust_remote_code
82
99
)
83
- t1 = time . perf_counter ()
100
+ t1 = self . _get_time ()
84
101
log_rank_n ("*** Moving to device" , logger .info )
85
102
model .to (self .device )
86
- t2 = time . perf_counter ()
103
+ t2 = self . _get_time ()
87
104
log_rank_n ("*** Initializing weights" , logger .info )
88
105
# Initialization is ~1000x faster on GPU.
89
106
model .init_weights ()
90
- t3 = time . perf_counter ()
107
+ t3 = self . _get_time ()
91
108
self .global_metrics [Metrics .INIT_CREATE ] = t1 - t0
92
109
self .global_metrics [Metrics .INIT_DEVICE ] = t2 - t1
93
110
self .global_metrics [Metrics .INIT_WEIGHTS ] = t3 - t2
@@ -101,14 +118,14 @@ def _reload_model(self):
101
118
self .model = self ._load_pretrained ("tmp" )
102
119
103
120
def _save_pretrained (self , pretrained_model : str ):
104
- t0 = time . perf_counter ()
121
+ t0 = self . _get_time ()
105
122
log_rank_n (f"*** Saving model to { pretrained_model } " , logger .info )
106
- t1 = time . perf_counter ()
123
+ t1 = self . _get_time ()
107
124
self .global_metrics [Metrics .INIT_SAVE ] = t1 - t0
108
125
self .model .save_pretrained (pretrained_model )
109
126
110
127
def _load_pretrained (self , pretrained_model : str ) -> PreTrainedModel :
111
- t0 = time . perf_counter ()
128
+ t0 = self . _get_time ()
112
129
log_rank_n (f"*** Loading model from { pretrained_model } " , logger .info )
113
130
kwargs = {"load_in_8bit" : True , "device_map" : "auto" } if self .is_int8 else {"torch_dtype" : self .dtype }
114
131
with fast_init (self .device ) if self .fast_init else contextlib .nullcontext ():
@@ -120,12 +137,12 @@ def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel:
120
137
trust_remote_code = self .trust_remote_code ,
121
138
** kwargs ,
122
139
)
123
- t1 = time . perf_counter ()
140
+ t1 = self . _get_time ()
124
141
self .global_metrics ["load pretrained model" ] = t1 - t0
125
142
if not self .is_int8 :
126
143
log_rank_n ("*** Moving to device" , logger .info )
127
144
model = model .to (self .device )
128
- t2 = time . perf_counter ()
145
+ t2 = self . _get_time ()
129
146
self .global_metrics [Metrics .INIT_DEVICE ] = t2 - t1
130
147
return model
131
148
@@ -171,26 +188,103 @@ def _get_config(
171
188
172
189
return config
173
190
174
- def __call__ (self , text : List [str ], ** generate_kwargs ) -> Tuple [List [str ], Dict [str , Any ]]:
175
- t0 = time .perf_counter ()
176
- inputs = self .tokenizer (text , return_tensors = "pt" , padding = True )
191
+ def _get_time (self , synchronize = False ):
192
+ if synchronize :
193
+ torch .cuda .synchronize ()
194
+ return time .perf_counter ()
195
+
196
+ def _generate_custom (self , inputs :Dict , max_new_tokens :int ):
197
+ t0 = self ._get_time (self .breakdown_latency )
198
+ batch_size , input_length = inputs ["input_ids" ].shape
199
+ output_length = input_length + max_new_tokens
200
+ input_ids = torch .empty ([batch_size , output_length ], dtype = torch .int64 , device = self .device )
201
+ input_ids [:, :input_length ].copy_ (inputs ["input_ids" ])
177
202
203
+ attention_mask = torch .empty ([batch_size , output_length ], dtype = torch .bool , device = self .device )
204
+ attention_mask [:, :input_length ].copy_ (inputs ["attention_mask" ])
205
+ attention_mask [:, input_length :].fill_ (True )
206
+
207
+ position_ids = attention_mask .long ().cumsum (- 1 , dtype = torch .int64 ) - 1
208
+ # TODO: Useless?
209
+ position_ids [:, :input_length ].masked_fill_ (attention_mask [:, :input_length ] == 0 , 1 )
210
+
211
+ if self .do_prefill or input_length <= 1 :
212
+ past_key_values = None
213
+ past_key_length = 0
214
+ else :
215
+ # Generate mock `past_key_values`
216
+ past_key_length = input_length - 1
217
+ if isinstance (self .config , GPTBigCodeConfig ):
218
+ if self .config .pre_allocate_kv_cache :
219
+ past_key_values = [past_key_length ]* self .config .n_layer
220
+ for block in self .model .transformer .h :
221
+ block .attn .get_kv_cache (batch_size , past_key_length , dtype = self .dtype , device = self .device ).normal_ ()
222
+ else :
223
+ kv_dim = self .config .n_embd // self .config .n_head if self .config .multi_query else self .config .n_embd
224
+ past_key_values = [torch .randn ([batch_size , past_key_length , 2 * kv_dim ], dtype = self .dtype , device = self .device ) for _ in range (self .config .n_layer )]
225
+ else :
226
+ past_key_values = [
227
+ [torch .randn ([batch_size , past_key_length , self .config .n_embd ], dtype = self .dtype , device = self .device ) for _ in range (2 )] for _ in
228
+ range (self .config .n_layer )]
229
+
230
+ t1 = self ._get_time (self .breakdown_latency )
231
+ last_time = t1
232
+ generate_times = {}
233
+ for key_length in range (input_length , output_length ):
234
+ outputs = self .model (
235
+ input_ids = input_ids [:, past_key_length :key_length ],
236
+ past_key_values = past_key_values ,
237
+ attention_mask = attention_mask [:, :key_length ],
238
+ position_ids = position_ids [:, past_key_length :key_length ],
239
+ return_dict = True ,
240
+ use_cache = self .use_cache ,
241
+ )
242
+ if self .use_cache :
243
+ past_key_values = outputs .past_key_values
244
+ past_key_length = key_length
245
+ next_tokens = torch .argmax (outputs .logits [:, - 1 , :], dim = - 1 )
246
+ input_ids [:, key_length ] = next_tokens
247
+ t2 = self ._get_time (self .breakdown_latency )
248
+ generate_times [key_length ]= t2 - last_time
249
+ last_time = t2
250
+
251
+ metrics = {}
252
+ if self .breakdown_latency :
253
+ metrics [Metrics .LATENCY_GENERATE_START ]= t1 - t0
254
+ metrics [Metrics .LATENCY_GENERATE_BREAKDOWN ]= generate_times
255
+
256
+ return input_ids , metrics
257
+
258
+ def _generate_hf (self , inputs :Dict , max_new_tokens :int ):
178
259
inputs = {key : value .to (self .device ) if torch .is_tensor (value ) else value for key , value in inputs .items ()}
260
+ output = self .model .generate (
261
+ ** inputs ,
262
+ return_dict_in_generate = True ,
263
+ max_new_tokens = max_new_tokens ,
264
+ do_sample = False ,
265
+ pad_token_id = self .tokenizer .pad_token_id ,
266
+ use_cache = self .use_cache ,
267
+ )
268
+ return output .sequences , {}
179
269
180
- t1 = time .perf_counter ()
181
- with torch .inference_mode ():
182
- output = self .model .generate (** inputs , return_dict_in_generate = True , ** generate_kwargs )
183
- t2 = time .perf_counter ()
184
270
185
- output_tokens = output .sequences
271
+ def __call__ (self , text : List [str ], max_new_tokens :int ) -> Tuple [List [str ], Dict [str , Any ]]:
272
+ t0 = self ._get_time ()
273
+ inputs = self .tokenizer (text , return_tensors = "pt" , padding = True )
274
+
275
+ t1 = self ._get_time ()
276
+ with torch .inference_mode ():
277
+ output_tokens , generate_metrics = self ._generate (inputs , max_new_tokens )
278
+ t2 = self ._get_time (True )
186
279
187
280
batch_size , input_length = inputs ["input_ids" ].shape
188
281
output_length = output_tokens .size (1 )
189
282
190
283
output_text = self .tokenizer .batch_decode (output_tokens .cpu (), skip_special_tokens = True )
191
- t3 = time . perf_counter ()
284
+ t3 = self . _get_time ()
192
285
193
286
metrics = {
287
+ ** generate_metrics ,
194
288
Metrics .BATCH_SIZE : batch_size ,
195
289
Metrics .INPUT_LENGTH : input_length ,
196
290
Metrics .OUTPUT_LENGTH : output_length ,
@@ -218,14 +312,23 @@ def aggregate_metrics(self, metrics: List[Dict[str, Any]]):
218
312
Metrics .TOKENS_BATCH ,
219
313
Metrics .LATENCY_TOKEN ,
220
314
Metrics .LATENCY_MODEL ,
315
+ Metrics .LATENCY_GENERATE_START ,
316
+ Metrics .LATENCY_GENERATE_BREAKDOWN ,
221
317
Metrics .LATENCY_DECODE ,
222
318
Metrics .LATENCY_E2E ,
223
319
)
224
320
}
321
+
322
+ breakdown = all_metrics .pop (Metrics .LATENCY_GENERATE_BREAKDOWN , [])
323
+
225
324
mean_metrics = {key : np .mean (value ).item () for key , value in all_metrics .items () if len (value ) > 0 }
226
325
throughput = mean_metrics [Metrics .TOKENS_BATCH ] / mean_metrics [Metrics .LATENCY_E2E ]
227
326
model_throughput = mean_metrics [Metrics .TOKENS_BATCH ] / mean_metrics [Metrics .LATENCY_MODEL ]
228
327
328
+ if len (breakdown ) > 0 :
329
+ mean_metrics [Metrics .LATENCY_GENERATE_BREAKDOWN ] = {
330
+ str (key ): np .mean ([values [key ] for values in breakdown ]).item () for key in breakdown [0 ]}
331
+
229
332
return {
230
333
** self .global_metrics ,
231
334
** mean_metrics ,
0 commit comments