Skip to content

Commit

Permalink
Abstract accelerator (step 2) (#2560)
Browse files Browse the repository at this point in the history
* Abstract accelerator (step 2)

* more flex op_builder path for both installation and runtime

* add SpatialInferenceBuilder into cuda_accelerator.py

* use reflection to make cuda_accelerator adapt to CUDA op builder change automatically

* clean up deepspeed/__init__.py

* add comments in cuda_accelerator for no torch path

* Update deepspeed/env_report.py

Change env_report.py according to suggestion

Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>

* reduce the range of try...except for better code clarity

* Add porting for deepspeed/ops/random_ltd/dropping_utils.py

* move accelerator to top directory and create symlink under deepspeed

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
4 people authored Jan 7, 2023
1 parent 95d9a1b commit 9548d48
Show file tree
Hide file tree
Showing 32 changed files with 206 additions and 158 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json
recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc
recursive-include op_builder *.py
recursive-include benchmarks *.py
recursive-include accelerator *.py
1 change: 1 addition & 0 deletions MANIFEST_win.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ recursive-include deepspeed *.tr
recursive-exclude deepspeed/ops/csrc *.cpp *.h *.cu *.cuh *.cc
prune csrc
prune op_builder
prune accelerator
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,12 +1,40 @@
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
import torch.cuda
import os
import pkgutil
import importlib

from .abstract_accelerator import DeepSpeedAccelerator
# During setup stage torch may not be installed, pass on no torch will
# allow op builder related API to be executed.
try:
import torch.cuda
except ImportError:
pass


class CUDA_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'cuda'
self._communication_backend_name = 'nccl'

# begin initialize for create_op_builder()
# put all valid class name <--> class type mapping into class_dict
op_builder_dir = self.op_builder_dir()
op_builder_module = importlib.import_module(op_builder_dir)

for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]):
# avoid self references
if module_name != 'all_ops' and module_name != 'builder' and module_name != 'builder_names':
module = importlib.import_module("{}.{}".format(
op_builder_dir,
module_name))
for member_name in module.__dir__():
if member_name.endswith(
'Builder'
) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes
if not member_name in self.class_dict:
self.class_dict[member_name] = getattr(module, member_name)
# end initialize for create_op_builder()

# Device APIs
def device_name(self, device_index=None):
if device_index == None:
Expand Down Expand Up @@ -194,44 +222,21 @@ def on_accelerator(self, tensor):
return False

def op_builder_dir(self):
return "deepspeed.ops.op_builder"
try:
# during installation time op_builder is visible, otherwise return deepspeed.ops.op_builder
import op_builder # noqa: F401
return "op_builder"
except ImportError:
return "deepspeed.ops.op_builder"

# dict that holds class name <--> class type mapping i.e.
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
# this dict will be filled at init stage
class_dict = {}

def create_op_builder(self, class_name):
from deepspeed.ops.op_builder import AsyncIOBuilder, CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, FusedLambBuilder, QuantizerBuilder, SparseAttnBuilder, StochasticTransformerBuilder, TransformerBuilder, InferenceBuilder, UtilsBuilder
from deepspeed.ops.op_builder.builder_names import AsyncIOBuilder as AsyncIOBuilderName
from deepspeed.ops.op_builder.builder_names import CPUAdagradBuilder as CPUAdagradBuilderName
from deepspeed.ops.op_builder.builder_names import CPUAdamBuilder as CPUAdamBuilderName
from deepspeed.ops.op_builder.builder_names import FusedAdamBuilder as FusedAdamBuilderName
from deepspeed.ops.op_builder.builder_names import FusedLambBuilder as FusedLambBuilderName
from deepspeed.ops.op_builder.builder_names import QuantizerBuilder as QuantizerBuilderName
from deepspeed.ops.op_builder.builder_names import SparseAttnBuilder as SparseAttnBuilderName
from deepspeed.ops.op_builder.builder_names import StochasticTransformerBuilder as StochasticTransformerBuilderName
from deepspeed.ops.op_builder.builder_names import TransformerBuilder as TransformerBuilderName
from deepspeed.ops.op_builder.builder_names import InferenceBuilder as InferenceBuilderName
from deepspeed.ops.op_builder.builder_names import UtilsBuilder as UtilsBuilderName

if class_name == AsyncIOBuilderName:
return AsyncIOBuilder()
elif class_name == CPUAdagradBuilderName:
return CPUAdagradBuilder()
elif class_name == CPUAdamBuilderName:
return CPUAdamBuilder()
elif class_name == FusedAdamBuilderName:
return FusedAdamBuilder()
elif class_name == FusedLambBuilderName:
return FusedLambBuilder()
elif class_name == QuantizerBuilderName:
return QuantizerBuilder()
elif class_name == SparseAttnBuilderName:
return SparseAttnBuilder()
elif class_name == StochasticTransformerBuilderName:
return StochasticTransformerBuilder()
elif class_name == TransformerBuilderName:
return TransformerBuilder()
elif class_name == InferenceBuilderName:
return InferenceBuilder()
elif class_name == UtilsBuilderName:
return UtilsBuilder()
if class_name in self.class_dict:
return self.class_dict[class_name]()
else:
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_accelerator():
_validate_accelerator(ds_accelerator)
return ds_accelerator

from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator
from .cuda_accelerator import CUDA_Accelerator
ds_accelerator = CUDA_Accelerator()
_validate_accelerator(ds_accelerator)
return ds_accelerator
Expand Down
5 changes: 3 additions & 2 deletions csrc/aio/py_test/aio_bench_perf_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from test_ds_aio_utils import refine_integer_value
from perf_sweep_utils import READ_OP_DESC, WRITE_OP_DESC, BENCH_LOG_DIR, \
READ_IO_DIR, WRITE_IO_DIR, READ_LOG_DIR, WRITE_LOG_DIR
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder.builder_names import AsyncIOBuilder

OTHER_OPTIONS = '--handle'
PERF_SCRIPT = 'test_ds_aio.py'
Expand Down Expand Up @@ -277,8 +279,7 @@ def script_path():


def async_io_setup():
from deepspeed.ops.aio import AsyncIOBuilder
return AsyncIOBuilder().is_compatible()
return get_accelerator().create_op_builder(AsyncIOBuilder).is_compatible()


def get_block_size_and_count(io_bytes):
Expand Down
38 changes: 22 additions & 16 deletions csrc/aio/py_test/ds_aio_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import torch
import os
import time
from deepspeed.ops.aio import AsyncIOBuilder
from multiprocessing import Pool, Barrier
from test_ds_aio_utils import report_results, task_log, task_barrier
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder.builder_names import AsyncIOBuilder


def pre_basic(args, tid, read_op):
Expand All @@ -19,7 +20,10 @@ def pre_basic(args, tid, read_op):
file = args.read_file if read_op else f'{args.write_file}.{tid}'

task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory()
buffer = get_accelerator().pin_memory(
torch.empty(num_bytes,
dtype=torch.uint8,
device='cpu'))
task_log(
tid,
f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}'
Expand Down Expand Up @@ -56,13 +60,14 @@ def post_basic(pool_params):
def main_basic_read(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_read(ctxt['buffer'],
ctxt['file'],
args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
args.validate)
get_accelerator().create_op_builder(AsyncIOBuilder).load().aio_read(
ctxt['buffer'],
ctxt['file'],
args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time

Expand All @@ -72,13 +77,14 @@ def main_basic_read(pool_params):
def main_basic_write(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_write(ctxt['buffer'],
ctxt['file'],
args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
args.validate)
get_accelerator().create_op_builder(AsyncIOBuilder).load().aio_write(
ctxt['buffer'],
ctxt['file'],
args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time

Expand Down
23 changes: 15 additions & 8 deletions csrc/aio/py_test/ds_aio_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import os
import time
from multiprocessing import Pool, Barrier
from deepspeed.ops.aio import AsyncIOBuilder
from test_ds_aio_utils import report_results, task_log, task_barrier
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder.builder_names import AsyncIOBuilder


def pre_handle(args, tid, read_op):
Expand All @@ -20,20 +21,26 @@ def pre_handle(args, tid, read_op):

task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
if args.gpu:
buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cuda')
buffer = torch.empty(num_bytes,
dtype=torch.uint8,
device=get_accelerator().device_name())
else:
buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory()
buffer = get_accelerator().pin_memory(
torch.empty(num_bytes,
dtype=torch.uint8,
device='cpu'))
task_log(
tid,
f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}'
)

io_parallel = args.io_parallel if args.io_parallel else 1
handle = AsyncIOBuilder().load().aio_handle(args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
io_parallel)
handle = get_accelerator().create_op_builder(AsyncIOBuilder).load().aio_handle(
args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
io_parallel)
task_log(tid, f'created deepspeed aio handle')

ctxt = {}
Expand Down
5 changes: 3 additions & 2 deletions csrc/aio/py_test/validate_async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
from deepspeed.ops.aio import AsyncIOBuilder
assert AsyncIOBuilder().is_compatible()
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder.builder_names import AsyncIOBuilder
assert get_accelerator().create_op_builder(AsyncIOBuilder).is_compatible()
1 change: 1 addition & 0 deletions deepspeed/accelerator
7 changes: 5 additions & 2 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
from .constants import *
from deepspeed.accelerator import get_accelerator


class ReduceOp(Enum):
Expand Down Expand Up @@ -127,7 +128,7 @@ def log_wrapper(*args, **kwargs):
finally:
if comms_logger.enabled:
# Need to make op blocking for accurate logging
torch.cuda.synchronize()
get_accelerator().synchronize()
# If we're using MPI, we can't simply sync the stream
if cdb.using_mpi:
cdb.barrier()
Expand Down Expand Up @@ -595,7 +596,7 @@ def get_global_rank(group=None, group_rank=0):


# Main DeepSpeed Comms. public API.
def init_distributed(dist_backend="nccl",
def init_distributed(dist_backend=None,
auto_mpi_discovery=True,
distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
verbose=True,
Expand Down Expand Up @@ -650,6 +651,8 @@ def init_distributed(dist_backend="nccl",
utils.logger.info('Distributed backend already initialized')
else:
assert isinstance(timeout, timedelta)
if dist_backend == None:
dist_backend = get_accelerator().communication_backend_name()
if int(os.getenv('RANK', '0')) == 0:
utils.logger.info(
'Initializing TorchBackend in DeepSpeed with backend {}'.format(
Expand Down
35 changes: 19 additions & 16 deletions deepspeed/env_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import deepspeed
import subprocess
import argparse
from .ops.op_builder import ALL_OPS
from .ops.op_builder.all_ops import ALL_OPS
from .git_version_info import installed_ops, torch_info
from deepspeed.accelerator import get_accelerator

GREEN = '\033[92m'
RED = '\033[91m'
Expand Down Expand Up @@ -79,31 +80,33 @@ def nvcc_version():
def debug_report():
max_dots = 33

hip_version = None
if hasattr(torch.version, 'hip'):
hip_version = torch.version.hip

report = [
("torch install path",
torch.__path__),
("torch version",
torch.__version__),
("torch cuda version",
torch.version.cuda),
("torch hip version",
hip_version),
("nvcc version",
(None if hip_version else nvcc_version())),
("deepspeed install path",
deepspeed.__path__),
("deepspeed info",
f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
),
("deepspeed wheel compiled w.",
f"torch {torch_info['version']}, " +
(f"hip {torch_info['hip_version']}"
if hip_version else f"cuda {torch_info['cuda_version']}")),
)
]
if get_accelerator().device_name() == 'cuda':
hip_version = getattr(torch.version, "hip", None)
report.extend([("torch cuda version",
torch.version.cuda),
("torch hip version",
hip_version),
("nvcc version",
(None if hip_version else nvcc_version())),
("deepspeed wheel compiled w.",
f"torch {torch_info['version']}, " +
(f"hip {torch_info['hip_version']}"
if hip_version else f"cuda {torch_info['cuda_version']}"))])
else:
report.extend([("deepspeed wheel compiled w.",
f"torch {torch_info['version']} ")])

print("DeepSpeed general environment info:")
for name, value in report:
print(name, "." * (max_dots - len(name)), value)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/git_version_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
git_hash = '[none]'
git_branch = '[none]'

from .ops.op_builder import ALL_OPS
from .ops.op_builder.all_ops import ALL_OPS
installed_ops = dict.fromkeys(ALL_OPS.keys(), False)
compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
torch_info = {'version': "0.0", "cuda_version": "0.0", "hip_version": "0.0"}
6 changes: 4 additions & 2 deletions deepspeed/ops/adagrad/cpu_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
'''

import torch
from ..op_builder import CPUAdagradBuilder
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder.builder_names import CPUAdagradBuilder
from deepspeed.utils.logging import should_log_le


Expand All @@ -24,7 +25,8 @@ def __init__(self,
self.opt_id = DeepSpeedCPUAdagrad.optimizer_id
DeepSpeedCPUAdagrad.optimizer_id = DeepSpeedCPUAdagrad.optimizer_id + 1
self.fp32_optimizer_states = fp32_optimizer_states
self.ds_opt_adagrad = CPUAdagradBuilder().load()
self.ds_opt_adagrad = get_accelerator().create_op_builder(
CPUAdagradBuilder).load()

self.ds_opt_adagrad.create_adagrad(self.opt_id,
lr,
Expand Down
Loading

0 comments on commit 9548d48

Please sign in to comment.