Skip to content

Conversation

vanbasten23
Copy link
Collaborator

@vanbasten23 vanbasten23 commented Aug 18, 2025

Description

This PR creates a TorchaxMergedColumnParallelLinearWithLoRA lora wrapper as we discussed in the design. This lora wrapper resembles MergedColumnParallelLinearWithLoRA in vLLM

Tests

MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax pytest -rs -vv tests/lora/test_layers.py -k test_column_parallel_packed

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
@QiliangCui QiliangCui force-pushed the main branch 2 times, most recently from 39f220e to 0b0c7d6 Compare August 19, 2025 06:59
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
@vanbasten23 vanbasten23 requested review from hfan and lsy323 and removed request for hfan August 20, 2025 05:15
@vanbasten23 vanbasten23 marked this pull request as ready for review August 20, 2025 05:15
@vanbasten23
Copy link
Collaborator Author

@hfan @lsy323 could you please take a look when you get a chance? Thanks!

@vanbasten23 vanbasten23 changed the title [do not review yet] Create column parallel linear with lora wrapper Create TorchaxMergedColumnParallelLinearWithLoRA lora wrapper for single chip Aug 20, 2025
@vanbasten23
Copy link
Collaborator Author

also cc @kyuyeunk . This is the PR that I'm working on. I think your refactoring main...support_fp8_quant should be able to play well with my PR.

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
@vanbasten23 vanbasten23 force-pushed the xiowei/write_ColumnParallelLinearWithLoRA branch from a1a77dd to 0dad381 Compare August 20, 2025 18:40
@kyuyeunk
Copy link
Collaborator

also cc @kyuyeunk . This is the PR that I'm working on. I think your refactoring main...support_fp8_quant should be able to play well with my PR.

Thanks @vanbasten23! This is the proper PR for the refactoring work: #512. As mentioned in the description, my PR removes all the custom torchax layers (like JaxMergedColumnParallelLinear) in favor of utilizing pre-existing vLLM APIs. Will this PR still work despite the refactoring?

@vanbasten23
Copy link
Collaborator Author

Thanks @vanbasten23! This is the proper PR for the refactoring work: #512. As mentioned in the description, my PR removes all the custom torchax layers (like JaxMergedColumnParallelLinear) in favor of utilizing pre-existing vLLM APIs. Will this PR still work despite the refactoring?

I think so. All this PR needs is that this line should work: aka this forward pass self.base_layer(x). In the current PR, self.base_layer is JaxMergedColumnParallelLinear. After your pr is in, self.base_layer becomes MergedColumnParallelLinear defined in vLLM. Behavior-wise, both JaxMergedColumnParallelLinear and MergedColumnParallelLinear should be identical.

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
return "tpu_commons.lora.torch_punica_tpu.PunicaWrapperTPU"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make it reusable for jax as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe check if MODEL_IMPL_TYPE=vllm here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great callout. If we add lora to JAX, then we'll check MODEL_IMPL_TYPE=vllm here.

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
device=device)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)

with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to set jax.default_device here?

punica_wrapper.move_to_device(mesh)

jax_inputs = []
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto



# https://github.com/vllm-project/vllm/blob/279a5f31b3faa6f40759516efa5c742f637ab8b7/tests/lora/utils.py
class DummyLoRAManager:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any difference from the one in the vLLM main repo? I wonder if it is possible to just to use the on in vLLM

vllm_config.model_config.model,
vllm_config.scheduler_config.max_num_batched_tokens,
vllm_config.parallel_config.tensor_parallel_size,
"MergedColumnParallelLinear")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it need to get fuse_matmuls here, instead of just using the pre-calculated value of fuse_matmuls as shard_merged_column_parallel_linear.fuse_matmuls?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants