Skip to content

Commit 2df1c2b

Browse files
liangfujimpang
authored andcommitted
[Neuron] Support inference with transformers-neuronx (vllm-project#2569)
1 parent 3b89c83 commit 2df1c2b

File tree

18 files changed

+516
-42
lines changed

18 files changed

+516
-42
lines changed

examples/offline_inference_neuron.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from vllm import LLM, SamplingParams
2+
3+
# Sample prompts.
4+
prompts = [
5+
"Hello, my name is",
6+
"The president of the United States is",
7+
"The capital of France is",
8+
"The future of AI is",
9+
]
10+
# Create a sampling params object.
11+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
12+
13+
# Create an LLM.
14+
llm = LLM(
15+
model="openlm-research/open_llama_3b",
16+
max_num_seqs=8,
17+
# The max_model_len and block_size arguments are required to be same as max sequence length,
18+
# when targeting neuron device. Currently, this is a known limitation in continuous batching
19+
# support in transformers-neuronx.
20+
# TODO(liangfu): Support paged-attention in transformers-neuronx.
21+
max_model_len=128,
22+
block_size=128,
23+
# The device can be automatically detected when AWS Neuron SDK is installed.
24+
# The device argument can be either unspecified for automated detection, or explicitly assigned.
25+
device="neuron")
26+
# Generate texts from the prompts. The output is a list of RequestOutput objects
27+
# that contain the prompt, generated text, and other information.
28+
outputs = llm.generate(prompts, sampling_params)
29+
# Print the outputs.
30+
for output in outputs:
31+
prompt = output.prompt
32+
generated_text = output.outputs[0].text
33+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

tests/lora/conftest.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
131131
cleanup()
132132
get_model_old = get_model
133133

134-
def get_model_patched(model_config, device_config, lora_config=None):
135-
return get_model_old(model_config, device_config,
136-
LoRAConfig(max_loras=4, max_lora_rank=8))
134+
def get_model_patched(model_config, device_config, **kwargs):
135+
return get_model_old(model_config,
136+
device_config,
137+
lora_config=LoRAConfig(max_loras=4,
138+
max_lora_rank=8))
137139

138140
with patch("vllm.worker.model_runner.get_model", get_model_patched):
139141
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)

vllm/config.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from vllm.logger import init_logger
1010
from vllm.transformers_utils.config import get_config
11-
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
11+
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version
1212

1313
logger = init_logger(__name__)
1414

@@ -380,13 +380,21 @@ def __init__(
380380
disable_custom_all_reduce: bool = False,
381381
) -> None:
382382
self.pipeline_parallel_size = pipeline_parallel_size
383-
self.tensor_parallel_size = tensor_parallel_size
383+
if is_neuron():
384+
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
385+
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
386+
# to multiple NeuronCores.
387+
self.tensor_parallel_size = 1
388+
self.neuron_tp_degree = tensor_parallel_size
389+
else:
390+
self.tensor_parallel_size = tensor_parallel_size
384391
self.worker_use_ray = worker_use_ray
385392
self.max_parallel_loading_workers = max_parallel_loading_workers
386393
self.disable_custom_all_reduce = disable_custom_all_reduce
387394

388-
self.world_size = pipeline_parallel_size * tensor_parallel_size
389-
if self.world_size > 1:
395+
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
396+
# Ray worker is not supported for Neuron backend.
397+
if self.world_size > 1 and not is_neuron():
390398
self.worker_use_ray = True
391399
self._verify_args()
392400

@@ -465,8 +473,29 @@ def _verify_args(self) -> None:
465473

466474
class DeviceConfig:
467475

468-
def __init__(self, device: str = "cuda") -> None:
469-
self.device = torch.device(device)
476+
def __init__(self, device: str = "auto") -> None:
477+
if device == "auto":
478+
# Automated device type detection
479+
if torch.cuda.is_available():
480+
self.device_type = "cuda"
481+
elif is_neuron():
482+
self.device_type = "neuron"
483+
else:
484+
raise RuntimeError("No supported device detected.")
485+
else:
486+
# Device type is assigned explicitly
487+
self.device_type = device
488+
489+
# Some device types require processing inputs on CPU
490+
if self.device_type in ["neuron"]:
491+
self.device = torch.device("cpu")
492+
else:
493+
# Set device with device type
494+
self.device = torch.device(self.device_type)
495+
496+
@property
497+
def is_neuron(self):
498+
return self.device_type == "neuron"
470499

471500

472501
@dataclass

vllm/engine/arg_utils.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class EngineArgs:
4444
lora_extra_vocab_size: int = 256
4545
lora_dtype = 'auto'
4646
max_cpu_loras: Optional[int] = None
47-
device: str = 'cuda'
47+
device: str = 'auto'
4848

4949
def __post_init__(self):
5050
if self.tokenizer is None:
@@ -171,7 +171,7 @@ def add_cli_args(
171171
parser.add_argument('--block-size',
172172
type=int,
173173
default=EngineArgs.block_size,
174-
choices=[8, 16, 32],
174+
choices=[8, 16, 32, 128],
175175
help='token block size')
176176
parser.add_argument('--seed',
177177
type=int,
@@ -264,13 +264,11 @@ def add_cli_args(
264264
help=('Maximum number of LoRAs to store in CPU memory. '
265265
'Must be >= than max_num_seqs. '
266266
'Defaults to max_num_seqs.'))
267-
parser.add_argument(
268-
"--device",
269-
type=str,
270-
default=EngineArgs.device,
271-
choices=["cuda"],
272-
help=('Device type for vLLM execution. '
273-
'Currently, only CUDA-compatible devices are supported.'))
267+
parser.add_argument("--device",
268+
type=str,
269+
default=EngineArgs.device,
270+
choices=["auto", "cuda", "neuron"],
271+
help='Device type for vLLM execution.')
274272
return parser
275273

276274
@classmethod

vllm/engine/llm_engine.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import time
55
import pickle
6+
import importlib
67
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
78
Union)
89

@@ -20,7 +21,8 @@
2021
SequenceGroupOutput, SequenceOutput, SequenceStatus)
2122
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
2223
TokenizerGroup)
23-
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
24+
from vllm.utils import (Counter, set_cuda_visible_devices, get_ip,
25+
get_open_port, get_distributed_init_method)
2426

2527
if ray:
2628
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@@ -31,6 +33,12 @@
3133
logger = init_logger(__name__)
3234
_LOCAL_LOGGING_INTERVAL_SEC = 5
3335

36+
# A map between the device type (in device config) to its worker module.
37+
DEVICE_TO_WORKER_MODULE_MAP = {
38+
"cuda": "vllm.worker.worker",
39+
"neuron": "vllm.worker.neuron_worker",
40+
}
41+
3442
# If the env var is set, it uses the Ray's compiled DAG API
3543
# which optimizes the control plane overhead.
3644
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
@@ -138,10 +146,17 @@ def __init__(
138146
def get_tokenizer_for_seq(self, sequence: Sequence):
139147
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
140148

149+
def _dispatch_worker(self):
150+
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
151+
self.device_config.device_type]
152+
imported_worker = importlib.import_module(worker_module)
153+
Worker = imported_worker.Worker
154+
return Worker
155+
141156
def _init_workers(self):
142157
# Lazy import the Worker to avoid importing torch.cuda/xformers
143158
# before CUDA_VISIBLE_DEVICES is set in the Worker
144-
from vllm.worker.worker import Worker
159+
Worker = self._dispatch_worker()
145160

146161
assert self.parallel_config.world_size == 1, (
147162
"Ray is required if parallel_config.world_size > 1.")
@@ -243,7 +258,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
243258

244259
# Lazy import the Worker to avoid importing torch.cuda/xformers
245260
# before CUDA_VISIBLE_DEVICES is set in the Worker
246-
from vllm.worker.worker import Worker
261+
Worker = self._dispatch_worker()
247262

248263
# Initialize torch distributed process group for the workers.
249264
model_config = copy.deepcopy(self.model_config)

vllm/lora/layers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,10 @@ def __init__(
795795
self.dtype = dtype
796796
self.device = device
797797

798+
@property
799+
def logits_as_hidden_states(self):
800+
return self.base_layer.logits_as_hidden_states
801+
798802
@property
799803
def vocab_size(self):
800804
return self.base_layer.vocab_size

vllm/model_executor/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from vllm.model_executor.input_metadata import InputMetadata
2-
from vllm.model_executor.model_loader import get_model
32
from vllm.model_executor.sampling_metadata import SamplingMetadata
4-
from vllm.model_executor.utils import set_random_seed
3+
from vllm.model_executor.utils import set_random_seed, get_model
54

65
__all__ = [
76
"InputMetadata",

vllm/model_executor/layers/sampler.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.sampling_params import SamplingParams, SamplingType
1111
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
1212
SequenceData, SequenceGroupOutput, SequenceOutput)
13+
from vllm.utils import is_neuron
1314

1415

1516
class Sampler(nn.Module):
@@ -32,6 +33,8 @@ def __init__(self,
3233
org_vocab_size: Optional[int] = None) -> None:
3334
super().__init__()
3435
self.vocab_size = vocab_size
36+
# Transformers-neuronx generate outputs as logits directly.
37+
self.logits_as_hidden_states = is_neuron()
3538
# original vocabulary size (without LoRA).
3639
self.org_vocab_size = org_vocab_size or vocab_size
3740

@@ -55,10 +58,14 @@ def forward(
5558
embedding_bias: Optional[torch.Tensor] = None,
5659
) -> Optional[SamplerOutput]:
5760
# Get the hidden states that we use for sampling.
58-
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
61+
if self.logits_as_hidden_states:
62+
logits = hidden_states
63+
else:
64+
hidden_states = _prune_hidden_states(hidden_states,
65+
sampling_metadata)
5966

60-
# Get the logits for the next tokens.
61-
logits = self._get_logits(hidden_states, embedding, embedding_bias)
67+
# Get the logits for the next tokens.
68+
logits = self._get_logits(hidden_states, embedding, embedding_bias)
6269

6370
# Only perform sampling in the driver worker.
6471
# Note: `_get_logits` is still distributed across TP workers because
@@ -395,7 +402,8 @@ def _sample(
395402
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
396403
is_prompts, sample_indices)
397404
if sampling_type == SamplingType.GREEDY:
398-
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
405+
greedy_samples = torch.argmax(logprobs[sample_indices.long()],
406+
dim=-1)
399407
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
400408
max_best_of = 1
401409
for seq_group, is_prompt in zip(seq_groups, is_prompts):
@@ -407,7 +415,7 @@ def _sample(
407415
"generators": sampling_metadata.generators,
408416
}
409417
multinomial_samples[sampling_type] = _multinomial(
410-
probs[sample_indices], max_best_of, **seeded_args)
418+
probs[sample_indices.long()], max_best_of, **seeded_args)
411419
elif sampling_type == SamplingType.BEAM:
412420
beam_search_logprobs = logprobs[sample_indices]
413421
else:

vllm/model_executor/model_loader.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Utilities for selecting and loading models."""
22
import contextlib
3-
from typing import Optional, Type
3+
from typing import Type
44

55
import torch
66
import torch.nn as nn
77

8-
from vllm.config import DeviceConfig, ModelConfig, LoRAConfig
8+
from vllm.config import DeviceConfig, ModelConfig
99
from vllm.model_executor.models import ModelRegistry
1010
from vllm.model_executor.weight_utils import (get_quant_config,
1111
initialize_dummy_weights)
@@ -37,9 +37,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
3737
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
3838

3939

40-
def get_model(model_config: ModelConfig,
41-
device_config: DeviceConfig,
42-
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
40+
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
41+
**kwargs) -> nn.Module:
42+
lora_config = kwargs.get("lora_config", None)
4343
model_class = _get_model_architecture(model_config)
4444

4545
# Get the (maybe quantized) linear method.

vllm/model_executor/models/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn as nn
55

66
from vllm.logger import init_logger
7-
from vllm.utils import is_hip
7+
from vllm.utils import is_hip, is_neuron
88

99
logger = init_logger(__name__)
1010

@@ -61,6 +61,9 @@
6161
"Sliding window attention is not yet supported in ROCm's flash attention",
6262
}
6363

64+
# Models not supported by Neuron.
65+
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}
66+
6467

6568
class ModelRegistry:
6669

@@ -77,8 +80,15 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
7780
logger.warning(
7881
f"Model architecture {model_arch} is partially supported "
7982
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
83+
elif is_neuron():
84+
if model_arch not in _NEURON_SUPPORTED_MODELS:
85+
raise ValueError(
86+
f"Model architecture {model_arch} is not supported by "
87+
"Neuron for now.")
8088

8189
module_name, model_cls_name = _MODELS[model_arch]
90+
if is_neuron():
91+
module_name = _NEURON_SUPPORTED_MODELS[model_arch]
8292
module = importlib.import_module(
8393
f"vllm.model_executor.models.{module_name}")
8494
return getattr(module, model_cls_name, None)

0 commit comments

Comments
 (0)