forked from mesolitica/vllm-whisper
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Misc] Add CustomOp interface for device portability (vllm-project#5255)
- Loading branch information
1 parent
c6282ef
commit e78010a
Showing
7 changed files
with
100 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch.nn as nn | ||
|
||
from vllm.utils import is_cpu, is_hip | ||
|
||
|
||
class CustomOp(nn.Module): | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__() | ||
self._forward_method = self.dispatch_forward() | ||
|
||
def forward(self, *args, **kwargs): | ||
return self._forward_method(*args, **kwargs) | ||
|
||
def forward_native(self, *args, **kwargs): | ||
"""PyTorch-native implementation of the forward method. | ||
This method is optional. If implemented, it can be used with compilers | ||
such as torch.compile or PyTorch XLA. Also, it can be used for testing | ||
purposes. | ||
""" | ||
raise NotImplementedError | ||
|
||
def forward_cuda(self, *args, **kwargs): | ||
raise NotImplementedError | ||
|
||
def forward_hip(self, *args, **kwargs): | ||
# By default, we assume that HIP ops are compatible with CUDA ops. | ||
return self.forward_cuda(*args, **kwargs) | ||
|
||
def forward_xpu(self, *args, **kwargs): | ||
# By default, we assume that XPU ops are compatible with CUDA ops. | ||
# NOTE(woosuk): This is a placeholder for future extensions. | ||
return self.forward_cuda(*args, **kwargs) | ||
|
||
def forward_cpu(self, *args, **kwargs): | ||
# By default, we assume that CPU ops are compatible with CUDA ops. | ||
return self.forward_cuda(*args, **kwargs) | ||
|
||
def forward_tpu(self, *args, **kwargs): | ||
# By default, we assume that TPU ops are compatible with the | ||
# PyTorch-native implementation. | ||
# NOTE(woosuk): This is a placeholder for future extensions. | ||
return self.forward_native(*args, **kwargs) | ||
|
||
def forward_gaudi(self, *args, **kwargs): | ||
# By default, we assume that Gaudi ops are compatible with the | ||
# PyTorch-native implementation. | ||
# NOTE(woosuk): This is a placeholder for future extensions. | ||
return self.forward_native(*args, **kwargs) | ||
|
||
def dispatch_forward(self): | ||
# NOTE(woosuk): Here we assume that vLLM was built for only one | ||
# specific backend. Currently, we do not support dynamic dispatching. | ||
if is_hip(): | ||
return self.forward_hip | ||
elif is_cpu(): | ||
return self.forward_cpu | ||
else: | ||
return self.forward_cuda |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters