Skip to content

Commit 8a5c7dd

Browse files
committed
Custom generate
1 parent 5d0e03d commit 8a5c7dd

File tree

5 files changed

+162
-40
lines changed

5 files changed

+162
-40
lines changed

src/main.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,19 @@ def get_arg_parser() -> ArgumentParser:
2626
parser.add_argument("config_args", nargs="*")
2727

2828
# Runtime
29+
parser.add_argument("-c", "--custom_generate", action="store_true")
2930
parser.add_argument("--pipeline_class", default="HF_Pipeline")
3031
parser.add_argument("--device", default="cuda", type=torch.device)
3132
parser.add_argument("--dtype", default="float16", type=lambda x: getattr(torch, x))
3233
parser.add_argument("--local_rank", type=int)
33-
parser.add_argument("--no_fast_init", dest="fast_init", action="store_false")
34+
parser.add_argument("--no_fast_init","--nf", dest="fast_init", action="store_false")
35+
parser.add_argument("--no_cache","--nc", dest="use_cache", action="store_false")
36+
parser.add_argument("--no_prefill","--np", dest="do_prefill", action="store_false")
3437

3538
# Input and output
36-
parser.add_argument("--batch_size", default=1, type=int)
37-
parser.add_argument("--max_input_length", default=-1, type=int)
38-
parser.add_argument("--max_new_tokens", default=100, type=int)
39+
parser.add_argument("--batch_size","-b", default=1, type=int)
40+
parser.add_argument("--max_input_length","-i", default=-1, type=int)
41+
parser.add_argument("--max_new_tokens","-g", default=100, type=int)
3942

4043
# Cleanup
4144
parser.add_argument("--clear_every_run", action="store_true")
@@ -47,10 +50,11 @@ def get_arg_parser() -> ArgumentParser:
4750

4851
# Profiling and logging
4952
parser.add_argument("--max_log_outputs", type=int)
50-
parser.add_argument("--profile", action="store_true")
51-
parser.add_argument("--profile_cycles", type=int)
52-
parser.add_argument("--full_trace", action="store_true")
53-
parser.add_argument("--show_op_names", action="store_true")
53+
parser.add_argument("--breakdown_latency","--bl", action="store_true")
54+
parser.add_argument("--profile","-p", action="store_true")
55+
parser.add_argument("--profile_cycles","--pc", type=int)
56+
parser.add_argument("--full_trace","--pt", action="store_true")
57+
parser.add_argument("--show_op_names","--pn", action="store_true")
5458
parser.add_argument("--save", type=Path)
5559

5660
return parser
@@ -61,7 +65,6 @@ def main(argv: Optional[List[str]] = None) -> None:
6165
parser = get_arg_parser()
6266
args = parser.parse_args(argv)
6367
config_args = parse_config_args(args.config_args)
64-
generate_kwargs = {"max_new_tokens": args.max_new_tokens, "do_sample": False}
6568
inputs = get_dummy_batch(args.batch_size, args.max_input_length)
6669
separate_profile = args.profile and args.profile_cycles is not None
6770
warmup = args.profile if args.warmup is None else args.warmup
@@ -88,6 +91,10 @@ def main(argv: Optional[List[str]] = None) -> None:
8891
dtype=args.dtype,
8992
fast_init=args.fast_init,
9093
trust_remote_code=args.trust_remote_code,
94+
custom_generate=args.custom_generate,
95+
use_cache=args.use_cache,
96+
do_prefill=args.do_prefill,
97+
breakdown_latency=args.breakdown_latency,
9198
)
9299

93100
all_metrics = []
@@ -104,7 +111,7 @@ def main(argv: Optional[List[str]] = None) -> None:
104111
profiler = contextlib.nullcontext()
105112

106113
benchmark_metrics = {
107-
**generate_kwargs,
114+
"max_new_tokens": args.max_new_tokens,
108115
"Model parameters": pipeline.get_num_parameters(),
109116
"Cycles (warmup)": args.skip + warmup,
110117
"Cycles (benchmark)": args.cycles,
@@ -124,7 +131,7 @@ def main(argv: Optional[List[str]] = None) -> None:
124131
if step == args.skip + warmup:
125132
t2 = time.perf_counter()
126133
benchmark_metrics[Metrics.RUNTIME_WARMUP] = t2 - t1
127-
generated_text, metrics = pipeline(inputs, **generate_kwargs)
134+
generated_text, metrics = pipeline(inputs, args.max_new_tokens)
128135
if args.profile:
129136
p.step()
130137

src/metrics.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,20 @@ def format_ms(t: float) -> str:
1717
return f"{1000 * t:.2f} ms"
1818

1919

20+
def format_ms_dict(t_dict: Dict[str,float]) -> Dict[str,str]:
21+
return {key:format_ms(value) for key, value in t_dict.items()}
22+
23+
2024
def format_mib(m: float) -> str:
2125
return f"{m/2**20:.0f} MiB"
2226

2327

2428
class Metrics:
2529
LATENCY_E2E = "Latency (end to end)"
2630
LATENCY_TOKEN = "Latency (tokenization)"
27-
LATENCY_MODEL = "Latency (model)"
31+
LATENCY_MODEL = "Latency (generate)"
32+
LATENCY_GENERATE_START = "Latency (prepare for generation)"
33+
LATENCY_GENERATE_BREAKDOWN = "Latency (generate breakdown)"
2834
LATENCY_DECODE = "Latency (decode)"
2935
LATENCY_MAX = "Latency (max)"
3036
LATENCY_MIN = "Latency (min)"
@@ -59,6 +65,8 @@ class Metrics:
5965
LATENCY_E2E: format_ms,
6066
LATENCY_TOKEN: format_ms,
6167
LATENCY_MODEL: format_ms,
68+
LATENCY_GENERATE_START: format_ms,
69+
LATENCY_GENERATE_BREAKDOWN: format_ms_dict,
6270
LATENCY_DECODE: format_ms,
6371
LATENCY_MAX: format_ms,
6472
LATENCY_MIN: format_ms,

src/pipeline.py

Lines changed: 128 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AutoTokenizer,
1919
PretrainedConfig,
2020
PreTrainedModel,
21+
GPTBigCodeConfig,GPTBigCodeForCausalLM
2122
)
2223

2324

@@ -37,25 +38,41 @@ def __init__(
3738
dtype: torch.dtype,
3839
fast_init: bool = True,
3940
trust_remote_code: bool = False,
41+
custom_generate:bool=False,
42+
use_cache: bool = True,
43+
do_prefill: bool = True,
44+
breakdown_latency=False,
4045
):
4146
self.global_metrics = {}
4247
log_rank_n("*** Setting up tokenizer", logger.info)
43-
t0 = time.perf_counter()
44-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
48+
t0 = self._get_time()
49+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, padding_side="left")
50+
if self.tokenizer.pad_token is None:
51+
self.tokenizer.pad_token=self.tokenizer.eos_token
4552

46-
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
47-
t1 = time.perf_counter()
53+
t1 = self._get_time()
4854

4955
self.device = device
56+
if self.device==torch.device("cuda"):
57+
self.device=torch.device("cuda:0")
58+
5059
self.dtype = dtype
5160
self.is_int8 = self.dtype == torch.int8
5261
self.fast_init = fast_init
5362
self.trust_remote_code = trust_remote_code
54-
if self.is_int8 and self.device != torch.device("cuda"):
63+
self.use_cache = use_cache
64+
self.do_prefill = do_prefill
65+
if not self.do_prefill:
66+
assert custom_generate
67+
assert self.use_cache
68+
self.breakdown_latency=breakdown_latency
69+
if self.is_int8 and self.device != torch.device("cuda:0"):
5570
raise ValueError(f"Model quantization not supported on device {self.device}")
5671

72+
self._generate=self._generate_custom if custom_generate else self._generate_hf
73+
5774
self.config = self._get_config(model_type, pretrained_config or pretrained_model, config_args)
58-
t2 = time.perf_counter()
75+
t2 = self._get_time()
5976

6077
logger.info(f"Model configuration: {self.config}")
6178

@@ -67,27 +84,27 @@ def __init__(
6784
self.model = self._load_pretrained(pretrained_model)
6885

6986
self.model.eval()
70-
t3 = time.perf_counter()
87+
t3 = self._get_time()
7188
self.global_metrics[Metrics.INIT_TOKEN] = t1 - t0
7289
self.global_metrics[Metrics.INIT_CONFIG] = t2 - t1
7390
self.global_metrics[Metrics.INIT_TOTAL] = t3 - t0
7491

7592
def _create_model(self) -> PreTrainedModel:
76-
t0 = time.perf_counter()
93+
t0 = self._get_time()
7794
log_rank_n("*** Creating model", logger.info)
7895
with fast_init(self.device) if self.fast_init else contextlib.nullcontext():
7996
torch_dtype = torch.float16 if self.is_int8 else self.dtype
8097
model = AutoModelForCausalLM.from_config(
8198
config=self.config, torch_dtype=torch_dtype, trust_remote_code=self.trust_remote_code
8299
)
83-
t1 = time.perf_counter()
100+
t1 = self._get_time()
84101
log_rank_n("*** Moving to device", logger.info)
85102
model.to(self.device)
86-
t2 = time.perf_counter()
103+
t2 = self._get_time()
87104
log_rank_n("*** Initializing weights", logger.info)
88105
# Initialization is ~1000x faster on GPU.
89106
model.init_weights()
90-
t3 = time.perf_counter()
107+
t3 = self._get_time()
91108
self.global_metrics[Metrics.INIT_CREATE] = t1 - t0
92109
self.global_metrics[Metrics.INIT_DEVICE] = t2 - t1
93110
self.global_metrics[Metrics.INIT_WEIGHTS] = t3 - t2
@@ -101,14 +118,14 @@ def _reload_model(self):
101118
self.model = self._load_pretrained("tmp")
102119

103120
def _save_pretrained(self, pretrained_model: str):
104-
t0 = time.perf_counter()
121+
t0 = self._get_time()
105122
log_rank_n(f"*** Saving model to {pretrained_model}", logger.info)
106-
t1 = time.perf_counter()
123+
t1 = self._get_time()
107124
self.global_metrics[Metrics.INIT_SAVE] = t1 - t0
108125
self.model.save_pretrained(pretrained_model)
109126

110127
def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel:
111-
t0 = time.perf_counter()
128+
t0 = self._get_time()
112129
log_rank_n(f"*** Loading model from {pretrained_model}", logger.info)
113130
kwargs = {"load_in_8bit": True, "device_map": "auto"} if self.is_int8 else {"torch_dtype": self.dtype}
114131
with fast_init(self.device) if self.fast_init else contextlib.nullcontext():
@@ -120,12 +137,12 @@ def _load_pretrained(self, pretrained_model: str) -> PreTrainedModel:
120137
trust_remote_code=self.trust_remote_code,
121138
**kwargs,
122139
)
123-
t1 = time.perf_counter()
140+
t1 = self._get_time()
124141
self.global_metrics["load pretrained model"] = t1 - t0
125142
if not self.is_int8:
126143
log_rank_n("*** Moving to device", logger.info)
127144
model = model.to(self.device)
128-
t2 = time.perf_counter()
145+
t2 = self._get_time()
129146
self.global_metrics[Metrics.INIT_DEVICE] = t2 - t1
130147
return model
131148

@@ -171,26 +188,103 @@ def _get_config(
171188

172189
return config
173190

174-
def __call__(self, text: List[str], **generate_kwargs) -> Tuple[List[str], Dict[str, Any]]:
175-
t0 = time.perf_counter()
176-
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
191+
def _get_time(self, synchronize=False):
192+
if synchronize:
193+
torch.cuda.synchronize()
194+
return time.perf_counter()
195+
196+
def _generate_custom(self, inputs:Dict, max_new_tokens:int):
197+
t0 = self._get_time(self.breakdown_latency)
198+
batch_size, input_length = inputs["input_ids"].shape
199+
output_length = input_length + max_new_tokens
200+
input_ids = torch.empty([batch_size, output_length], dtype=torch.int64, device=self.device)
201+
input_ids[:, :input_length].copy_(inputs["input_ids"])
177202

203+
attention_mask = torch.empty([batch_size, output_length], dtype=torch.bool, device=self.device)
204+
attention_mask[:, :input_length].copy_(inputs["attention_mask"])
205+
attention_mask[:, input_length:].fill_(True)
206+
207+
position_ids = attention_mask.long().cumsum(-1, dtype=torch.int64) - 1
208+
# TODO: Useless?
209+
position_ids[:, :input_length].masked_fill_(attention_mask[:, :input_length] == 0, 1)
210+
211+
if self.do_prefill or input_length<=1:
212+
past_key_values=None
213+
past_key_length=0
214+
else:
215+
# Generate mock `past_key_values`
216+
past_key_length=input_length-1
217+
if isinstance(self.config, GPTBigCodeConfig):
218+
if self.config.pre_allocate_kv_cache:
219+
past_key_values=[past_key_length]*self.config.n_layer
220+
for block in self.model.transformer.h:
221+
block.attn.get_kv_cache(batch_size, past_key_length, dtype=self.dtype, device=self.device).normal_()
222+
else:
223+
kv_dim=self.config.n_embd // self.config.n_head if self.config.multi_query else self.config.n_embd
224+
past_key_values=[torch.randn([batch_size, past_key_length, 2*kv_dim], dtype=self.dtype, device=self.device) for _ in range(self.config.n_layer)]
225+
else:
226+
past_key_values = [
227+
[torch.randn([batch_size, past_key_length, self.config.n_embd], dtype=self.dtype, device=self.device) for _ in range(2)] for _ in
228+
range(self.config.n_layer)]
229+
230+
t1 = self._get_time(self.breakdown_latency)
231+
last_time=t1
232+
generate_times={}
233+
for key_length in range(input_length, output_length):
234+
outputs = self.model(
235+
input_ids=input_ids[:, past_key_length:key_length],
236+
past_key_values=past_key_values,
237+
attention_mask=attention_mask[:, :key_length],
238+
position_ids=position_ids[:, past_key_length:key_length],
239+
return_dict=True,
240+
use_cache=self.use_cache,
241+
)
242+
if self.use_cache:
243+
past_key_values=outputs.past_key_values
244+
past_key_length=key_length
245+
next_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1)
246+
input_ids[:, key_length] = next_tokens
247+
t2 = self._get_time(self.breakdown_latency)
248+
generate_times[key_length]=t2-last_time
249+
last_time=t2
250+
251+
metrics={}
252+
if self.breakdown_latency:
253+
metrics[Metrics.LATENCY_GENERATE_START]=t1-t0
254+
metrics[Metrics.LATENCY_GENERATE_BREAKDOWN]=generate_times
255+
256+
return input_ids, metrics
257+
258+
def _generate_hf(self, inputs:Dict, max_new_tokens:int):
178259
inputs = {key: value.to(self.device) if torch.is_tensor(value) else value for key, value in inputs.items()}
260+
output = self.model.generate(
261+
**inputs,
262+
return_dict_in_generate=True,
263+
max_new_tokens=max_new_tokens,
264+
do_sample=False,
265+
pad_token_id=self.tokenizer.pad_token_id,
266+
use_cache=self.use_cache,
267+
)
268+
return output.sequences, {}
179269

180-
t1 = time.perf_counter()
181-
with torch.inference_mode():
182-
output = self.model.generate(**inputs, return_dict_in_generate=True, **generate_kwargs)
183-
t2 = time.perf_counter()
184270

185-
output_tokens = output.sequences
271+
def __call__(self, text: List[str], max_new_tokens:int) -> Tuple[List[str], Dict[str, Any]]:
272+
t0 = self._get_time()
273+
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
274+
275+
t1 = self._get_time()
276+
with torch.inference_mode():
277+
output_tokens, generate_metrics = self._generate(inputs, max_new_tokens)
278+
t2 = self._get_time(True)
186279

187280
batch_size, input_length = inputs["input_ids"].shape
188281
output_length = output_tokens.size(1)
189282

190283
output_text = self.tokenizer.batch_decode(output_tokens.cpu(), skip_special_tokens=True)
191-
t3 = time.perf_counter()
284+
t3 = self._get_time()
192285

193286
metrics = {
287+
**generate_metrics,
194288
Metrics.BATCH_SIZE: batch_size,
195289
Metrics.INPUT_LENGTH: input_length,
196290
Metrics.OUTPUT_LENGTH: output_length,
@@ -218,14 +312,23 @@ def aggregate_metrics(self, metrics: List[Dict[str, Any]]):
218312
Metrics.TOKENS_BATCH,
219313
Metrics.LATENCY_TOKEN,
220314
Metrics.LATENCY_MODEL,
315+
Metrics.LATENCY_GENERATE_START,
316+
Metrics.LATENCY_GENERATE_BREAKDOWN,
221317
Metrics.LATENCY_DECODE,
222318
Metrics.LATENCY_E2E,
223319
)
224320
}
321+
322+
breakdown=all_metrics.pop(Metrics.LATENCY_GENERATE_BREAKDOWN, [])
323+
225324
mean_metrics = {key: np.mean(value).item() for key, value in all_metrics.items() if len(value) > 0}
226325
throughput = mean_metrics[Metrics.TOKENS_BATCH] / mean_metrics[Metrics.LATENCY_E2E]
227326
model_throughput = mean_metrics[Metrics.TOKENS_BATCH] / mean_metrics[Metrics.LATENCY_MODEL]
228327

328+
if len(breakdown) > 0:
329+
mean_metrics[Metrics.LATENCY_GENERATE_BREAKDOWN] = {
330+
str(key): np.mean([values[key] for values in breakdown]).item() for key in breakdown[0]}
331+
229332
return {
230333
**self.global_metrics,
231334
**mean_metrics,

src/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,13 @@ def log_rank_n(msg: str, logger: Callable = logging.info, rank: int = 0):
8282
logger(line)
8383

8484

85-
def log_dict(data: dict, logger: Callable = logging.info, rank: int = 0):
85+
def log_dict(data: dict, logger: Callable = logging.info, rank: int = 0, _prefix=""):
8686
for key, value in data.items():
87-
log_rank_n(f"{key}: {value}", logger, rank)
87+
if isinstance(value, dict):
88+
log_rank_n(f"{_prefix}{key}:", logger, rank)
89+
log_dict(value, logger, rank, _prefix+" ")
90+
else:
91+
log_rank_n(f"{_prefix}{key}: {value}", logger, rank)
8892

8993

9094
dummy_input_sentences = [

transformers

Submodule transformers updated 2152 files

0 commit comments

Comments
 (0)