-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
Support inference with transformers-neuronx #2569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ee2f2ba
9ef2a92
eb63ed1
d904235
5b5624b
3ccac8d
d757a5a
fab6267
a9f094b
2841688
8f20019
c507d83
b6c9da7
2359ace
9fc4b0b
470e737
a7ffe34
1fb8b8d
288b2d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from vllm import LLM, SamplingParams | ||
|
||
# Sample prompts. | ||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
||
# Create an LLM. | ||
llm = LLM( | ||
model="openlm-research/open_llama_3b", | ||
max_num_seqs=8, | ||
# The max_model_len and block_size arguments are required to be same as max sequence length, | ||
# when targeting neuron device. Currently, this is a known limitation in continuous batching | ||
# support in transformers-neuronx. | ||
# TODO(liangfu): Support paged-attention in transformers-neuronx. | ||
max_model_len=128, | ||
block_size=128, | ||
# The device can be automatically detected when AWS Neuron SDK is installed. | ||
# The device argument can be either unspecified for automated detection, or explicitly assigned. | ||
device="neuron") | ||
# Generate texts from the prompts. The output is a list of RequestOutput objects | ||
# that contain the prompt, generated text, and other information. | ||
outputs = llm.generate(prompts, sampling_params) | ||
# Print the outputs. | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
|
||
from vllm.logger import init_logger | ||
from vllm.transformers_utils.config import get_config | ||
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version | ||
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
@@ -380,13 +380,21 @@ def __init__( | |
disable_custom_all_reduce: bool = False, | ||
) -> None: | ||
self.pipeline_parallel_size = pipeline_parallel_size | ||
self.tensor_parallel_size = tensor_parallel_size | ||
if is_neuron(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please comment on why the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment added above. |
||
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly. | ||
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload | ||
# to multiple NeuronCores. | ||
self.tensor_parallel_size = 1 | ||
self.neuron_tp_degree = tensor_parallel_size | ||
else: | ||
self.tensor_parallel_size = tensor_parallel_size | ||
self.worker_use_ray = worker_use_ray | ||
self.max_parallel_loading_workers = max_parallel_loading_workers | ||
self.disable_custom_all_reduce = disable_custom_all_reduce | ||
|
||
self.world_size = pipeline_parallel_size * tensor_parallel_size | ||
if self.world_size > 1: | ||
self.world_size = pipeline_parallel_size * self.tensor_parallel_size | ||
# Ray worker is not supported for Neuron backend. | ||
if self.world_size > 1 and not is_neuron(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please comment on why the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment added above. |
||
self.worker_use_ray = True | ||
self._verify_args() | ||
|
||
|
@@ -465,8 +473,29 @@ def _verify_args(self) -> None: | |
|
||
class DeviceConfig: | ||
|
||
def __init__(self, device: str = "cuda") -> None: | ||
self.device = torch.device(device) | ||
def __init__(self, device: str = "auto") -> None: | ||
if device == "auto": | ||
# Automated device type detection | ||
if torch.cuda.is_available(): | ||
self.device_type = "cuda" | ||
elif is_neuron(): | ||
self.device_type = "neuron" | ||
else: | ||
raise RuntimeError("No supported device detected.") | ||
else: | ||
# Device type is assigned explicitly | ||
self.device_type = device | ||
|
||
# Some device types require processing inputs on CPU | ||
if self.device_type in ["neuron"]: | ||
self.device = torch.device("cpu") | ||
else: | ||
# Set device with device type | ||
self.device = torch.device(self.device_type) | ||
|
||
@property | ||
def is_neuron(self): | ||
return self.device_type == "neuron" | ||
|
||
|
||
@dataclass | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the change in this file needed? |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
import torch.nn as nn | ||
|
||
from vllm.logger import init_logger | ||
from vllm.utils import is_hip | ||
from vllm.utils import is_hip, is_neuron | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
@@ -59,6 +59,9 @@ | |
"Sliding window attention is not yet supported in ROCm's flash attention", | ||
} | ||
|
||
# Models not supported by Neuron. | ||
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"} | ||
|
||
|
||
class ModelRegistry: | ||
|
||
|
@@ -75,8 +78,15 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: | |
logger.warning( | ||
f"Model architecture {model_arch} is partially supported " | ||
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) | ||
elif is_neuron(): | ||
if model_arch not in _NEURON_SUPPORTED_MODELS: | ||
raise ValueError( | ||
f"Model architecture {model_arch} is not supported by " | ||
"Neuron for now.") | ||
|
||
module_name, model_cls_name = _MODELS[model_arch] | ||
if is_neuron(): | ||
module_name = _NEURON_SUPPORTED_MODELS[model_arch] | ||
Comment on lines
+88
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need these lines? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
becomes
|
||
module = importlib.import_module( | ||
f"vllm.model_executor.models.{module_name}") | ||
return getattr(module, model_cls_name, None) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the online server APIs of vLLM such as API server and OpenAI-compatible servers do not work at the moment. Could you support them as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit tricky for now. I extended the
--block-size
argument with 128 being an option.We will be able to run API server with following setup:
python3 -m vllm.entrypoints.api_server --model openlm-research/open_llama_3b --tensor-parallel-size 2 --max-num-seqs 2 --max-model-len 128 --block-size 128 &
On the client side, we need to assume
--n=1
python3 api_client.py --stream --n=1 --prompt="one two three four "