Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third-party/nccl"]
path = third-party/nccl
url = https://github.com/NVIDIA/nccl.git
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ This document introduces the Hybrid Expert Parallel (Hybrid-EP) implementation t

### Hardware Optimizations
- **TMA Instructions**: Leverage Tensor Memory Accelerator instructions for minimal SM overhead
- **RDMA Integration**: High-efficiency inter-node communication (coming soon)*
- **RDMA Integration**: High-efficiency inter-node communication
- **Pipeline Architecture**: Warp-level pipeline parallelism within execution blocks

### Supported Data Types
Expand Down Expand Up @@ -120,13 +120,13 @@ This document introduces the Hybrid Expert Parallel (Hybrid-EP) implementation t
```
csrc/hybrid_ep/
├── hybrid_ep.cu # Main CUDA implementation
├── hybrid_ep.cuh # Header definitions
├── internode.cu # Main RMDA CUDA implementation
├── pybind_hybrid_ep.cu # PyBind bindings
├── config.cuh # Config definitions required by hybrid-EP kernels
├── utils.cuh # Utility helpers and macros
├── allocator/ # Allocator for memory accessible by remote ranks
├── backend/ # Core Hybrid-EP kernel implementations
│ └── hybrid_ep_backend.cuh
│ ├── hybrid_ep_backend.cuh
│ └── utils.cuh # Utility helpers and macros
├── executor/ # Kernel runner
├── extension/ # Useful extensions
└── jit/ # JIT compiler
Expand All @@ -148,7 +148,7 @@ Follow the same build process as the main branch. No additional dependencies req
Refer to `tests/test_hybrid_ep.py` for comprehensive usage examples including:
- Multi-node configuration
- Intra-node testing scenarios
- Inter-node testing will come soon
- Inter-node testing scenarios
- Performance benchmarking setups

### Important Configuration Note
Expand Down Expand Up @@ -207,7 +207,6 @@ Here are important parameter settings in `csrc/hybrid_ep/config.cuh`. You can mo

### 🚧 Upcoming Features
- **Low Latency Mode**: Enhanced performance for latency-critical workloads
- **RDMA Integration**: High-performance inter-node communication

### ⚠️ Current Limitations
- RDMA functionality not yet available (under final testing)
Expand Down

Large diffs are not rendered by default.

115 changes: 85 additions & 30 deletions deep_ep/backend/utils.cuh → csrc/hybrid_ep/backend/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <sstream>
#include <algorithm>
#include <type_traits>
#include <linux/types.h>
#define MAX_NUM_OF_RANKS_PER_NODE 72

enum class TOKEN_DATA_TYPE { UINT16, UINT8 };

Expand All @@ -35,42 +37,95 @@ inline int get_token_data_type_size(TOKEN_DATA_TYPE token_data_type) {
return 0;
}

#ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE
struct dispatch_memory_region_info_t {
__be32 token_lkey;
__be32 token_rkey;
__be32 prob_lkey;
__be32 prob_rkey;
__be32 scaling_factor_lkey;
__be32 scaling_factor_rkey;
__be32 flag_lkey;
__be32 flag_rkey;
uint64_t token_laddr;
uint64_t token_raddr;
uint64_t prob_laddr;
uint64_t prob_raddr;
uint64_t scaling_factor_laddr;
uint64_t scaling_factor_raddr;
uint64_t flag_laddr;
uint64_t flag_raddr;
} __attribute__((__aligned__(8)));

struct combine_memory_region_info_t {
__be32 token_lkey;
__be32 token_rkey;
__be32 prob_lkey;
__be32 prob_rkey;
__be32 flag_lkey;
__be32 flag_rkey;
uint64_t token_laddr;
uint64_t token_raddr;
uint64_t prob_laddr;
uint64_t prob_raddr;
uint64_t flag_laddr;
uint64_t flag_raddr;
} __attribute__((__aligned__(8)));
#endif

struct DispatchBuffers {
TOKEN_DATA_TYPE data_type;
// Input buffers from attn, only used in inter-node case
void * attn_input_token = nullptr;
void * attn_input_prob = nullptr;
void * attn_input_flags = nullptr;
void * attn_input_scaling_factor = nullptr;
// Output buffers to experts
void *expert_output_token;
void **expert_output_token_all_ranks;
float *expert_output_prob;
float **expert_output_prob_all_ranks;
float *expert_output_scaling_factor;
float **expert_output_scaling_factor_all_ranks;
// Local temp buffer for dispatch kernel.
void *rdma_inter_node_group_token;
float *rdma_inter_node_group_prob;
float *rdma_inter_node_group_scaling_factor;
uint64_t *rdma_inter_node_group_flags;
void * expert_output_token = nullptr;
void ** expert_output_token_all_ranks = nullptr;
float * expert_output_prob = nullptr;
float ** expert_output_prob_all_ranks = nullptr;
float * expert_output_scaling_factor = nullptr;
float ** expert_output_scaling_factor_all_ranks = nullptr;
// RDMA buffers for dispatch kernel.
void * rdma_inter_node_group_token = nullptr;
float * rdma_inter_node_group_prob = nullptr;
float * rdma_inter_node_group_scaling_factor = nullptr;
uint64_t * rdma_inter_node_group_flags = nullptr;
// Misc flags
uint32_t *intra_node_write_completion_flags;
uint64_t *expected_rdma_flag_value;
uint32_t *expected_intra_node_flag_value;
uint32_t * intra_node_write_completion_flags = nullptr;
uint64_t * expected_rdma_flag_value = nullptr;
uint32_t * expected_intra_node_flag_value = nullptr;
#ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE
// qp info and mr info
struct doca_gpu_dev_verbs_qp ** d_qps_gpu = nullptr;
struct dispatch_memory_region_info_t * mr_info = nullptr;
#endif
};

struct CombineBuffers {
// Input buffers from experts
uint16_t *expert_input_token;
uint16_t **expert_input_token_all_ranks;
float *expert_input_prob;
float **expert_input_prob_all_ranks;
// Local temp buffer for combine kernel.
uint16_t *rdma_intra_node_red_token;
float *rdma_intra_node_red_prob;
uint16_t *rdma_inter_node_group_token;
float *rdma_inter_node_group_prob;
uint64_t *rdma_inter_node_group_flags;
uint16_t * expert_input_token = nullptr;
uint16_t ** expert_input_token_all_ranks = nullptr;
float * expert_input_prob = nullptr;
float ** expert_input_prob_all_ranks = nullptr;
// Output buffers to attn, only used in inter-node case
void * attn_output_flags = nullptr;
// RDMA buffers for combine kernel.
uint16_t * rdma_intra_node_red_token = nullptr;
float * rdma_intra_node_red_prob = nullptr;
uint16_t * rdma_inter_node_group_token = nullptr;
float * rdma_inter_node_group_prob = nullptr;
uint64_t * rdma_inter_node_group_flags = nullptr;
// Misc flags
uint32_t *intra_node_write_completion_flags;
uint64_t *expected_rdma_flag_value;
uint32_t *expected_intra_node_flag_value;
uint32_t * intra_node_write_completion_flags = nullptr;
uint64_t * expected_rdma_flag_value = nullptr;
uint32_t * expected_intra_node_flag_value = nullptr;
#ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE
// qp info and mr info
struct doca_gpu_dev_verbs_qp ** d_qps_gpu = nullptr;
struct combine_memory_region_info_t * mr_info = nullptr;
#endif
};

__device__ __forceinline__ bool elect_sync(uint32_t membermask) {
Expand Down Expand Up @@ -173,7 +228,7 @@ inline void print_ptr_info(void* p) {
cudaPointerAttributes attr{};
cudaError_t err = cudaPointerGetAttributes(&attr, p);
if (err != cudaSuccess) {
printf("cudaPointerGetAttributes failed: %s\n", cudaGetErrorString(err));
fprintf(stderr, "cudaPointerGetAttributes failed: %s\n", cudaGetErrorString(err));
return;
}
cudaMemoryType memory_type;
Expand All @@ -189,14 +244,14 @@ inline void print_ptr_info(void* p) {
case cudaMemoryTypeManaged: memory_type_str = "Managed"; break;
default: memory_type_str = "Unregistered/Unknown"; break;
}
printf("type=%s, device=%d\n", memory_type_str.c_str(), attr.device);
fprintf(stderr, "type=%s, device=%d\n", memory_type_str.c_str(), attr.device);

// If this is a device/managed pointer, try to query its allocation range (base + size)
if (memory_type == cudaMemoryTypeDevice || memory_type == cudaMemoryTypeManaged) {
cuInit(0);
CUdeviceptr base = 0;
size_t size = 0;
CUresult r = cuMemGetAddressRange(&base, &size, reinterpret_cast<CUdeviceptr>(p));
printf("alloc_base=%p, alloc_size=%zu bytes\n", reinterpret_cast<void*>(base), size);
fprintf(stderr, "alloc_base=%p, alloc_size=%zu bytes\n", reinterpret_cast<void*>(base), size);
}
}
2 changes: 2 additions & 0 deletions csrc/hybrid_ep/config.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ struct BufferConfig {
int num_of_nodes;
TOKEN_DATA_TYPE token_data_type;
int num_of_blocks_preprocessing_api;
int num_of_blocks_dispatch_api;
int num_of_blocks_combine_api;
int num_of_tokens_per_chunk_dispatch_api;
int num_of_tokens_per_chunk_combine_api;

Expand Down
51 changes: 40 additions & 11 deletions csrc/hybrid_ep/executor/executor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "executor.cuh"

Executor::Executor(int local_rank, int node_rank, std::string base_path) : local_rank(local_rank), node_rank(node_rank), kernel_cache(local_rank, base_path) {}
Executor::Executor(int local_rank, int node_rank, std::string base_path, bool load_cached_kernels) : local_rank(local_rank), node_rank(node_rank), kernel_cache(node_rank, local_rank, base_path, load_cached_kernels) {}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Executor::metadata_preprocess_core(
Expand Down Expand Up @@ -31,7 +31,7 @@ Executor::metadata_preprocess_core(
torch::empty({num_of_tokens_per_rank, config.num_of_nodes - 1},
torch::dtype(torch::kBool).device(torch::kCUDA));
auto num_of_tokens_for_experts =
torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto local_expert_routing_map = torch::empty(
{num_of_tokens_per_rank * config.num_of_ranks_per_node * config.num_of_nodes, config.num_of_experts_per_rank},
torch::dtype(torch::kBool).device(torch::kCUDA));
Expand All @@ -49,8 +49,26 @@ Executor::metadata_preprocess_core(
}

void Executor::dispatch_preprocess(HybridEpConfigInstance config, DispatchBuffers& dispatch_buffers, DispatchArgs& args) {
// Empty now, will be filled with D2D in the inter-node case
nvtxRangePushA("dispatch_preprocess in hybrid-ep");
if(config.num_of_nodes > 1) {
#ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE
auto sizeof_token_data_type = get_token_data_type_size(config.token_data_type);
CUDA_CHECK(cudaMemcpyAsync(dispatch_buffers.attn_input_token, args.hidden.data_ptr(), args.hidden.numel() * sizeof_token_data_type, cudaMemcpyDeviceToDevice, args.stream));
if(config.forward_dispatch_api) {
CUDA_CHECK(cudaMemcpyAsync(dispatch_buffers.attn_input_prob, args.probs.data_ptr(), args.probs.numel() * sizeof(float), cudaMemcpyDeviceToDevice, args.stream));
}
if(config.token_data_type == TOKEN_DATA_TYPE::UINT8) {
CUDA_CHECK(cudaMemcpyAsync(dispatch_buffers.attn_input_scaling_factor, args.scaling_factor.data_ptr(), args.scaling_factor.numel() * sizeof(float), cudaMemcpyDeviceToDevice, args.stream));
}
#else
throw std::runtime_error("Multi-node support is not enabled in this build.");
#endif
} else {
// Set the tensor pointers to the dispatch buffers.
dispatch_buffers.attn_input_token = args.hidden.data_ptr();
dispatch_buffers.attn_input_prob = (config.forward_dispatch_api) ? args.probs.data_ptr() : nullptr;
dispatch_buffers.attn_input_scaling_factor = (config.token_data_type == TOKEN_DATA_TYPE::UINT8) ? args.scaling_factor.data_ptr() : nullptr;
}
nvtxRangePop(); // End of dispatch_preprocess nvtx range
}

Expand All @@ -63,9 +81,9 @@ void Executor::dispatch_core(HybridEpConfigInstance config, DispatchBuffers& dis

hybrid_ep::dispatch_kernel_param_t<DType> param;
// Setup input pointers
param.attn_input_token = reinterpret_cast<DType*>(args.hidden.data_ptr());
param.attn_input_prob = (config.forward_dispatch_api) ? reinterpret_cast<float*>(args.probs.data_ptr()) : nullptr;
param.attn_input_token_scaling_factor = (config.token_data_type == TOKEN_DATA_TYPE::UINT8) ? reinterpret_cast<float*>(args.scaling_factor.data_ptr()) : nullptr;
param.attn_input_token = reinterpret_cast<DType*>(dispatch_buffers.attn_input_token);
param.attn_input_prob = reinterpret_cast<float*>(dispatch_buffers.attn_input_prob);
param.attn_input_token_scaling_factor = reinterpret_cast<float*>(dispatch_buffers.attn_input_scaling_factor);

// Setup output pointers
for (int i = 0; i < config.num_of_ranks_per_node; i++) {
Expand Down Expand Up @@ -95,10 +113,13 @@ void Executor::dispatch_core(HybridEpConfigInstance config, DispatchBuffers& dis
param.num_of_tokens_per_rank = args.num_of_tokens_per_rank;
param.expected_rdma_flag_value = dispatch_buffers.expected_rdma_flag_value;
param.expected_intra_node_flag_value = dispatch_buffers.expected_intra_node_flag_value;
#ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE
param.d_qps_gpu = dispatch_buffers.d_qps_gpu;
param.mr_info = dispatch_buffers.mr_info;
#endif

// Launch kernel
kernel_cache.run_dispatch_kernel<DType>(config, param, args.stream);

nvtxRangePop(); // End of dispatch_core nvtx range
}

Expand Down Expand Up @@ -153,7 +174,10 @@ Executor::dispatch_postprocess(HybridEpConfigInstance config, DispatchBuffers& d
// otherwise, we will compute the num_permuted_tokens by summing the tokens_per_expert.
if (num_permuted_tokens < 0) {
if (args.use_host_meta) {
tokens_per_expert = tokens_per_expert.cpu();
auto host_opts = tokens_per_expert.options().device(torch::kCPU).pinned_memory(true);
torch::Tensor tokens_per_expert_pinned = torch::empty(tokens_per_expert.sizes(), host_opts);
tokens_per_expert_pinned.copy_(tokens_per_expert, /*non_blocking=*/false);
tokens_per_expert = tokens_per_expert_pinned;
}
num_permuted_tokens = tokens_per_expert.sum().item<int64_t>();
}
Expand Down Expand Up @@ -252,6 +276,7 @@ void Executor::combine_preprocess(HybridEpConfigInstance config, CombineBuffers&
cudaMemcpyDeviceToDevice, args.stream));
}
}

nvtxRangePop(); // End of combine_preprocess nvtx range
}

Expand All @@ -268,8 +293,8 @@ void Executor::combine_core(HybridEpConfigInstance config, CombineBuffers& combi
}

// Setup output pointers
param.attn_output_token = args.combined_tokens;
param.attn_output_prob = (config.backward_combine_api) ? args.combined_probs : nullptr;
param.attn_output_token = reinterpret_cast<uint16_t*>(args.combined_tokens);
param.attn_output_prob = (config.backward_combine_api) ? reinterpret_cast<float*>(args.combined_probs) : nullptr;

// Setup local buffer pointers
param.rdma_intra_node_red_token =
Expand All @@ -293,6 +318,10 @@ void Executor::combine_core(HybridEpConfigInstance config, CombineBuffers& combi
param.expected_rdma_flag_value = combine_buffers.expected_rdma_flag_value;
param.expected_intra_node_flag_value =
combine_buffers.expected_intra_node_flag_value;
#ifdef HYBRID_EP_BUILD_MULTINODE_ENABLE
param.d_qps_gpu = combine_buffers.d_qps_gpu;
param.mr_info = combine_buffers.mr_info;
#endif

// Launch kernel
kernel_cache.run_combine_kernel(config, param, args.stream);
Expand All @@ -301,6 +330,6 @@ void Executor::combine_core(HybridEpConfigInstance config, CombineBuffers& combi

void Executor::combine_postprocess(HybridEpConfigInstance config, CombineBuffers& combine_buffers, CombineArgs& args) {
nvtxRangePushA("combine_postprocess in hybrid-ep");
// TODO: Implement the combine postprocessing
// No postprocess is needed for the combine kernel now.
nvtxRangePop(); // End of combine_postprocess nvtx range
}
2 changes: 1 addition & 1 deletion csrc/hybrid_ep/executor/executor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class Executor {
public:
Executor(int local_rank, int node_rank, std::string base_path);
Executor(int local_rank, int node_rank, std::string base_path, bool load_cached_kernels);

struct DispatchArgs {
// Input tensors
Expand Down
Loading