-
Notifications
You must be signed in to change notification settings - Fork 14
Create TorchaxMergedColumnParallelLinearWithLoRA lora wrapper for single chip #496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
39f220e
to
0b0c7d6
Compare
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
also cc @kyuyeunk . This is the PR that I'm working on. I think your refactoring main...support_fp8_quant should be able to play well with my PR. |
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
a1a77dd
to
0dad381
Compare
Thanks @vanbasten23! This is the proper PR for the refactoring work: #512. As mentioned in the description, my PR removes all the custom torchax layers (like |
I think so. All this PR needs is that this line should work: aka this forward pass |
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
@classmethod | ||
def get_punica_wrapper(cls) -> str: | ||
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" | ||
return "tpu_commons.lora.torch_punica_tpu.PunicaWrapperTPU" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe make it reusable for jax as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe check if MODEL_IMPL_TYPE=vllm
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great callout. If we add lora to JAX, then we'll check MODEL_IMPL_TYPE=vllm
here.
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
device=device) | ||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) | ||
|
||
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need to set jax.default_device
here?
punica_wrapper.move_to_device(mesh) | ||
|
||
jax_inputs = [] | ||
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
|
||
|
||
# https://github.com/vllm-project/vllm/blob/279a5f31b3faa6f40759516efa5c742f637ab8b7/tests/lora/utils.py | ||
class DummyLoRAManager: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any difference from the one in the vLLM main repo? I wonder if it is possible to just to use the on in vLLM
vllm_config.model_config.model, | ||
vllm_config.scheduler_config.max_num_batched_tokens, | ||
vllm_config.parallel_config.tensor_parallel_size, | ||
"MergedColumnParallelLinear") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does it need to get fuse_matmuls
here, instead of just using the pre-calculated value of fuse_matmuls
as shard_merged_column_parallel_linear.fuse_matmuls
?
Description
This PR creates a TorchaxMergedColumnParallelLinearWithLoRA lora wrapper as we discussed in the design. This lora wrapper resembles MergedColumnParallelLinearWithLoRA in vLLM
Tests
MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax pytest -rs -vv tests/lora/test_layers.py -k test_column_parallel_packed
Checklist
Before submitting this PR, please make sure: