Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Compressedbackend for Onebit optimizers #5473

Merged
merged 19 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
format
  • Loading branch information
Liangliang-Ma committed Apr 26, 2024
commit c88355546cc79afd2c3c76f1899d82da2fb33e08
45 changes: 19 additions & 26 deletions csrc/xpu/packbits/packing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,25 @@

// DeepSpeed Team

#include <ipex.h>
#include <torch/extension.h>
#include <iostream>
#include <sycl/sycl.hpp>
#include <ipex.h>

using namespace sycl;
using namespace xpu;

void packbitskernel(const bool* input,
uint8_t * output,
const int input_size,
id<1> item_ct1)
void packbitskernel(const bool* input, uint8_t* output, const int input_size, id<1> item_ct1)
{
int i = item_ct1;
for (int j = 0; j < 8; ++j)
{
for (int j = 0; j < 8; ++j) {
int k = i * 8 + j;
int bit = k < input_size && input[k] != 0;
output[i] |= bit << (7 - j);
}
}

void unpackbitskernel(const uint8_t* input,
float * output,
id<1> item_ct1)
void unpackbitskernel(const uint8_t* input, float* output, id<1> item_ct1)
{
int i = item_ct1;
output[i] = (input[i / 8] >> (7 - i % 8)) & 1;
Expand All @@ -43,7 +37,7 @@ sycl::queue get_current_queue(at::Device device)

at::Tensor packbits(at::Tensor tensor, int input_size, int rank)
{
/*
/*
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Liangliang-Ma the function documentation needs to be moved to line 39 right before the function def line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

pack bool tensor into uint8 tensor. Every eight bool elements get packed into one uint8
Arguments:
tensor: A bool tensor that get packed.
Expand All @@ -57,21 +51,21 @@ at::Tensor packbits(at::Tensor tensor, int input_size, int rank)
auto unit8_options = at::TensorOptions().dtype(at::kByte).device(at::kXPU);
at::Tensor packed = torch::empty({packed_size}, unit8_options);

bool* input = (bool *)tensor.data_ptr();
uint8_t * output = (uint8_t *)packed.data_ptr();
bool* input = (bool*)tensor.data_ptr();
uint8_t* output = (uint8_t*)packed.data_ptr();

auto event = q.submit([&](sycl::handler&cgh) {
auto event = q.submit([&](sycl::handler& cgh) {
cgh.parallel_for<>(range(packed_size), [=](id<1> item_ct1) {
packbitskernel(input, output, input_size, item_ct1);
});
});
packbitskernel(input, output, input_size, item_ct1);
});
});

return packed;
}

at::Tensor unpackbits(at::Tensor tensor, int input_size, int rank)
{
/*
/*
unpack uint8 tensor into float tensor. Every uint8 element get unpacked into eight float
Arguments:
tensor: A uint8 tensor that get unpacked.
Expand All @@ -82,16 +76,15 @@ at::Tensor unpackbits(at::Tensor tensor, int input_size, int rank)
sycl::queue q = get_current_queue(device);

auto float_options = at::TensorOptions().dtype(at::kFloat).device(at::kXPU);
at::Tensor unpacked = torch::empty({input_size*8}, float_options);
at::Tensor unpacked = torch::empty({input_size * 8}, float_options);

uint8_t* input = (uint8_t *)tensor.data_ptr();
float * output = (float *)unpacked.data_ptr();
uint8_t* input = (uint8_t*)tensor.data_ptr();
float* output = (float*)unpacked.data_ptr();

auto event = q.submit([&](sycl::handler&cgh) {
cgh.parallel_for<>(range(input_size*8), [=](id<1> item_ct1) {
unpackbitskernel(input, output, item_ct1);
});
});
auto event = q.submit([&](sycl::handler& cgh) {
cgh.parallel_for<>(range(input_size * 8),
[=](id<1> item_ct1) { unpackbitskernel(input, output, item_ct1); });
});

return unpacked;
}
Expand Down
11 changes: 6 additions & 5 deletions deepspeed/runtime/comm/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def my_gather(self, rank, size, group, sendbuf, recvbuf, root):
dist.send(sendbuf, group=group, dst=root)

def pack(self, buffer, size):
buffer = buffer.ravel().sign_().add_(1).bool() # convert buffer to bool, element set to True if its value >=0
buffer = buffer.ravel().sign_().add_(1).bool() # convert buffer to bool, element set to True if its value >=0
packed = self.packer.packbits(buffer, buffer.numel(), self.rank)
return packed.reshape(size, -1)

Expand All @@ -72,16 +72,16 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro
worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

sign_list_packed_tmp = self.pack(buffer_m, self.size).type(torch.int8)



recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])],
dtype=sign_list_packed_tmp[0].dtype,
device=sign_list_packed_tmp.device)

sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)]

recvbuf_scale = [
torch.zeros(1, dtype=worker_scale.dtype, device=get_accelerator().current_device_name()) for _ in range(self.size)
torch.zeros(1, dtype=worker_scale.dtype, device=get_accelerator().current_device_name())
for _ in range(self.size)
]

# communication phase 1
Expand Down Expand Up @@ -126,7 +126,8 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro
flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten()

buffer_m.data.copy_(
self.unpack(flattened_recvbuf_sign_server, self.size, torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data)
self.unpack(flattened_recvbuf_sign_server, self.size,
torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data)

if original_size != worker_error_size:
buffer_m = buffer_m[0:original_size]
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/fp16/onebit/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self,
self.deepspeed = deepspeed
self.adam_freeze_key = False
self.initialize = False
self.freeze_step = 5
self.freeze_step = freeze_step
self.cuda_aware = cuda_aware
self.using_pipeline = False

Expand Down