Skip to content

Commit

Permalink
DeepNVMe GDS (#5852)
Browse files Browse the repository at this point in the history
PR for the GDS AIO code.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <loadams@microsoft.com>
Co-authored-by: Ubuntu <deepspeed@H100-VM2.shlnn55tgwve1eacvp21ie45dg.jx.internal.cloudapp.net>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
5 people authored Aug 19, 2024
1 parent c2e3a70 commit 5f0d177
Show file tree
Hide file tree
Showing 43 changed files with 2,144 additions and 761 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nv-pre-compile-ops.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_GDS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
- name: DS Report
run: |
ds_report
38 changes: 38 additions & 0 deletions csrc/aio/py_lib/deepspeed_aio_op_desc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include "deepspeed_aio_op_desc.h"

using namespace std;

io_op_desc_t::io_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int num_threads,
const bool validate)
: _read_op(read_op),
_buffer(buffer),
_fd(fd),
_filename(filename),
_file_num_bytes(file_num_bytes),
_num_threads(num_threads),
_num_bytes_per_thread(file_num_bytes / num_threads),
_validate(validate)
{
}

char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }

void io_op_desc_t::finish() {}

void io_op_desc_t::validate() {}

void io_op_desc_t::run(const int tid,
std::unique_ptr<aio_context>& aio_ctxt,
deepspeed_aio_config_t* aio_config)
{
}
41 changes: 41 additions & 0 deletions csrc/aio/py_lib/deepspeed_aio_op_desc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#ifndef _IO_OP_DESC_T_
#define _IO_OP_DESC_T_
#include <memory>
#include <queue>
#include "deepspeed_py_aio.h"

struct io_op_desc_t {
const bool _read_op;
torch::Tensor _buffer;
int _fd;
const std::string _filename;
const long long int _file_num_bytes;
const int _num_threads;
const int _num_bytes_per_thread;
torch::Tensor _contiguous_buffer;
const bool _validate;

io_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int num_threads,
const bool validate);

virtual void run(const int tid,
std::unique_ptr<aio_context>& aio_ctxt,
deepspeed_aio_config_t* aio_config);

virtual char* data_ptr() const;

virtual void validate();

virtual void finish();
};
#endif // _IO_OP_DESC_T_
55 changes: 1 addition & 54 deletions csrc/aio/py_lib/deepspeed_aio_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,50 +9,8 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.

#include "deepspeed_aio_thread.h"

#if defined(__ENABLE_CANN__)
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
#include "torch_npu/csrc/framework/utils/UtilForOpAdapter.h"
#endif

using namespace std;

io_op_desc_t::io_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const int fd,
const char* filename,
const long long int num_bytes,
const bool validate)
: _read_op(read_op),
_buffer(buffer),
_fd(fd),
_filename(filename),
_num_bytes(num_bytes),
_validate(validate)
{
_cpu_buffer = (_buffer.is_cuda() || _buffer.is_xpu()
#if defined(__ENABLE_CANN__)
|| torch_npu::utils::is_npu(_buffer)
#endif
)
? _buffer.to(torch::kCPU).pin_memory()
: _buffer;
_contiguous_buffer = _cpu_buffer.contiguous();
}

char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }

void io_op_desc_t::fini()
{
if (_read_op && _buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); }
if (_read_op && _buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); }
#if defined(__ENABLE_CANN__)
if (_read_op && torch_npu::utils::is_npu(_buffer)) {
auto device = at::Device("npu:0");
_buffer.copy_(_cpu_buffer.to(device));
}
#endif
}

deepspeed_aio_thread_t::deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config)
: _tid(tid),
_aio_config(aio_config),
Expand All @@ -79,18 +37,7 @@ void deepspeed_aio_thread_t::run()
}

if (next_io_op) {
const auto base_offset = next_io_op->_num_bytes * _tid;

std::unique_ptr<io_xfer_ctxt> xfer_ctxt(new io_xfer_ctxt(
next_io_op->_fd, base_offset, next_io_op->_num_bytes, next_io_op->data_ptr()));

if (_aio_config._overlap_events) {
do_aio_operation_overlap(
next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
} else {
do_aio_operation_sequential(
next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
}
next_io_op->run(_tid, _aio_ctxt, &_aio_config);

{
std::lock_guard<std::mutex> lock(_complete_sync._mutex);
Expand Down
23 changes: 1 addition & 22 deletions csrc/aio/py_lib/deepspeed_aio_thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
#include <condition_variable>
#include <memory>
#include <queue>
#include "deepspeed_py_aio.h"

struct io_op_desc_t {
const bool _read_op;
torch::Tensor _buffer;
int _fd;
const std::string _filename;
const long long int _num_bytes;
torch::Tensor _cpu_buffer;
torch::Tensor _contiguous_buffer;
const bool _validate;

io_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const int fd,
const char* filename,
const long long int num_bytes,
const bool validate);

char* data_ptr() const;
void fini();
};
#include "deepspeed_cpu_op.h"

struct thread_sync_t {
std::mutex _mutex;
Expand Down
72 changes: 72 additions & 0 deletions csrc/aio/py_lib/deepspeed_cpu_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include "deepspeed_cpu_op.h"

using namespace std;

cpu_op_desc_t::cpu_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int num_threads,
const bool validate)
: io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, num_threads, validate),
_cpu_buffer(buffer)
{
// Need to use CPU bounce buffer if buffer is not a page-locked DRAM memory.
_use_bounce_buffer = !(_buffer.is_cpu() && _buffer.is_pinned());
if (_use_bounce_buffer) {
if (_read_op) {
auto options = torch::TensorOptions()
.dtype(_buffer.dtype())
.layout(_buffer.layout())
.device(torch::kCPU);
_cpu_buffer = torch::empty(_buffer.nbytes(), options).pin_memory();
} else {
_cpu_buffer = _buffer.to(torch::kCPU).pin_memory();
}
}
_contiguous_buffer = _cpu_buffer.contiguous();
}

char* cpu_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }

void cpu_op_desc_t::finish()
{
if (_read_op) {
if (_buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); }
if (_buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); }
#if defined(__ENABLE_CANN__)
if (torch_npu::utils::is_npu(_buffer)) {
auto device = at::Device("npu:0");
_buffer.copy_(_cpu_buffer.to(device));
}
#endif
}
}

void cpu_op_desc_t::validate()
{
validate_aio_operation(_read_op, _filename.c_str(), data_ptr(), _file_num_bytes);
}

void cpu_op_desc_t::run(const int tid,
std::unique_ptr<aio_context>& aio_ctxt,
deepspeed_aio_config_t* aio_config)
{
assert(tid < _num_threads);
const auto base_offset = _num_bytes_per_thread * tid;

std::unique_ptr<io_xfer_ctxt> xfer_ctxt(
new io_xfer_ctxt(_fd, base_offset, _num_bytes_per_thread, data_ptr()));

if (aio_config->_overlap_events) {
do_aio_operation_overlap(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr);
} else {
do_aio_operation_sequential(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr);
}
}
31 changes: 31 additions & 0 deletions csrc/aio/py_lib/deepspeed_cpu_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include <memory>
#include <queue>
#include "deepspeed_aio_op_desc.h"

struct cpu_op_desc_t : io_op_desc_t {
torch::Tensor _cpu_buffer;
bool _use_bounce_buffer;

cpu_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int num_threads,
const bool validate);

void run(const int tid,
std::unique_ptr<aio_context>& aio_ctxt,
deepspeed_aio_config_t* aio_config);

char* data_ptr() const;

void validate();

void finish();
};
3 changes: 0 additions & 3 deletions csrc/aio/py_lib/deepspeed_py_aio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
// DeepSpeed Team

/*
Copyright 2020 The Microsoft DeepSpeed Team
Licensed under the MIT license.
Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
*/

Expand Down
5 changes: 1 addition & 4 deletions csrc/aio/py_lib/deepspeed_py_aio.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
// DeepSpeed Team

/*
Copyright 2020 The Microsoft DeepSpeed Team
Licensed under the MIT license.
Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
Functionality for swapping tensors to/from (NVMe) storage devices.
*/

#include <deepspeed_aio_common.h>
Expand Down
Loading

0 comments on commit 5f0d177

Please sign in to comment.