Closed
Description
Proposal
We propose to integrate transformers-neuronx to be the execution engine in vLLM for supporting LLM inference on Inferentia. This would require changes on both transformers-neuronx and vLLM.
Changes to transformers-neuronx
- Support batch size 1 prompt encoding, while share same cache space with max batch size decoding.
- Support batch-dependent KV cache update. Each sequence will have a specified position_id to update cache.
- Support virtual dynamic batching. This would enable multi-batch prompt encoding virtually agnostic to vLLM.
Changes to vLLM
- Make CUDA kernel compilation optional, so that when we are trying to perform LLM inference on inf2 instances we don’t necessarily compile the CUDA kernels. Meanwhile, we would still keep CUDA kernel compilation enabled by default. [Neuron] Add an option to build with neuron #2065
- Add transformers-neuronx package as a (optional) thirdparty dependency of vllm. Note that transformers-neuronx would further depend on torch-neuronx, torch-xla, neuronx-cc and many others. [Neuron] Add an option to build with neuron #2065
- Configure transformers-neuronx to enable continuous batching feature in vLLM model loader. Support inference with transformers-neuronx #2569
- Compile the model after loading weights. Support inference with transformers-neuronx #2569
- Execute model with transformers-neuronx. Support inference with transformers-neuronx #2569
Implementation Details
Model-specific (e.g. llama specific code for neuron) forward function
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
batch_size, n_active_tokens = input_ids.shape
with torch.inference_mode():
seq_ids = []
block_size = self.model.context_buckets[-1]
if input_metadata.num_generation_tokens == 0:
num_prompts = input_metadata.num_prompts
seq_ids = torch.zeros(num_prompts, 1, dtype=torch.int64, device='cpu')
anchor = 0
for prompt_id in range(num_prompts):
seq_ids[prompt_id] = input_metadata.slot_mapping[anchor] // block_size
anchor += input_metadata.prompt_lens[prompt_id]
else:
seq_ids = input_metadata.block_tables
logits = self.model(input_ids, cache_ids=positions, start_ids=seq_ids)
next_tokens = self.sampler(logits, input_metadata)
return next_tokens
Model compilation
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
**kwargs):
from transformers_neuronx.llama.model import LlamaForSampling
if not os.path.exists(f"{model_name_or_path}-split"):
from transformers.models.llama import LlamaForCausalLM
from transformers_neuronx.module import save_pretrained_split
hf_model = LlamaForCausalLM.from_pretrained(model_name_or_path, low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = LlamaForSampling.from_pretrained(f"{model_name_or_path}-split", **kwargs)
self.model.to_neuron()
Model-agnostic (e.g. generic model loader)
# Load the weights from the cached or downloaded files.
from transformers_neuronx.config import NeuronConfig, ContinuousBatchingConfig
continuous_batching_config = ContinuousBatchingConfig(batch_size_for_shared_caches=scheduler_config.max_num_seqs)
neuron_config = NeuronConfig(continuous_batching=continuous_batching_config)
model.load_weights(model_config.model, model_config.download_dir,
model_config.load_format, model_config.revision,
tp_degree=parallel_config.tp_degree,
amp='f32', neuron_config=neuron_config,
context_length_estimate=[scheduler_config.max_model_len],
n_positions=[scheduler_config.max_model_len],
batch_size=scheduler_config.max_num_seqs)
Related Resources
Stable release versions of transformers-neuronx packages can be found from https://pip.repos.neuron.amazonaws.com/transformers-neuronx/ . We can install transformers-neuronx pacakge with
pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com