Skip to content

[RFC]: AWS Neuron 2.23 NxD Inference with vLLM V0 #15970

Closed
@mrinalks

Description

@mrinalks

Motivation.

AWS Neuron has released the NeuronX Distributed (NxD) Inference library, a PyTorch-based solution that has performance optimizations relevant to AWS Trainium and Inferentia instances. NxD Inference is the path forward for optimized inference on Neuron. The Transformers NeuronX (TNx) library will soon reach the end of support.

This RFC integrates NxD Inference into vLLM and adds minor features to TNx. The integration currently targets vLLM’s V0 architecture, with plans to migrate to V1 Architecture.

These changes streamline Neuron Serving with vLLM while maintaining while maintaining compatibility and performance for inference workloads on AWS Trainium and Inferentia.

AWS Neuron is committed to supporting vLLM and is planning an engineering roadmap with deeper integration. We will share the next RFC with the vLLM community for feedback once it’s ready.

We are adding the following features to the current RFC:

  1. NeuronX Distributed (NxD) Inference Support
  2. Speculative Decoding
  3. Dynamic On-device Sampling
  4. Quantized Model Support (limited to TNx)
  5. Multi-Modal Support (Llama-3.2)
  6. Multi-LoRA Serving

Note: The changes will be isolated to Neuron-specific logic and will not impact other platforms.

Proposed Change.

  1. NeuronX Distributed (NxD) Inference Support

    1. Allow customers to select a framework based on preference or availability. Default to neuronx-distributed-inference (NxD); if unavailable, fall back to transformers-neuronx (TNx).
    2. Support inference using NxD by adding a worker/neuronx_distributed_model_runner.py
    3. Add framework detection utility that returns the current framework in use.
  2. Speculative Decoding

    1. To enable speculative decoding with NxD, we added worker/multi_step_neuronx_distributed_model_runner.py.
    2. To enable speculative decoding with TnX, we added worker/multi_step_neuron_model_runner.py. This model runner is chosen in neuron_worker.py if speculation is enabled.
  3. Dynamic On-device Sampling

    1. Extract the sampling params (top_k, top_p, temperature) and add them to execute_model().
  4. Multi-modal model support

    1. Add support for MLlama multi-modal models
  5. Quantized model support (limited to TNx)

    1. Support INT8 and FP8 quantizations
  6. Multi-LoRA Serving

    1. Allow loading and using LoRA adapters with NxD.
    2. Supports only loading of Lora adapters at server startup. Dynamic loading of LoRA will be supported along with V1 Support.

Feedback Period.

1 week (due on April 9, 2025).

CC List.

Any Other Things.

  • The RFC is focused on V0 architecture and does not implement V1 support for Neuron. V1 architecture support is being actively planned and will be shared in a separate RFC.
  • The RFC introduces significant code changes to Neuron-related paths, which are organized into feature-specific PRs to streamline the review process.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions