diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index ddc553f3a558..4fb107b0fa63 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -9,6 +9,9 @@ import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore import functools +import importlib +import inspect + class XPU_Accelerator(DeepSpeedAccelerator): @@ -17,6 +20,7 @@ def __init__(self): self._communication_backend_name = 'ccl' self._compile_backend = "inductor" self.aligned_tensors = [] + self.class_dict = None def is_synchronized_device(self): return False @@ -257,35 +261,29 @@ def on_accelerator(self, tensor): else: return False + def _lazy_init_class_dict(self): + if self.class_dict: + return + + op_builder_module = importlib.import_module(self.op_builder_dir()) + + # get op builder class from op_builder/xpu/__init__.py + self.class_dict = {} + for class_name, class_obj in inspect.getmembers(op_builder_module, inspect.isclass): + self.class_dict[class_name] = class_obj + # create an instance of op builder and return, name specified by class_name - def create_op_builder(self, op_name): - builder_class = self.get_op_builder(op_name) - if builder_class != None: - return builder_class() - return None + def create_op_builder(self, class_name): + builder_class = self.get_op_builder(class_name) + return builder_class() # return an op builder class, name specified by class_name def get_op_builder(self, class_name): - try: - # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed - # if successful this also means we're doing a local install and not JIT compile path - from op_builder import __deepspeed__ # noqa: F401 # type: ignore - from op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder, PackbitsBuilder - except ImportError: - from deepspeed.ops.op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder, PackbitsBuilder - - if class_name == "AsyncIOBuilder": - return AsyncIOBuilder - elif class_name == "CPUAdagradBuilder": - return CPUAdagradBuilder - elif class_name == "CPUAdamBuilder": - return CPUAdamBuilder - elif class_name == "FusedAdamBuilder": - return FusedAdamBuilder - elif class_name == "PackbitsBuilder": - return PackbitsBuilder + self._lazy_init_class_dict() + if class_name in self.class_dict: + return self.class_dict[class_name] else: - return None + return self.class_dict['NotImplementedBuilder'] def build_extension(self): try: diff --git a/deepspeed/comm/ccl.py b/deepspeed/comm/ccl.py index f2a6cb6a36dc..cdf4c030f5d6 100644 --- a/deepspeed/comm/ccl.py +++ b/deepspeed/comm/ccl.py @@ -8,13 +8,14 @@ import torch from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import NotImplementedBuilder from .reduce_op import ReduceOp from .torch import TorchBackend def build_ccl_op(): builder = get_accelerator().create_op_builder("CCLCommBuilder") - if builder is None: + if builder is None or NotImplementedBuilder: return None ccl_cpp_module = builder.load() print(f'DeepSpeed {builder.absolute_name()} built successfully') diff --git a/op_builder/xpu/__init__.py b/op_builder/xpu/__init__.py index bf82e4248338..093f90b30234 100755 --- a/op_builder/xpu/__init__.py +++ b/op_builder/xpu/__init__.py @@ -7,4 +7,7 @@ from .cpu_adagrad import CPUAdagradBuilder from .fused_adam import FusedAdamBuilder from .async_io import AsyncIOBuilder +from .inference import InferenceBuilder +from .flash_attn import FlashAttentionBuilder +from .no_impl import NotImplementedBuilder from .packbits import PackbitsBuilder diff --git a/op_builder/xpu/flash_attn.py b/op_builder/xpu/flash_attn.py new file mode 100644 index 000000000000..c8c2674d5d27 --- /dev/null +++ b/op_builder/xpu/flash_attn.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +from .builder import SYCLOpBuilder + + +class FlashAttentionBuilderObject(): + + def __init__(self): + pass + + # general functions + def flash_attn_func_v2(self, q, k, v, dropout_p, softmax_scale, is_causal): + try: + import torch + import intel_extension_for_pytorch # noqa + return torch.nn.functional.scaled_dot_product_attention(q, + k, + v, + dropout_p=dropout_p, + is_causal=is_causal, + scale=softmax_scale) + except ImportError: + raise ImportError( + "Please install pytorch and intel_extension_for_pytorch to include scaled dot product attention.") + + +class FlashAttentionBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_FlashAttention" + NAME = "flash_attn" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.{self.NAME}_op' + + def sources(self): + return + + def include_paths(self): + return [] + + def extra_ldflags(self): + return [] + + def cxx_args(self): + return [] + + def load(self): + return FlashAttentionBuilderObject() diff --git a/op_builder/xpu/inference.py b/op_builder/xpu/inference.py new file mode 100644 index 000000000000..9114dcc2c315 --- /dev/null +++ b/op_builder/xpu/inference.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +from .builder import SYCLOpBuilder + + +class InferenceBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE" + NAME = "transformer_inference" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.transformer.inference.{self.NAME}_op' + + def sources(self): + return + + def include_paths(self): + return [] + + def extra_ldflags(self): + return [] + + def cxx_args(self): + return [] + + def load(self): + try: + import intel_extension_for_pytorch.deepspeed + return intel_extension_for_pytorch.deepspeed.transformer_inference.transformer_inference + except ImportError: + raise ImportError("Please install intel-extension-for-pytorch >= 2.1.30 to include DeepSpeed kernels.") diff --git a/op_builder/xpu/no_impl.py b/op_builder/xpu/no_impl.py new file mode 100644 index 000000000000..8b294f70c279 --- /dev/null +++ b/op_builder/xpu/no_impl.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import SYCLOpBuilder + + +class NotImplementedBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED" + NAME = "deepspeed_not_implemented" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.{self.NAME}_op' + + def load(self, verbose=True): + raise ValueError("This op had not been implemented on XPU backend.") + + def sources(self): + return [] + + def cxx_args(self): + return [] + + def extra_ldflags(self): + return [] + + def include_paths(self): + return []