Skip to content

Commit

Permalink
revert benchmark fix (#8747)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 authored Jul 10, 2024
1 parent 3241120 commit a1dbc39
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,8 +849,6 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
self.free_list = [i for i in range(self.max_block_nums)][::-1]
self.used_list = [[] for _ in range(config.batch_size)]

self.benchmark = config.benchmark

def init_inputs(self, config: PredictorArgument):
self.inputs = {}

Expand Down Expand Up @@ -967,22 +965,19 @@ def _get_rotary_position_embedding(self, position_ids, head_dim, rope_theta=1000
return rot_emb

def _preprocess(self, source):
if not self.benchmark and self.tokenizer.chat_template is not None:
if self.tokenizer.chat_template is not None:
source = [source] if isinstance(source, str) else source
source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source]

for i, text in enumerate(source):
add_special_tokens = self.tokenizer.chat_template is None or isinstance(
self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)
)
add_special_tokens = add_special_tokens if not self.benchmark else False
tokens = self.tokenizer(
text,
return_tensors="np",
padding=True,
max_length=self.config.src_length,
# if use chat_template, it will not add special_tokens
add_special_tokens=add_special_tokens,
add_special_tokens=self.tokenizer.chat_template is None
or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)),
)
input_ids = tokens["input_ids"][0]
length = len(input_ids)
Expand Down Expand Up @@ -1622,10 +1617,6 @@ def predict():

predictor = create_predictor(predictor_args, model_args)

if predictor_args.benchmark:
benchmark(predictor, predictor_args, model_args)
return

source_texts = []
target_texts = []
if model_args.data_file:
Expand Down Expand Up @@ -1669,12 +1660,14 @@ def predict():
out = {"src": source, "tgt": target, "output": output}
f.write(json.dumps(out, ensure_ascii=False) + "\n")

if predictor_args.benchmark:
benchmark(predictor, predictor_args, model_args)


def benchmark(predictor, predictor_args, model_args):
# Just construct a simple benchmark input. We pad input to the src_length.
benchmark_texts = [
predictor.tokenizer.pad_token * predictor_args.src_length for _ in range(predictor_args.batch_size)
]
test_texts = "hello world, how are you?"
benchmark_texts = [test_texts + "<pad>" * predictor_args.src_length for _ in range(predictor_args.batch_size)]

batch_benchmark_texts = batchfy_text(benchmark_texts, predictor_args.batch_size)
print("***********Start Benchmark**********")
Expand Down

0 comments on commit a1dbc39

Please sign in to comment.