Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add CustomOp interface for device portability #5255

Merged
merged 15 commits into from
Jun 5, 2024
Merged

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Jun 4, 2024

Currently, the custom layers have two issues. First, they directly import _custom_ops, which are not supported for devices such as TPU and Gaudi. Second, they assume that the custom ops are implemented in the same way for all devices. To address these issues, the PR adds CustomOp interface, an indirection layer that implements the device-specific forward methods. This allows the custom kernels to be lazily imported only for the associated device.

class CustomOp(nn.Module):

    def forward(self, *args, **kwargs):
        if not hasattr(self, "_forward_method"):
            self._forward_method = self.dispatch_forward()
        return self._forward_method(*args, **kwargs)

    def forward_native(self, *args, **kwargs):
        """PyTorch-native implementation of the forward method."""
        raise NotImplementedError

    def forward_cuda(self, *args, **kwargs):
        raise NotImplementedError

    def forward_hip(self, *args, **kwargs):
        ...

    def forward_xpu(self, *args, **kwargs):
        ...

    def forward_cpu(self, *args, **kwargs):
        ...

    def forward_tpu(self, *args, **kwargs):
        ...

    def forward_gaudi(self, *args, **kwargs):
        ...

    def dispatch_forward(self):
        if is_hip():
            return self.forward_hip
        elif is_cpu():
            return self.forward_cpu
        else:
            return self.forward_cuda

According to the benchmarks, the lazy import does not affect the performance:

$ python benchmarks/benchmark_latency.py --model JackFram/llama-68m
# main
Avg latency: 0.3685314357979223 seconds
# This PR
Avg latency: 0.3665724984719418 seconds

@WoosukKwon WoosukKwon marked this pull request as draft June 4, 2024 17:30
@WoosukKwon WoosukKwon marked this pull request as ready for review June 4, 2024 17:49
@WoosukKwon WoosukKwon requested a review from comaniac June 4, 2024 17:50
vllm/model_executor/layers/activation.py Show resolved Hide resolved
vllm/model_executor/custom_op.py Show resolved Hide resolved
Comment on lines 50 to 56
def dispatch_forward(self):
if is_hip():
return self.forward_hip
elif is_cpu():
return self.forward_cpu
else:
return self.forward_cuda
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we need more flexibility here in the future. For example, we may build a wheel with both CPU and CUDA enabled, but we want to configure which one to use on the fly. On the other hand, this may not be necessary at this moment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. For now, vLLM is bound to a specific backend at the build time. I added a note that we do not support dynamic dispatching currently.

@youkaichao
Copy link
Member

pytorch has quite a lot dispatching utilities, can we reuse some? forward usually has a special meaning in pytorch, and it has a special handling in torch.compile . Having these new forward_xxx method might break future torch.compile integration.

@WoosukKwon
Copy link
Collaborator Author

pytorch has quite a lot dispatching utilities, can we reuse some? forward usually has a special meaning in pytorch, and it has a special handling in torch.compile . Having these new forward_xxx method might break future torch.compile integration.

@youkaichao Good point. I moved the dispatching logic to __init__ so that it is not included in the scope of torch.compile. After this change, I believe the PR itself does not break torch.compile.

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@bnellnm
Copy link
Contributor

bnellnm commented Jun 4, 2024

Hi @WoosukKwon , would you mind taking a look at #5047 before you land this? I've been working on registering all the custom operations via TORCH_LIBRARY (which also has per device dispatching). I'm worried these changes might be at odds with TORCH_LIBRARY/pytorch dispatching.

@WoosukKwon
Copy link
Collaborator Author

@bnellnm Thanks for bringing it up. If I understand correctly, this PR is orthogonal to yours. Basically, I believe your PR does NOT include per-device dispatching, because vLLM always builds the custom library for at most one device. Also, in our situation, dispatching can't be implemented at the C++ level, because we'd like to use Python libraries to implement some custom ops.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm :)

@WoosukKwon WoosukKwon merged commit 41ca62c into main Jun 5, 2024
100 of 103 checks passed
@WoosukKwon WoosukKwon deleted the dispatcher branch June 5, 2024 16:18
blinkbear pushed a commit to blinkbear/vllm that referenced this pull request Jun 6, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 27, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 8, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants