Skip to content

[Fix] Auto-detect XGrammar compiler threads based on CPU cores. #17737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

Ubospica
Copy link

@Ubospica Ubospica commented May 6, 2025

This PR detects the number of threads in XGrammar's GrammarCompiler with the number of CPU cores to achieve best compilation performance.

Thanks to discussion with @russellb

@Ubospica Ubospica requested review from mgoin and russellb as code owners May 6, 2025 18:09
Copy link

github-actions bot commented May 6, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@russellb
Copy link
Member

russellb commented May 6, 2025

This change is based on observing that about half of the physical cores got the best result in this test:

from transformers import AutoTokenizer
import xgrammar as xgr
import time

# llama 3.1 8b
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)

schema = {
    "type": "object",
    "properties": {
        "name": {"type": "string"},
        "email": {"type": "string"},
        "street": {"type": "string"},
        "city": {"type": "string"},
        "state": {"type": "string"},
        "zip": {"type": "string"},
        "phone": {"type": "string"},
        "website": {"type": "string"},
        "company": {"type": "string"},
        "age": {"type": "integer"}
    },
    "required": ["name", "email"],
}

results = []
for i in range(128):
    times = []
    for _ in range(10):
        grammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=i + 1)
        time_start = time.monotonic_ns()
        compiled_grammar = grammar_compiler.compile_json_schema(schema)
        time_end = time.monotonic_ns()
        times.append((time_end - time_start) / 1000000)
    avg_time = sum(times) / len(times)
    results.append((i + 1, avg_time))
    print(f"{i} threads - average time in ms: {avg_time}ms")

# find the minimum time
min_time = min(results, key=lambda x: x[1])
print(f"{min_time[0]} threads - minimum time: {min_time[1]}ms")

def _get_grammar_compiler_threads(self) -> int:
"""Get the number of threads to use for the grammar compiler. The number of physical
threads is used to determine the thread count, and it is capped at 32 because having too
many threads does not yield additional benefits.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think there's something special about 32? or just that it happened to be the number of physical cores on a single CPU in the test environment?

Copy link
Author

@Ubospica Ubospica May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The maximum number of effective threads is the number of elements in the grammar, typically in dozens. So 32 is a reasonable number for not very complex schemas. If the grammar is extremely complex, more effective threads would be allowed, but for general purposes I think we can use an upper bound of 32.

Comment on lines 104 to 105
num_cpus = os.cpu_count()
return min(num_cpus / 2, 32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are 32 threads needed for this? It seems like overkill and poses problems when running multiple instances on one host.

This also won't work well for systems with small numbers of cores, as we have several threads being used just for multi-process input, scheduler, model workers. I think we need to have a fixed floor where you subtract some number of "reserved vllm threads" to get an "available threads" estimate

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's an 8 CPU core system, this will result in 4.

fair point about multiple instances on a single host. See my comment on the PR with a test script that helped find the sweet spot . I don't think we should hard code 32 since there's nothing magic about that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ubospica why did you move away from using psutil and getting the number of physical cores? did you get an error?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically I think we should increase the number of threads for structured output if we have an embarrassing number of cores, rather than taking up half by default

Copy link
Author

@Ubospica Ubospica May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ubospica why did you move away from using psutil and getting the number of physical cores? did you get an error?

I just updated it to resolve the comment

I don't think we want to ever use all CPUs

But I think half the physical threads would be too small for hyperthreading cases. E.g. we can use 16 threads in a 16c32t CPU, but half the physical threads (8 threads) is too small. So just set it back to logical threads.

Copy link
Author

@Ubospica Ubospica May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are 32 threads needed for this? It seems like overkill and poses problems when running multiple instances on one host.

@mgoin 32 threads reflect the effective threads needed for typical schema sizes. See #17737 (comment). This might be appropriate since we’re only setting an upper limit.

To allow multiple instances, maybe we can add a global setting to limit the number of threads, just like we did in V0 at

. But just not sure if we need to do it in this PR.

@Ubospica Ubospica force-pushed the main-dev/2025-05-06-fix-thread branch 2 times, most recently from ebd49cc to 17988a2 Compare May 6, 2025 19:11
@tmehlinger
Copy link

tmehlinger commented May 6, 2025

It would be better to make this configurable (say, VLLM_XGRAMMAR_COMPILER_THREADS) rather than simply assuming a thread count based on number of available cores. If a config value isn't provided, then fall back to your assumption. As it is, this has the potential to monopolize a substantial portion of available CPU time on a system that may be busy with other workloads, and operators have no way to tune the behavior.

@russellb
Copy link
Member

russellb commented May 6, 2025

It would be better to make this configurable (say, VLLM_XGRAMMAR_COMPILER_THREADS) rather than simply assuming a thread count based on number of available cores. If a config value isn't provided, then fall back to your assumption. As it is, this has the potential to monopolize a substantial portion of available CPU time on a system that may be busy with other workloads, and operators have no way to tune the behavior.

This makes sense. The simplest non-controversial thing we could do right now is add the env var, set it to 8 by default, and document that it's a tunable worth exploring in some extreme cases.

@Ubospica
Copy link
Author

Ubospica commented May 7, 2025

It would be better to make this configurable (say, VLLM_XGRAMMAR_COMPILER_THREADS) rather than simply assuming a thread count based on number of available cores. If a config value isn't provided, then fall back to your assumption. As it is, this has the potential to monopolize a substantial portion of available CPU time on a system that may be busy with other workloads, and operators have no way to tune the behavior.

This makes sense. The simplest non-controversial thing we could do right now is add the env var, set it to 8 by default, and document that it's a tunable worth exploring in some extreme cases.

I agree we can have an env var. What about: if an env var is set, use it; otherwise, fallback to the method in this PR. This can achieve better performance while allowing control in extreme cases.


Returns:
int: The number of threads to use for the grammar compiler.
"""
# check environment variable
env_threads = os.environ.get(VLLM_XGRAMMAR_COMPILER_THREADS_ENV_VAR)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have another pattern we typically use for env vars in vLLM. I'm happy to do the conversion for you if you don't mind me committing to your branch. See vllm.envs

@russellb
Copy link
Member

russellb commented May 7, 2025

I agree we can have an env var. What about: if an env var is set, use it; otherwise, fallback to the method in this PR. This can achieve better performance while allowing control in extreme cases.

I think there is concern that we're increasing the number of threads when it may only provide a significant benefit in the more extreme cases. Instead, I suggest just keeping the existing default of 8, but allow tuning with the env var. I think that's the compromise that we can get merged.

Copy link

mergify bot commented May 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Ubospica.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 12, 2025
Copy link
Member

@russellb russellb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed this to:

  1. Change the env var handling to use vllm.envs like other vllm env vars

  2. Remove the automatic calculation of the thread-pool size, since that part was controversial. It still defaults to 8, but can now at least be changed using the env var if needed.

Ubospica and others added 4 commits May 12, 2025 13:17
Signed-off-by: Ubospica <ubospica@gmail.com>

update

Signed-off-by: Ubospica <ubospica@gmail.com>

update comments

fix pre-commit errors

Signed-off-by: Ubospica <ubospica@gmail.com>
Signed-off-by: Ubospica <ubospica@gmail.com>
This changes the XGRAMMAR threads env var to be handled like other vllm
env vars.

This also removes the logic for auto calculating the thread pool size,
since that was controversial in review.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
@russellb russellb force-pushed the main-dev/2025-05-06-fix-thread branch from bd2c6b5 to 40ecba7 Compare May 12, 2025 17:18
@mergify mergify bot added documentation Improvements or additions to documentation and removed needs-rebase labels May 12, 2025
Copy link
Author

@Ubospica Ubospica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@russellb russellb dismissed their stale review May 13, 2025 12:07

dismissed my requested-changes review, but won't approve since the current patch is mostly my changes

@russellb russellb requested a review from mgoin May 13, 2025 12:07
Copy link

mergify bot commented May 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Ubospica.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation needs-rebase structured-output v1
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

4 participants