Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIO CPU Locked Tensor #6592

Merged
merged 26 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1413d69
fixed functionality of cpu locked tensor
jomayeri Sep 30, 2024
87a7c69
Merge branch 'master' into jomayeri/aio-locked-tensor
jomayeri Sep 30, 2024
c13bb10
enabling cpu locked in unittests, and fixing compilation errors
jomayeri Oct 1, 2024
248cec4
Merge branch 'jomayeri/aio-locked-tensor' of github.com:microsoft/Dee…
jomayeri Oct 1, 2024
b909702
passing gds tests
jomayeri Oct 2, 2024
b1ee711
renaming all instances of num_threads
jomayeri Oct 2, 2024
ada1b83
updating function names to match
jomayeri Oct 2, 2024
1cb88ce
fix formatting
jomayeri Oct 2, 2024
f5528da
variable name change to fix compilation
jomayeri Oct 2, 2024
f576d29
formatting
jomayeri Oct 2, 2024
5a47bf3
update references in tutorial
jomayeri Oct 3, 2024
fe93fdc
Merge branch 'master' into jomayeri/aio-locked-tensor
jomayeri Oct 7, 2024
b2866cb
Merge branch 'master' into jomayeri/aio-locked-tensor
tjruwase Oct 7, 2024
9c93d2c
Merge branch 'master' into jomayeri/aio-locked-tensor
tjruwase Oct 8, 2024
884c0fd
async_io operator for CPU accelerator
tjruwase Oct 9, 2024
a5ba643
Merge branch 'jomayeri/aio-locked-tensor' of github.com:microsoft/Dee…
tjruwase Oct 9, 2024
ea0e45b
Merge branch 'master' into jomayeri/aio-locked-tensor
tjruwase Oct 9, 2024
98988cd
Formatting; Use int64_t
tjruwase Oct 9, 2024
b30cda5
Merge branch 'jomayeri/aio-locked-tensor' of github.com:microsoft/Dee…
tjruwase Oct 9, 2024
90e25da
Skip fp16 tests on CPU
tjruwase Oct 9, 2024
60ae3e0
Merge branch 'master' into jomayeri/aio-locked-tensor
tjruwase Oct 9, 2024
a008d4c
Merge branch 'master' into jomayeri/aio-locked-tensor
tjruwase Oct 9, 2024
59d8dfa
Merge branch 'master' into jomayeri/aio-locked-tensor
tjruwase Oct 9, 2024
8a52388
Add Cuda 12.6
tjruwase Oct 9, 2024
d1afe4c
Merge branch 'jomayeri/aio-locked-tensor' of github.com:microsoft/Dee…
tjruwase Oct 9, 2024
0182208
Merge branch 'master' into jomayeri/aio-locked-tensor
tjruwase Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/aio/common/deepspeed_aio_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ int get_file_size(const char* filename, long long int& size)
return 0;
}

void* ds_page_aligned_alloc(const size_t size, const bool lock)
void* ds_page_aligned_alloc(const long long int size, const bool lock)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
{
void* ptr;
int retval;
Expand Down
2 changes: 1 addition & 1 deletion csrc/aio/common/deepspeed_aio_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,6 @@ struct io_prep_generator {
int prep_iocbs(const int n_iocbs, std::vector<struct iocb*>* iocbs);
};

void* ds_page_aligned_alloc(const size_t size, const bool lock = false);
void* ds_page_aligned_alloc(const long long int size, const bool lock = false);

int get_file_size(const char* filename, long long int& size);
8 changes: 5 additions & 3 deletions csrc/aio/py_lib/deepspeed_aio_op_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@ using namespace std;

io_op_desc_t::io_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const bool is_managed,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int num_threads,
const int intra_op_parallelism,
const bool validate)
: _read_op(read_op),
_buffer(buffer),
_is_managed(is_managed),
_fd(fd),
_filename(filename),
_file_num_bytes(file_num_bytes),
_num_threads(num_threads),
_num_bytes_per_thread(file_num_bytes / num_threads),
_intra_op_parallelism(intra_op_parallelism),
_num_bytes_per_thread(file_num_bytes / intra_op_parallelism),
_validate(validate)
{
}
Expand Down
6 changes: 4 additions & 2 deletions csrc/aio/py_lib/deepspeed_aio_op_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@
struct io_op_desc_t {
const bool _read_op;
torch::Tensor _buffer;
const bool _is_managed;
int _fd;
const std::string _filename;
const long long int _file_num_bytes;
const int _num_threads;
const int _intra_op_parallelism;
const long long int _num_bytes_per_thread;
torch::Tensor _contiguous_buffer;
const bool _validate;

io_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const bool is_managed,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int num_threads,
const int intra_op_parallelism,
const bool validate);

virtual void run(const int tid,
Expand Down
24 changes: 17 additions & 7 deletions csrc/aio/py_lib/deepspeed_cpu_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,32 @@ using namespace std;

cpu_op_desc_t::cpu_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const bool is_managed,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int num_threads,
const int intra_op_parallelism,
const bool validate)
: io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, num_threads, validate),
: io_op_desc_t(read_op,
buffer,
is_managed,
fd,
filename,
file_num_bytes,
intra_op_parallelism,
validate),
_cpu_buffer(buffer)
{
// Need to use CPU bounce buffer if buffer is not a page-locked DRAM memory.
_use_bounce_buffer = !(_buffer.is_cpu() && _buffer.is_pinned());
_use_bounce_buffer = !(_buffer.is_cpu() && (_buffer.is_pinned() || _is_managed));
if (_use_bounce_buffer) {
if (_read_op) {
auto options = torch::TensorOptions()
.dtype(_buffer.dtype())
.layout(_buffer.layout())
.device(torch::kCPU);
_cpu_buffer = torch::empty(_buffer.nbytes(), options).pin_memory();
.device(torch::kCPU)
.requires_grad(false);
_cpu_buffer = torch::empty(_buffer.numel(), options).pin_memory();
} else {
_cpu_buffer = _buffer.to(torch::kCPU).pin_memory();
}
Expand All @@ -37,9 +46,10 @@ char* cpu_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_pt

void cpu_op_desc_t::finish()
{
if (_read_op) {
if (_read_op && _use_bounce_buffer) {
if (_buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); }
if (_buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); }
if (_buffer.is_cpu()) { _buffer.copy_(_cpu_buffer); }
#if defined(__ENABLE_CANN__)
if (torch_npu::utils::is_npu(_buffer)) {
auto device = at::Device("npu:0");
Expand All @@ -58,7 +68,7 @@ void cpu_op_desc_t::run(const int tid,
std::unique_ptr<aio_context>& aio_ctxt,
deepspeed_aio_config_t* aio_config)
{
assert(tid < _num_threads);
assert(tid < _intra_op_parallelism);
const auto base_offset = _num_bytes_per_thread * tid;

std::unique_ptr<io_xfer_ctxt> xfer_ctxt(
Expand Down
3 changes: 2 additions & 1 deletion csrc/aio/py_lib/deepspeed_cpu_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ struct cpu_op_desc_t : io_op_desc_t {

cpu_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const bool is_managed,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int num_threads,
const int intra_op_parallelism,
const bool validate);

void run(const int tid,
Expand Down
15 changes: 12 additions & 3 deletions csrc/aio/py_lib/deepspeed_pin_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@ deepspeed_pin_tensor_t::~deepspeed_pin_tensor_t()
_locked_tensors.clear();
}

torch::Tensor deepspeed_pin_tensor_t::alloc(const size_t num_elem, const at::ScalarType& elem_type)
torch::Tensor deepspeed_pin_tensor_t::alloc(const long long int num_elem,
const at::ScalarType& elem_type)
{
const auto num_bytes = num_elem * elementSize(elem_type);
auto pinned_buffer = ds_page_aligned_alloc(num_bytes, true);
assert(nullptr != pinned_buffer);

_locked_tensors[pinned_buffer] = num_bytes;

auto options = torch::TensorOptions().dtype(elem_type).device(torch::kCPU);
auto options = torch::TensorOptions().dtype(elem_type).device(torch::kCPU).requires_grad(false);

return at::from_blob(pinned_buffer, static_cast<long int>(num_bytes), options);
return at::from_blob(pinned_buffer, static_cast<long long int>(num_elem), options);
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
}

bool deepspeed_pin_tensor_t::free(torch::Tensor& locked_tensor)
Expand All @@ -43,3 +44,11 @@ bool deepspeed_pin_tensor_t::free(torch::Tensor& locked_tensor)

return false;
}

bool deepspeed_pin_tensor_t::is_managed(const torch::Tensor& buffer)
{
auto addr = buffer.data_ptr();
if (!buffer.is_cpu()) { return false; }
if (_locked_tensors.find(addr) != _locked_tensors.end()) { return true; }
return false;
};
6 changes: 4 additions & 2 deletions csrc/aio/py_lib/deepspeed_pin_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ Functionality for managing CPU tensors occupying page-locked memory.
#include "deepspeed_py_aio.h"

struct deepspeed_pin_tensor_t {
std::map<void*, size_t> _locked_tensors;
std::map<void*, long long int> _locked_tensors;

deepspeed_pin_tensor_t() = default;

~deepspeed_pin_tensor_t();

torch::Tensor alloc(const size_t num_elem, const at::ScalarType& elem_type);
torch::Tensor alloc(const long long num_elem, const at::ScalarType& elem_type);

bool free(torch::Tensor& locked_tensor);

bool is_managed(const torch::Tensor& buffer);
};
8 changes: 6 additions & 2 deletions csrc/aio/py_lib/deepspeed_py_aio_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ deepspeed_aio_handle_t::deepspeed_aio_handle_t(const int block_size,
const int queue_depth,
const bool single_submit,
const bool overlap_events,
const int num_threads)
: deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, num_threads)
const int intra_op_parallelism)
: deepspeed_io_handle_t(block_size,
queue_depth,
single_submit,
overlap_events,
intra_op_parallelism)
{
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/aio/py_lib/deepspeed_py_aio_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct deepspeed_aio_handle_t : deepspeed_io_handle_t {
const int queue_depth,
const bool single_submit,
const bool overlap_events,
const int num_threads);
const int intra_op_parallelism);

~deepspeed_aio_handle_t();
};
23 changes: 12 additions & 11 deletions csrc/aio/py_lib/deepspeed_py_io_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ deepspeed_io_handle_t::deepspeed_io_handle_t(const int block_size,
const int queue_depth,
const bool single_submit,
const bool overlap_events,
const int num_threads)
const int intra_op_parallelism)
: _aio_ctxt(new aio_context(block_size, queue_depth)),
_single_submit(single_submit),
_overlap_events(overlap_events),
_num_threads(num_threads),
_intra_op_parallelism(intra_op_parallelism),
_aio_config(block_size, queue_depth, single_submit, overlap_events, false),
_num_pending_ops(0),
_pinned_tensor_mgr(new deepspeed_pin_tensor_t())
{
for (auto i = 0; i < num_threads; ++i) {
for (auto i = 0; i < intra_op_parallelism; ++i) {
_thread_contexts.push_back(std::make_shared<deepspeed_aio_thread_t>(i, _aio_config));
}

Expand Down Expand Up @@ -56,7 +56,7 @@ const bool deepspeed_io_handle_t::get_single_submit() const { return _single_sub

const bool deepspeed_io_handle_t::get_overlap_events() const { return _overlap_events; }

const int deepspeed_io_handle_t::get_thread_count() const { return _num_threads; }
const int deepspeed_io_handle_t::get_intra_op_parallelism() const { return _intra_op_parallelism; }

int deepspeed_io_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate)
{
Expand Down Expand Up @@ -192,9 +192,9 @@ bool deepspeed_io_handle_t::_is_valid_parallel_aio_op(const bool read_op,
const long long int num_bytes)
{
const auto op_string = read_op ? "Read" : "Write";
if (num_bytes % get_thread_count()) {
if (num_bytes % get_intra_op_parallelism()) {
std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes
<< " not divisible by thread count = " << get_thread_count() << std::endl;
<< " not divisible by thread count = " << get_intra_op_parallelism() << std::endl;
return false;
}

Expand All @@ -209,8 +209,9 @@ std::shared_ptr<struct io_op_desc_t> deepspeed_io_handle_t::_create_io_op_desc(
const long long int file_num_bytes,
const bool validate)
{
bool is_managed = _pinned_tensor_mgr->is_managed(buffer);
return std::make_shared<cpu_op_desc_t>(
read_op, buffer, fd, filename, file_num_bytes, _num_threads, validate);
read_op, buffer, is_managed, fd, filename, file_num_bytes, _intra_op_parallelism, validate);
}

int deepspeed_io_handle_t::pread(const torch::Tensor& buffer,
Expand All @@ -229,8 +230,8 @@ int deepspeed_io_handle_t::pread(const torch::Tensor& buffer,
std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes
<< " != " << num_file_bytes << std::endl;
}
assert(static_cast<long long int>(buffer.nbytes()) == num_file_bytes);
assert((num_file_bytes % _num_threads) == 0);
assert(buffer_bytes == num_file_bytes);
assert((num_file_bytes % _intra_op_parallelism) == 0);

if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; }

Expand All @@ -252,7 +253,7 @@ int deepspeed_io_handle_t::pwrite(const torch::Tensor& buffer,
const bool async)
{
const auto num_write_bytes = static_cast<long long int>(buffer.nbytes());
assert((num_write_bytes % _num_threads) == 0);
assert((num_write_bytes % _intra_op_parallelism) == 0);

if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; }

Expand Down Expand Up @@ -288,7 +289,7 @@ int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer, const char*
return pwrite(buffer, filename, false, true);
}

at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const size_t num_elem,
at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const long long int num_elem,
const torch::Tensor& example_tensor)
{
return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type());
Expand Down
9 changes: 5 additions & 4 deletions csrc/aio/py_lib/deepspeed_py_io_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct deepspeed_io_handle_t {
std::unique_ptr<struct aio_context> _aio_ctxt;
const bool _single_submit;
const bool _overlap_events;
const int _num_threads;
const int _intra_op_parallelism;
deepspeed_aio_config_t _aio_config;

std::vector<std::shared_ptr<struct deepspeed_aio_thread_t>> _thread_contexts;
Expand All @@ -28,15 +28,15 @@ struct deepspeed_io_handle_t {
const int queue_depth,
const bool single_submit,
const bool overlap_events,
const int num_threads);
const int intra_op_parallelism);

virtual ~deepspeed_io_handle_t() = 0;

const int get_block_size() const;
const int get_queue_depth() const;
const bool get_single_submit() const;
const bool get_overlap_events() const;
const int get_thread_count() const;
const int get_intra_op_parallelism() const;

int read(torch::Tensor& buffer, const char* filename, const bool validate);

Expand All @@ -61,7 +61,8 @@ struct deepspeed_io_handle_t {
int async_pwrite(const torch::Tensor& buffer, const char* filename);

// TODO: Make API's args to be shape and dtype.
torch::Tensor new_cpu_locked_tensor(const size_t num_elem, const torch::Tensor& example_tensor);
torch::Tensor new_cpu_locked_tensor(const long long int num_elem,
const torch::Tensor& example_tensor);

bool free_cpu_locked_tensor(torch::Tensor&);

Expand Down
4 changes: 2 additions & 2 deletions csrc/aio/py_lib/py_ds_aio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"queue_depth"_a = 128,
"single_submit"_a = false,
"overlap_events"_a = false,
"num_threads"_a = 1)
"intra_op_parallelism"_a = 1)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

.def("get_block_size", &deepspeed_aio_handle_t::get_block_size)
.def("get_queue_depth", &deepspeed_aio_handle_t::get_queue_depth)
.def("get_single_submit", &deepspeed_aio_handle_t::get_single_submit)
.def("get_overlap_events", &deepspeed_aio_handle_t::get_overlap_events)
.def("get_thread_count", &deepspeed_aio_handle_t::get_thread_count)
.def("get_intra_op_parallelism", &deepspeed_aio_handle_t::get_intra_op_parallelism)

.def("read",
&deepspeed_aio_handle_t::read,
Expand Down
14 changes: 11 additions & 3 deletions csrc/gds/py_lib/deepspeed_gds_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,20 @@ void gds_op_desc_t::remove_buffer_from_registry(const torch::Tensor& buffer)

gds_op_desc_t::gds_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const bool is_managed,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int num_threads,
const int intra_op_parallelism,
const bool validate)
: io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, num_threads, validate)
: io_op_desc_t(read_op,
buffer,
is_managed,
fd,
filename,
file_num_bytes,
intra_op_parallelism,
validate)
{
_contiguous_buffer = _buffer.contiguous();
const int64_t device = _buffer.get_device();
Expand All @@ -123,7 +131,7 @@ void gds_op_desc_t::run(const int tid,
std::unique_ptr<aio_context>& aio_ctxt,
deepspeed_aio_config_t* aio_config)
{
assert(tid < _num_threads);
assert(tid < _intra_op_parallelism);
check_cudaruntimecall(cudaSetDevice(_buffer.get_device()));
int64_t buf_offset = data_ptr() + (_num_bytes_per_thread * tid) - (char*)_base_ptr;
const auto file_offset = _num_bytes_per_thread * tid;
Expand Down
3 changes: 2 additions & 1 deletion csrc/gds/py_lib/deepspeed_gds_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ struct gds_op_desc_t : io_op_desc_t {

gds_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const bool is_managed,
const int fd,
const char* filename,
const long long int file_num_bytes,
const int num_threads,
const int intra_op_parallelism,
const bool validate);

void run(const int tid,
Expand Down
Loading
Loading