Skip to content

[RFC] Initial Support for AWS Inferentia #1866

Closed
@liangfu

Description

@liangfu

Proposal

We propose to integrate transformers-neuronx to be the execution engine in vLLM for supporting LLM inference on Inferentia. This would require changes on both transformers-neuronx and vLLM.

Changes to transformers-neuronx

  1. Support batch size 1 prompt encoding, while share same cache space with max batch size decoding.
  2. Support batch-dependent KV cache update. Each sequence will have a specified position_id to update cache.
  3. Support virtual dynamic batching. This would enable multi-batch prompt encoding virtually agnostic to vLLM.

Changes to vLLM

Implementation Details

Model-specific (e.g. llama specific code for neuron) forward function

def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    kv_caches: List[KVCache],
    input_metadata: InputMetadata,
    cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
    batch_size, n_active_tokens = input_ids.shape

    with torch.inference_mode():
        seq_ids = []
        block_size = self.model.context_buckets[-1]
        if input_metadata.num_generation_tokens == 0:
            num_prompts = input_metadata.num_prompts
            seq_ids = torch.zeros(num_prompts, 1, dtype=torch.int64, device='cpu')
            anchor = 0
            for prompt_id in range(num_prompts):
                seq_ids[prompt_id] = input_metadata.slot_mapping[anchor] // block_size
                anchor += input_metadata.prompt_lens[prompt_id]
        else:
            seq_ids = input_metadata.block_tables

        logits = self.model(input_ids, cache_ids=positions, start_ids=seq_ids)
        next_tokens = self.sampler(logits, input_metadata)

    return next_tokens

Model compilation

def load_weights(self,
                    model_name_or_path: str,
                    cache_dir: Optional[str] = None,
                    load_format: str = "auto",
                    revision: Optional[str] = None,
                    **kwargs):
    from transformers_neuronx.llama.model import LlamaForSampling

    if not os.path.exists(f"{model_name_or_path}-split"):
        from transformers.models.llama import LlamaForCausalLM
        from transformers_neuronx.module import save_pretrained_split

        hf_model = LlamaForCausalLM.from_pretrained(model_name_or_path, low_cpu_mem_usage=True)
        save_pretrained_split(hf_model, f"{model_name_or_path}-split")

    self.model = LlamaForSampling.from_pretrained(f"{model_name_or_path}-split", **kwargs)
    self.model.to_neuron()

Model-agnostic (e.g. generic model loader)

# Load the weights from the cached or downloaded files.
from transformers_neuronx.config import NeuronConfig, ContinuousBatchingConfig

continuous_batching_config = ContinuousBatchingConfig(batch_size_for_shared_caches=scheduler_config.max_num_seqs)
neuron_config = NeuronConfig(continuous_batching=continuous_batching_config)
model.load_weights(model_config.model, model_config.download_dir,
                    model_config.load_format, model_config.revision,
                    tp_degree=parallel_config.tp_degree,
                    amp='f32', neuron_config=neuron_config,
                    context_length_estimate=[scheduler_config.max_model_len],
                    n_positions=[scheduler_config.max_model_len],
                    batch_size=scheduler_config.max_num_seqs)

Related Resources

Stable release versions of transformers-neuronx packages can be found from https://pip.repos.neuron.amazonaws.com/transformers-neuronx/ . We can install transformers-neuronx pacakge with

pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com

Metadata

Metadata

Assignees

No one assigned

    Labels

    aws-neuronRelated to AWS Inferentia & Trainium

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions