Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ DeepEP (AMD version) depends on [rocSHMEM](https://github.com/ROCm/rocSHMEM). Pl
git clone https://github.com/ROCm/DeepEP
cd DeepEP


# To use DeepEP with MPI, please proceed with these commands
# Export OMPI dir in the next command (e.g., it's $BUILD_DIR/ompi in third-party/README.md)
export OMPI_DIR=<ompi_dir>
Expand All @@ -41,6 +40,12 @@ python3 setup.py --variant rocm build develop --user
# Then install DeepEP using this command
python3 setup.py --variant rocm --disable-mpi build develop --user



# To use DeepEP without MPI, please make sure rocSHMEM was built with this flag -DUSE_EXTERNAL_MPI=OFF
# Then install DeepEP using this command
python3 setup.py --variant rocm --disable-mpi build develop

# Run test cases
# NOTES: you may modify the `init_dist` function in `tests/utils.py`
# according to your own cluster settings, and launch into multiple nodes
Expand Down
4 changes: 2 additions & 2 deletions csrc/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
namespace deep_ep {

template <typename dtype_t>
dtype_t cell_div(dtype_t a, dtype_t b) {
dtype_t ceil_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;
}

template <typename dtype_t>
dtype_t align(dtype_t a, dtype_t b) {
return cell_div<dtype_t>(a, b) * b;
return ceil_div<dtype_t>(a, b) * b;
}

struct Config {
Expand Down
1,264 changes: 879 additions & 385 deletions csrc/deep_ep.cpp

Large diffs are not rendered by default.

237 changes: 173 additions & 64 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <torch/types.h>

#include <tuple>
#include <vector>

Expand Down Expand Up @@ -35,21 +36,24 @@ struct Buffer {
void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void** buffer_ptrs_gpu = nullptr;

void* nvl_buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
void** nvl_buffer_ptrs_gpu = nullptr;
// NVSHMEM Buffer
int64_t num_rdma_bytes;
void* rdma_buffer_ptr = nullptr;

// Shrink mode buffer
bool enable_shrink = false;
int* mask_buffer_ptr = nullptr;
int* sync_buffer_ptr = nullptr;

// Device info and communication
int device_id;
#ifdef USE_ROCM
int gfx;
#endif
#endif
int num_device_sms;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
cudaIpcMemHandle_t pxn_ipc_handles[NUM_MAX_NVL_PEERS];

// Stream for communication
at::cuda::CUDAStream comm_stream;
Expand All @@ -59,8 +63,15 @@ struct Buffer {

// Task fifo
int head = 0;
int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** task_fifo_ptrs_gpu = nullptr;

// Whether explicit `destroy()` is required.
bool explicitly_destroy;
// After `destroy()` be called, this flag will be true
bool destroyed = false;

// Barrier signals
int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** barrier_signal_ptrs_gpu = nullptr;

// Workspace
void* workspace = nullptr;
Expand All @@ -83,9 +94,15 @@ struct Buffer {

private:
void move_fifo_slots(int num_slots = 1);

public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode);
Buffer(int rank,
int num_ranks,
int64_t num_nvl_bytes,
int64_t num_rdma_bytes,
bool low_latency_mode,
bool explicitly_destroy,
bool enable_shrink);

~Buffer() noexcept(false);

Expand All @@ -104,67 +121,159 @@ struct Buffer {
pybind11::bytearray get_local_ipc_handle() const;

pybind11::bytearray get_local_nvshmem_unique_id() const;

pybind11::bytearray get_local_pxn_ipc_handle() const;

torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const;

void sync(const std::vector<int>& device_ids, const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt);

std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional<EventHandle>& previous_event,
bool async, bool allocate_on_comm_stream);

std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>>
intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, const std::optional<torch::Tensor>& cached_rank_prefix_matrix, const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);

std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
intranode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);

std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::vector<int>, torch::Tensor, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_dispatch(const torch::Tensor& x, const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx, const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank, const std::optional<torch::Tensor>& num_tokens_per_rdma_rank,
const torch::Tensor& is_token_in_rank, const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens, int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix, const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
int expert_alignment, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);

std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
internode_combine(const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);
torch::Stream get_comm_stream() const;

void sync(const std::vector<int>& device_ids,
const std::vector<std::optional<pybind11::bytearray>>& all_gathered_handles,
const std::optional<pybind11::bytearray>& root_unique_id_opt);

void destroy();

std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, std::optional<EventHandle>> get_dispatch_layout(
const torch::Tensor& topk_idx,
int num_experts,
std::optional<EventHandle>& previous_event,
bool async,
bool allocate_on_comm_stream);

std::tuple<torch::Tensor,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::vector<int>,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
std::optional<EventHandle>>
intranode_dispatch(const torch::Tensor& x,
const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx,
const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank,
const torch::Tensor& is_token_in_rank,
const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens,
const std::optional<torch::Tensor>& cached_rank_prefix_matrix,
const std::optional<torch::Tensor>& cached_channel_prefix_matrix,
int expert_alignment,
int num_worst_tokens,
const Config& config,
std::optional<EventHandle>& previous_event,
bool async,
bool allocate_on_comm_stream);

std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> intranode_combine(
const torch::Tensor& x,
const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_idx,
const torch::Tensor& rank_prefix_matrix,
const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head,
const Config& config,
std::optional<EventHandle>& previous_event,
bool async,
bool allocate_on_comm_stream);

std::tuple<torch::Tensor,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::vector<int>,
torch::Tensor,
torch::Tensor,
std::optional<torch::Tensor>,
torch::Tensor,
std::optional<torch::Tensor>,
torch::Tensor,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<torch::Tensor>,
std::optional<EventHandle>>
internode_dispatch(const torch::Tensor& x,
const std::optional<torch::Tensor>& x_scales,
const std::optional<torch::Tensor>& topk_idx,
const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& num_tokens_per_rank,
const std::optional<torch::Tensor>& num_tokens_per_rdma_rank,
const torch::Tensor& is_token_in_rank,
const std::optional<torch::Tensor>& num_tokens_per_expert,
int cached_num_recv_tokens,
int cached_num_rdma_recv_tokens,
const std::optional<torch::Tensor>& cached_rdma_channel_prefix_matrix,
const std::optional<torch::Tensor>& cached_recv_rdma_rank_prefix_sum,
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix,
const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
int expert_alignment,
const Config& config,
std::optional<EventHandle>& previous_event,
bool async,
bool allocate_on_comm_stream);

std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> internode_combine(
const torch::Tensor& x,
const std::optional<torch::Tensor>& topk_weights,
const std::optional<torch::Tensor>& bias_0,
const std::optional<torch::Tensor>& bias_1,
const torch::Tensor& src_meta,
const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix,
const torch::Tensor& rdma_rank_prefix_sum,
const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head,
const torch::Tensor& combined_nvl_head,
const Config& config,
std::optional<EventHandle>& previous_event,
bool async,
bool allocate_on_comm_stream);

void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);

std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool async, bool return_recv_hook);

std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);

torch::Tensor
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);

// addtional interface for c++
std::string get_local_ipc_handle_string() const;
std::string get_local_nvshmem_unique_id_string() const;
void sync_string(const std::vector<int>& device_ids, const std::vector<std::string>& all_gathered_handles, const std::string& root_unique_id_opt);
std::tuple<torch::Tensor,
std::optional<torch::Tensor>,
torch::Tensor,
torch::Tensor,
torch::Tensor,
std::optional<EventHandle>,
std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x,
const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank,
int num_experts,
bool use_fp8,
bool round_scale,
bool use_ue8m0,
bool async,
bool return_recv_hook);

std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> low_latency_combine(
const torch::Tensor& x,
const torch::Tensor& topk_idx,
const torch::Tensor& topk_weights,
const torch::Tensor& src_info,
const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank,
int num_experts,
bool use_logfmt,
bool zero_copy,
bool async,
bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);

torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;

void low_latency_update_mask_buffer(int rank_to_mask, bool mask);

void low_latency_query_mask_buffer(const torch::Tensor& mask_status);

void low_latency_clean_mask_buffer();
};

} // namespace deep_ep
} // namespace deep_ep
Loading
Loading