Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Pipeline Parallel Support #4412

Merged
merged 6 commits into from
Jul 2, 2024

Conversation

andoorve
Copy link
Collaborator

@andoorve andoorve commented Apr 27, 2024

Adds initial pipeline parallelism support to vLLM.

ToDo:

Milestone 1: POC Prototype

Milestone 2: Mergeable

  • Fix issues related to LLaMa incorrect outputs (Bug filed against PyTorch [Distributed] P2P Operations on NCCL do not respect tag pytorch/pytorch#125079 and worked around)
  • Refactor to move sending and recving code out of models.
  • Check if there's a simpler way to do weight loading
  • Enable multi-node
  • Add RFC for community benefit
  • Add some testing
  • Assert out models that are not supported yet as well as LLMEngine.
  • Check if any PyNCCL changes are necessary
  • Rebase on latest
  • Tests passing

FIX #4461

Goals for this PR:

  • Functional eager-mode PP
  • Support AsyncLLMEngine
  • Support RayGPUExecutor
  • Support LLaMa/GPT2
  • Support chunked prefill

Non-goals for this PR (To be covered in future PRs)

  • Be fully optimized
  • Support LLMEngine (this may be removed in the future)
  • Support any other distributed backend
  • Support models other than LLaMa/GPT2
  • Support CUDAGraph (this is already supported in this PR but issues on this should not be blocking merge)

cc: @zhuohan123 @WoosukKwon @simon-mo @youkaichao

@robertgshaw2-neuralmagic
Copy link
Collaborator

@andoorve - Exciting!!!

@youkaichao
Copy link
Member

@andoorve thanks for the effort! Can you write an RFC to describe the overall design so that people can easily understand it? example rfcs: https://github.com/vllm-project/vllm/issues?q=label%3ARFC+sort%3Aupdated-desc

@andoorve
Copy link
Collaborator Author

@youkaichao Yes for sure, it is one of the TODO items above

@@ -746,7 +763,8 @@ def execute_model(
logits = self.model.compute_logits(hidden_states, sampling_metadata)

# Only perform sampling in the driver worker.
if not self.is_driver_worker:
if (not (is_pipeline_model_parallel_last_rank()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so for tp, the first rank (driver) performs sampling, and for pp, the last rank (the last worker in the last pp's tp group) performs sampling, is this correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the first worker of the last PP's TP group

@andoorve
Copy link
Collaborator Author

Updated the RFC here: #4461 @youkaichao

Let me know if anything needs further elaboration

@andoorve
Copy link
Collaborator Author

FYI pretty sure PyTorch has a bug, filed here: pytorch/pytorch#125079

Worked around this last week by making sending and receiving phase for each model atomic by concatenating residuals and hidden states.

@youkaichao
Copy link
Member

@andoorve hi, I already made the change to pynccl to support multiple groups in #4512 . The first rank can be read from the group argument directly.

@andoorve
Copy link
Collaborator Author

andoorve commented May 1, 2024

Sounds good @youkaichao, I can update mine once that's merged.

Will you also include the change to create the multiple CPU TP groups or should I create a separate PR?

@youkaichao
Copy link
Member

Will you also include the change to create the multiple CPU TP groups or should I create a separate PR?

Yes, that's also in my plan. I will break #4460 down into small pieces to be merged, ETA this week.

@andoorve
Copy link
Collaborator Author

andoorve commented May 1, 2024

Sounds good - I'll revert the PyNCCL changes on this PR and wait for that to be merged to add in

@GindaChen
Copy link
Contributor

GindaChen commented May 1, 2024

Hey @andoorve - This is super exciting!

I'm trying to run a simple example with PP = 2, but encountered some error at runtime. I coded my own example using the simple example script examples/offline_inference.py and added the pipeline_parallel_size=2 in the argument.

- llm = LLM(model="facebook/opt-125m", load_format="dummy")
+ llm = LLM(model="facebook/opt-2.7b", pipeline_parallel_size=2, load_format="dummy")

This is the error I hit: error.txt. It seems like it's complaining the kv_caches list item not found (probably empty?)

ERROR 05-01 20:45:18 worker_base.py:147] Traceback (most recent call last): ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/worker/worker_base.py", line 139, in execute_method ERROR 05-01 20:45:18 worker_base.py:147] return executor(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context ERROR 05-01 20:45:18 worker_base.py:147] return func(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/worker/worker.py", line 140, in determine_num_available_blocks ERROR 05-01 20:45:18 worker_base.py:147] self.model_runner.profile_run() ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context ERROR 05-01 20:45:18 worker_base.py:147] return func(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/worker/model_runner.py", line 844, in profile_run ERROR 05-01 20:45:18 worker_base.py:147] self.execute_model(seqs, kv_caches) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context ERROR 05-01 20:45:18 worker_base.py:147] return func(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/worker/model_runner.py", line 763, in execute_model ERROR 05-01 20:45:18 worker_base.py:147] hidden_states = model_executable(**execute_model_kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl ERROR 05-01 20:45:18 worker_base.py:147] return self._call_impl(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl ERROR 05-01 20:45:18 worker_base.py:147] return forward_call(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/model_executor/models/opt.py", line 300, in forward ERROR 05-01 20:45:18 worker_base.py:147] hidden_states = self.model(input_ids, positions, kv_caches, ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl ERROR 05-01 20:45:18 worker_base.py:147] return self._call_impl(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl ERROR 05-01 20:45:18 worker_base.py:147] return forward_call(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/model_executor/models/opt.py", line 275, in forward ERROR 05-01 20:45:18 worker_base.py:147] return self.decoder(input_ids, positions, kv_caches, attn_metadata) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl ERROR 05-01 20:45:18 worker_base.py:147] return self._call_impl(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl ERROR 05-01 20:45:18 worker_base.py:147] return forward_call(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/model_executor/models/opt.py", line 249, in forward ERROR 05-01 20:45:18 worker_base.py:147] hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) ERROR 05-01 20:45:18 worker_base.py:147] IndexError: list index out of range

I haven't dug into the code deep enough, and curious what is the best way to test and play around with it. If you can point me to some potential starting point, that would be awesome enough. Thanks!

@andoorve
Copy link
Collaborator Author

andoorve commented May 1, 2024

Hey @GindaChen there's a couple of things here,

We haven't supported OPT yet, also the LLMEngine entry point won't work. We're only supporting AsyncLLMEngine right now

@andoorve
Copy link
Collaborator Author

andoorve commented May 1, 2024

The way I would recommend is try with the online serving entrypoint with the LLaMa model. That'd be the best way to start playing around with it

@GindaChen

@youkaichao
Copy link
Member

@andoorve FYI: pynccl with multiple groups is landed at #4512 .

@youkaichao
Copy link
Member

Will you also include the change to create the multiple CPU TP groups or should I create a separate PR?

@andoorve please check out #4566 and see if you need anything else.

@andoorve
Copy link
Collaborator Author

andoorve commented May 2, 2024

LGTM - I guess one thing we can add is PP PyNCCL group

@youkaichao
Copy link
Member

LGTM - I guess one thing we can add is PP PyNCCL group

That's in my plan. Which operation do you need for pp? allreduce? gather? or anything else?

@andoorve
Copy link
Collaborator Author

andoorve commented May 2, 2024

We only need point-to-point, blocking send and blocking recv only. It's not critical though unless torch.distributed.* ops don't work well with CUDA graph.

@SolitaryThinker
Copy link
Contributor

Hi @andoorve,

While benchmarking using your PR, I've consistently encountered engine timeouts with smaller models on setups far below total VRAM capacity, which might relate to the issues you've linked (e.g., [Bug]: Engine iteration timed out #4293, #4430, #4135). I'm using commit 9d698fa.

Setup and Reproduction:
Models and Hardware:

  • Llama-2-7b-hf on 2x A100s
  • llama-160m on 2x RTX A4000s:
python -m vllm.entrypoints.openai.api_server --model JackFram/llama-160m \
--swap-space 16 \
--disable-log-requests \
--pipeline-parallel-size 2
python benchmarks/benchmark_serving.py --backend vllm --model JackFram/llama-160m \
--dataset-name sharegpt \
--dataset-path /workspace/sharegpt.json \
--num-prompts 3

Observation:
Engine hangs almost immediately with 3 running prompts, similar issues with larger models at non-infinite --request-rate.

Proposed Solution:

I traced the issue to async.gather(*coros) in ray_gpu_executor.py returning prematurely because it does not block on ray.ObjectRefs. Inserting ray.wait(coros[1:]) before the gather aligns with the intended code semantics and resolves the hanging.

Branch with fix: https://github.com/SolitaryThinker/vllm/tree/pipeline-parallel-fix

I noticed a new commit from you regarding TP+PP fix, but it didn’t resolve the issue in my environment. Could it be due to missing the latest pynccl changes with groups #4512?

This is my first time handling VLLM and Ray, so any insights or corrections on my understanding or approach would be greatly appreciated.

Additional technical details:
After some digging, I realized that async.gather(*coros) is returning before workers threads have finished. The cause is that coros consist of both futures and ray.ObjectRefs, the latter of which asyncio.gather does not appear to block on. Thus back in the run_engine_loop, the VE that is assumed to be finished executing after this call:

 done, _ = await asyncio.wait(requests_in_progress, return_when=asyncio.FIRST_COMPLETED)

call still could have workers running when a new engine_step task for the VE is created. I'm not sure the exact interaction that causes the hanging, but inserting a ray.wait(coros[1:]) before the gather seems to actually respect the intended semantics of the code to wait for materialization of the ray.objectref.

Thanks
-will

@andoorve
Copy link
Collaborator Author

andoorve commented May 6, 2024

@SolitaryThinker

Thanks for the thorough investigation and the fix!

It's indeed true that there are existing issues with hanging on the current vLLM mainline, and I have not rebased on the latest PyNCCL changes yet. I also am unable to reproduce this issue easily with GPT2 when I try with my own testing. For these reasons I haven't investigated as deeply yet. I'll give your setup and fix a try once I check if multi-node is functional.

I wonder if this is a similar reason as to why the TP-only cases are hanging in the issues mentioned above since there is no such ray.wait in that situation as well. In the meanwhile @rkooo567 maybe you might have some comments?

@youkaichao
Copy link
Member

FYI: I recently find clean up logic is prone to hang, and this is "fixed" in #4508 .

@andoorve
Copy link
Collaborator Author

andoorve commented May 6, 2024

@SolitaryThinker I tried the model/commands above that are giving you issues. I was unable to reproduce on my setup.

My Setup

Started a fresh instance with the following:

GCP g2-standard-48 (4 x NVIDIA L4)
Image: Google, Deep Learning VM with CUDA 12.1, M120, Debian 11, Python 3.10. With CUDA 12.1 preinstalled.
vLLM install @ 04b5fe9

Experiments

Started vLLM with

python -m vllm.entrypoints.openai.api_server --model JackFram/llama-160m \
--swap-space 16 \
--disable-log-requests \
--pipeline-parallel-size 2

Ran the below 3 times:

python benchmarks/benchmark_serving.py --backend vllm --model JackFram/llama-160m \
--dataset-name sharegpt \
--dataset-path ~/sharegpt.json \
--num-prompts 3

Killed vLLM server then repeated the above experiment 2 more times for a total of 3 separate serving instances, 9 benchmark tries, and 27 total requests sent.

See expected benchmark results each time:

Traffic request rate: inf
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:06<00:00,  2.18s/it]
============ Serving Benchmark Result ============
Successful requests:                     3
Benchmark duration (s):                  6.55
Total input tokens:                      72
Total generated tokens:                  1380
Request throughput (req/s):              0.46
Input token throughput (tok/s):          10.99
Output token throughput (tok/s):         210.70
---------------Time to First Token----------------
Mean TTFT (ms):                          29.55
Median TTFT (ms):                        27.27
P99 TTFT (ms):                           34.69
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.42
Median TPOT (ms):                        7.77
P99 TPOT (ms):                           7.80

I wonder if it might only be reproducible on other instances... needs further investigation though.

@zhengxingmao
Copy link

zhengxingmao commented May 7, 2024

A very meaningful feature.
Hi @andoorve ,I have conducted verification based on your PR, and currently, the service can start normally. However, an error occurs when processing requests.
My env:
RTX-4090 2 nodes
vLLM install @ 04b5fe9

Here is the command:

python3 -m vllm.entrypoints.openai.api_server --trust-remote-code --model /data/llvm/llama_weight --gpu-memory-utilization 0.60 --pipeline-parallel-size 2 --port 8000 --host 0.0.0.0 --enforce-eager

And here is error stack:

ERROR 05-07 16:55:03 async_llm_engine.py:43] Engine background task failed
ERROR 05-07 16:55:03 async_llm_engine.py:43] Traceback (most recent call last):
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "python/ray/_raylet.pyx", line 902, in ray._raylet.prepare_args_internal
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py", line 494, in serialize
ERROR 05-07 16:55:03 async_llm_engine.py:43]     return self._serialize_to_msgpack(value)
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py", line 472, in _serialize_to_msgpack
ERROR 05-07 16:55:03 async_llm_engine.py:43]     pickle5_serialized_object = self._serialize_to_pickle5(
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py", line 425, in _serialize_to_pickle5
ERROR 05-07 16:55:03 async_llm_engine.py:43]     raise e
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py", line 420, in _serialize_to_pickle5
ERROR 05-07 16:55:03 async_llm_engine.py:43]     inband = pickle.dumps(
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "/usr/local/lib/python3.10/dist-packages/ray/cloudpickle/cloudpickle_fast.py", line 88, in dumps
ERROR 05-07 16:55:03 async_llm_engine.py:43]     cp.dump(obj)
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "/usr/local/lib/python3.10/dist-packages/ray/cloudpickle/cloudpickle_fast.py", line 733, in dump
ERROR 05-07 16:55:03 async_llm_engine.py:43]     return Pickler.dump(self, obj)
ERROR 05-07 16:55:03 async_llm_engine.py:43] TypeError: cannot pickle 'torch._C.Generator' object

@andoorve
Copy link
Collaborator Author

andoorve commented May 7, 2024

@zhengxingmao Thanks for reporting this! Does this happen without PP? If not, I think it could be some interaction with the following flags with PP.
--trust-remote-code --model /data/llvm/llama_weight --gpu-memory-utilization 0.60

Can you try without these flags and use a model directly from HF? (LLaMa)

@andoorve
Copy link
Collaborator Author

andoorve commented May 7, 2024

@SolitaryThinker

I did some investigation into what you were saying. I think there are real hangs that appear. I tried LLaMa 3 8B with effectively infinite request rate on 2 L4s and saw hangs - not sure if this is the same situation that you found yourself in. Strangely, if I did a warm up request first, the hang went away.

The ray.wait solution doesn't help, and it's not intended for async contexts. See here https://docs.ray.io/en/latest/ray-core/api/doc/ray.wait.html:

This method will issue a warning if it’s running inside an async context. Instead of ray.wait(ray_waitables), you can use await asyncio.wait(ray_waitables).

Also from here, asyncio methods such as asyncio.wait and asyncio.gather should be sufficient:
https://docs.ray.io/en/latest/ray-core/actors/async_api.html

I resolved a hang on my end with:
df9b0c4

Maybe this helps for you?

@andoorve andoorve marked this pull request as ready for review May 7, 2024 20:46
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really awesome work thanks @andoorve. None of my comments need to block merging this. I'd be happy to try out my suggestions post-merge.

Would also be nice to fast-follow with single-node support for multiprocessing backend, I think that should be trivial.

Looking forward to seeing the perf test results!

if get_pp_group().is_last_rank:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: redundant else

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

return self.output_buffers["hidden_states"]
if get_pp_group().is_last_rank:
return self.output_buffers["hidden_states"]
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: redundant else

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

remote() # type: ignore
)
else:
await self.engine.stop_remote_worker_execution_loop_async()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andoorve I wonder whether we could keep this within the engine and have each TP group stop/start their worker execution loops independently?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill We could do that, we should talk more about this. I guess we can just pass in a TP parameter for that case. Also, how would this level of granularity help?

Comment on lines +298 to +305
self.scheduler = [
Scheduler(scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size)
for _ in range(parallel_config.pipeline_parallel_size)
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to encapsulate all of these loops in a MultiScheduler class that overrides all the necessary scheduler methods.. could either be a subclass of we could introduce ABC or protocol for this.

That way we can also use singular scheduler directly in PP=1 case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought the same. I defaulted to List of scheduler because there weren't too many methods that were necessary in this case but it's definitely a good refactor for readability.

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
@wukaixingxp
Copy link
Contributor

Hi! I am super happy to see VLLM is going to support Pipeline Parallel. I tried this PR fork, and got this errorValueError: Total number of hidden layers (X) must be divisible by pipeline parallel size (N). I hope we can remove this hard requirement by allowing first N-1 gpu to load floor(X/N) layers, then the last gpu load the remaining X % N layers for X layers, N gpu setup. Thank you so much!

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the hardwork! very excited to see this new feature ❤️

@youkaichao youkaichao merged commit c5832d2 into vllm-project:main Jul 2, 2024
70 checks passed
@andoorve
Copy link
Collaborator Author

andoorve commented Jul 2, 2024

@wukaixingxp Certainly! Please look for follow up PR

prashantgupta24 pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 3, 2024
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jul 7, 2024
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 8, 2024
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
@MuYu-zhi
Copy link

MuYu-zhi commented Jul 9, 2024

@andoorve hi, I have some confusions about the PP implementation in this PR, hope you can help make clear. Why is “coroutine + async” implementation adopted here, other than torch way "multiprocess + rpc"?What considerations led to this choice? Is it because this way aligns better with the current architecture of vLLM? And, do you have
any clue about the performance data between these two implementations?

@andoorve
Copy link
Collaborator Author

andoorve commented Jul 9, 2024

Hi @MuYu-zhi thanks for your question. Basically, it's due to two reasons, it naturally aligns very well with the existing vLLM async architecture that was already there as well as the fact that it's quite simple to do. We don't have data on multiprocess + RPC but CPU usage with async seems to be high. This may be addressed by other PRs.

@binxuan
Copy link

binxuan commented Jul 11, 2024

Hi Thanks for the great PR! Looks like adding pipeline parallel for Mixtral is pretty straightforward based on the code change for llama class. Is there someone already working on this expansion? Otherwise I can push my local change as well.

@andoorve
Copy link
Collaborator Author

Thanks @binxuan! Please push your local changes

@binxuan
Copy link

binxuan commented Jul 12, 2024

Thanks @binxuan! Please push your local changes

So here are my local changes link
Before I request the PR, could you point me to where I should add approriate test cases?
I did the test on two P4D instances using Mixtral 8x22B-instruct and the responses look reasonable to me

@andoorve
Copy link
Collaborator Author

Hey @binxuan, you can directly create the PR, please paste the responses in that PR. We are still working on correctness tests

xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Alvant <alvasian@yandex.ru>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC]: Initial support for Pipeline Paralleism