Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inference with transformers-neuronx #2569

Merged
merged 19 commits into from
Feb 28, 2024

Conversation

liangfu
Copy link
Contributor

@liangfu liangfu commented Jan 24, 2024

This PR enables llama model inference on Inferentia with transformers-neuronx backend.

To demonstrate offline inference with transformers-neuronx, run

python3 examples/offline_inference_neuron.py

@liangfu liangfu force-pushed the neuron-2 branch 2 times, most recently from a9dbe4b to 242d6d1 Compare January 25, 2024 00:08
@WoosukKwon WoosukKwon self-requested a review January 25, 2024 18:07
@WoosukKwon
Copy link
Collaborator

Just for a note: We may want to merge #2503 before this PR.

@liangfu
Copy link
Contributor Author

liangfu commented Jan 30, 2024

Just for a note: We may want to merge #2503 before this PR.

Sure, I can work on a rebase once that's merged.

@WoosukKwon
Copy link
Collaborator

Hey @liangfu, thanks for understanding. We are prioritizing some PRs for the upcoming v0.3.0 release and the meetup event. Apologies for the delay.

@liangfu
Copy link
Contributor Author

liangfu commented Jan 30, 2024

No worries. Removing hardcoded device is necessary.

@zhuohan123 zhuohan123 mentioned this pull request Jan 31, 2024
30 tasks
@WoosukKwon
Copy link
Collaborator

Hi @liangfu, we just merged #2503. Could you rebase the PR? Thanks!

@liangfu
Copy link
Contributor Author

liangfu commented Feb 3, 2024

Hi @WoosukKwon , i rebased upon latest main branch. Feel free to take a look.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liangfu, thanks for submitting the PR! I've checked that the PR works on inf2 instance and the outputs of offline_inference_neuron match those of offline_inference.py on NVIDIA GPUs (when greedy sampling is used). Happy to see that we have a working implementation!

Also, I liked that the PR minimized the modifications to the codebase, except that it duplicated worker in neuron_worker for some reason I don't understand. Could you clarify this?

Also, I'd like to know more about the details of the implementations made in transformers-neurons. For now, things like max_model_len=128, block_size=128, and os.environ['MASTER_PORT'] = '12355' are quite mysterious to me.

vllm/model_executor/__init__.py Outdated Show resolved Hide resolved
max_num_seqs=8,
max_model_len=128,
block_size=128,
device="cpu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the device need to be "cpu" here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the constraints on block_size and max_model_len and max_num_seqs? If there's any, please specify it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. removed the device='cpu' here, and the device can be automatically detected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, specified the constraints on supporting neuron device explicitly

examples/offline_inference_neuron.py Outdated Show resolved Hide resolved
vllm/model_executor/neuron_model_loader.py Outdated Show resolved Hide resolved
vllm/model_executor/neuron_model_loader.py Outdated Show resolved Hide resolved
vllm/model_executor/neuron_model_loader.py Outdated Show resolved Hide resolved
vllm/worker/neuron_worker.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the online server APIs of vLLM such as API server and OpenAI-compatible servers do not work at the moment. Could you support them as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit tricky for now. I extended the --block-size argument with 128 being an option.

We will be able to run API server with following setup:

python3 -m vllm.entrypoints.api_server --model openlm-research/open_llama_3b --tensor-parallel-size 2 --max-num-seqs 2 --max-model-len 128 --block-size 128 &

On the client side, we need to assume --n=1

python3 api_client.py --stream --n=1 --prompt="one two three four "

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for submitting the PR and addressing my reviews. Happy to have the initial support for Inferentia! Let's keep optimizing the performance and making this integration more mature 🚀

vllm/worker/model_runner.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
vllm/model_executor/models/neuron/llama.py Outdated Show resolved Hide resolved
vllm/lora/layers.py Outdated Show resolved Hide resolved
vllm/model_executor/__init__.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks liangfu for the work! Left more comments on styling.

@@ -374,13 +374,17 @@ def __init__(
disable_custom_all_reduce: bool = False,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
if is_neuron():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please comment on why the is_neuron here is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment added above.

self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
if self.world_size > 1 and not is_neuron():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please comment on why the is_neuron here is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment added above.

Comment on lines 145 to 148
if self.is_neuron:
from vllm.worker.neuron_worker import Worker
else:
from vllm.worker.worker import Worker
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change this to a dictionary from hardware to workers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to a dict to load the modules dynamically based on device config.

Comment on lines 250 to 253
if self.is_neuron:
from vllm.worker.neuron_worker import Worker
else:
from vllm.worker.worker import Worker
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I personally believe a dictionary will look more clear

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reusing the same dispatch_worker function, which implements the dict for load the device-specific worker.

Comment on lines 368 to 369
if not self.is_neuron:
self._run_workers("warm_up_model")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we warm up for neuron?

I feel like the warmup here is a no-op. We should minimize the number of if is_neurons

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. The warm-up function becomes a no-op for the neuron backend.

Comment on lines +87 to +89
if is_neuron():
module_name = _NEURON_SUPPORTED_MODELS[model_arch]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need these lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module_name would become neuron.llama instead of llama. We use separate module for neuron support.

vllm/model_executor/models/llama.py

becomes

vllm/model_executor/models/neuron/llama.py

@@ -39,6 +38,9 @@ def __init__(
self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks

if is_neuron():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please comment on why the is_neuron here is needed and the logic here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added self.logits_as_hidden_states flag

Comment on lines 84 to 89
if self.is_neuron:
self.model = get_model(self.model_config, self.parallel_config,
self.scheduler_config)
else:
self.model = get_model(self.model_config, self.device_config,
self.lora_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have different function arguments for neuron get_model?

Copy link
Contributor Author

@liangfu liangfu Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function arguments are different because:
1/ transformer-neuronx doesn't support lora inference yet.
2/ transformer-neuronx would require parallel_config and scheduler_config to compile the kernels.

This has been changed to utils.get_model that fetches the device-specific get_model dynamically.
The call site of the get_model is exactly the same among different devices.

@@ -309,7 +314,8 @@ def _prepare_decode(
use_captured_graph = (
not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_context_len <= self.max_context_len_to_capture)
and max_context_len <= self.max_context_len_to_capture
and not self.is_neuron)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In config, can we just set enforce_eager for neuron? This can help reduce the number of is_neuron in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to set enforce_eager for neuron support.

vllm/worker/model_runner.py Outdated Show resolved Hide resolved
@liangfu liangfu force-pushed the neuron-2 branch 2 times, most recently from 614a4f0 to 9645f08 Compare February 21, 2024 17:51
@liangfu liangfu requested a review from zhuohan123 February 26, 2024 18:46
@WoosukKwon WoosukKwon merged commit 3b7178c into vllm-project:main Feb 28, 2024
22 checks passed
@WoosukKwon
Copy link
Collaborator

@liangfu Thanks for submitting the PR and sorry for the delays in the review! Let's merge the PR and do refactoring to isolate the neuron backend from others.

xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 4, 2024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the change in this file needed?

njhill added a commit to njhill/vllm that referenced this pull request Mar 19, 2024
Care is taken in the code to avoid initializing CUDA prior to CUDA_VISIBLE_DEVICES being set in the worker, but an instance of this was inadvertently introduced in vllm-project#2569.
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants