Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generation is too slow when using RegexParser(r'.*') or RegexParser(r'.+') #84

Open
ZhouGongZaiShi opened this issue Mar 8, 2024 · 3 comments

Comments

@ZhouGongZaiShi
Copy link

ZhouGongZaiShi commented Mar 8, 2024

output_regex = r"\d*.*"
parser = RegexParser(output_regex)
logits_processor = build_vllm_logits_processor(llm_tokenizer_data, parser, analyze=False)
sampling_params.logits_processors = [logits_processor]

On A100 GPU, use vllm to load the local model and use lmformatenforcer.

If output_regex=r"\d*\w*", tokens per second are about 200
If output_regex=r"\d*.*", tokens per second are about 5.

The speed dropped by a factor of forty.
Sometimes "." is necessary for regular expressions.

@noamgat
Copy link
Owner

noamgat commented Mar 13, 2024

Hi,
I was not able to reproduce this issue on my computer.
Can you check if this slowdown also occurs if you use the same logits processor instance for multiple requests?

RegexParser is implemented behind the scenes via a state machine.
The first time a state is encountered, all of the legal tokens are calculated. This can take a bit of time.
However, next time, its not calculated at all. In the regex you gave as an example, there are only 3 states to the FSM, so after the first generation, the next ones should have negligible impact.

Can you check if this is the case?

@goutham794
Copy link

Hello,

I'm trying to use a regex controlled generation, for me the processing does not even start when I have a huge batch, it worked when I gave it a single example, albeit slow. My inital regex was to match a valid Python list of strings which I simplified, but it still does not even start :

list_regex = r'\[.*\]'
parser = RegexParser(list_regex)

logits_processor = build_vllm_logits_processor(tokenizer_data, parser)

sampling_params = SamplingParams(max_tokens=100, logits_processors=[logits_processor])
results = llm.generate([p['text'] for p in dataset], sampling_params=sampling_params)

@goutham794
Copy link

It started after a long time, and the throughput is much lower - at 2.5 it/s compared to 110 it/s without the logit processing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants