@@ -1238,6 +1238,42 @@ def _dummy_run(
1238
1238
)
1239
1239
return hidden_states
1240
1240
1241
+ @torch .inference_mode ()
1242
+ def _dummy_sampler_run (
1243
+ self ,
1244
+ hidden_states : torch .Tensor ,
1245
+ ) -> torch .Tensor :
1246
+
1247
+ logits = self .model .compute_logits (hidden_states , None )
1248
+ num_reqs = logits .size (0 )
1249
+
1250
+ dummy_tensors = lambda v : torch .full (
1251
+ (num_reqs , ), v , device = self .device )
1252
+
1253
+ dummy_metadata = SamplingMetadata (
1254
+ temperature = dummy_tensors (0.5 ),
1255
+ all_greedy = False ,
1256
+ all_random = False ,
1257
+ top_p = dummy_tensors (0.9 ),
1258
+ top_k = dummy_tensors (logits .size (1 ) - 1 ),
1259
+ min_p = None ,
1260
+ generators = {},
1261
+ max_num_logprobs = None ,
1262
+ no_penalties = True ,
1263
+ prompt_token_ids = None ,
1264
+ frequency_penalties = dummy_tensors (0.1 ),
1265
+ presence_penalties = dummy_tensors (0.1 ),
1266
+ repetition_penalties = dummy_tensors (0.1 ),
1267
+ output_token_ids = [[] for _ in range (num_reqs )],
1268
+ min_tokens = {},
1269
+ logit_bias = [None for _ in range (num_reqs )],
1270
+ allowed_token_ids_mask = None ,
1271
+ )
1272
+ sampler_output = self .model .sample (logits = logits ,
1273
+ sampling_metadata = dummy_metadata )
1274
+
1275
+ return sampler_output
1276
+
1241
1277
def profile_run (self ) -> None :
1242
1278
# Profile with multimodal encoder & encoder cache.
1243
1279
# TODO: handle encoder-decoder models once we support them.
@@ -1353,37 +1389,11 @@ def profile_run(self) -> None:
1353
1389
hidden_states = self ._dummy_run (self .max_num_tokens )
1354
1390
if get_pp_group ().is_last_rank :
1355
1391
hidden_states = hidden_states [logit_indices ]
1356
- logits = self .model .compute_logits (hidden_states , None )
1357
- dummy_tensors = lambda v : torch .full (
1358
- (num_reqs , ), v , device = self .device )
1359
- dummy_metadata = SamplingMetadata (
1360
- temperature = dummy_tensors (0.5 ),
1361
- all_greedy = False ,
1362
- all_random = False ,
1363
- top_p = dummy_tensors (0.9 ),
1364
- top_k = dummy_tensors (logits .size (1 ) - 1 ),
1365
- min_p = None ,
1366
- generators = {},
1367
- max_num_logprobs = None ,
1368
- no_penalties = True ,
1369
- prompt_token_ids = torch .ones_like (logits ,
1370
- dtype = torch .int64 ),
1371
- frequency_penalties = dummy_tensors (0.1 ),
1372
- presence_penalties = dummy_tensors (0.1 ),
1373
- repetition_penalties = dummy_tensors (0.1 ),
1374
- output_token_ids = [[] for _ in range (num_reqs )],
1375
- min_tokens = {},
1376
- logit_bias = [None for _ in range (num_reqs )],
1377
- allowed_token_ids_mask = None ,
1378
- )
1379
- sampler_output = self .model .sample (
1380
- logits = logits , sampling_metadata = dummy_metadata )
1392
+ sampler_output = self ._dummy_sampler_run (hidden_states )
1381
1393
else :
1382
- logits = None
1383
1394
sampler_output = None
1384
- dummy_metadata = None
1385
1395
torch .cuda .synchronize ()
1386
- del hidden_states , logits , sampler_output , dummy_metadata
1396
+ del hidden_states , sampler_output
1387
1397
self .encoder_cache .clear ()
1388
1398
gc .collect ()
1389
1399
0 commit comments