Skip to content

Commit

Permalink
fix the _allgather_base backend issue(issue11) (#12)
Browse files Browse the repository at this point in the history
Get following error when torch==2.3.0
Could not retrieve or create the backend 2 for device type cuda
Find that, the topo-detection utils may has conflict with it.
Therefore comment the topo-utils.
  • Loading branch information
zheng-ningxin authored Jul 1, 2024
1 parent 775e061 commit 66e2716
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 70 deletions.
41 changes: 0 additions & 41 deletions src/all_gather/ths_op/all_gather_gemm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "flux/op_registry.h"
#include "flux/runtime_config.h"
#include "flux/ths_op/ths_op.h"
#include "flux/ths_op/topo_utils.h"
#include "flux/ths_op/util.h"
#include "flux/args/all_gather.h"
#include "flux/utils.h"
Expand Down Expand Up @@ -162,7 +161,6 @@ class AGKernel : public torch::CustomClassHolder {
<< "invalid nnodes: world_size[" << world_size << "] %% nnodes[" << nnodes << "] != 0";
FLUX_CHECK(!(transpose_weight == true && is_fp8_gemm == true))
<< "FP8 GEMM does not support transpose weight";
_ensure_topo_initialized();
this->ring_mode = get_ring_mode(ring_mode_);

// input buffer
Expand Down Expand Up @@ -812,12 +810,6 @@ class AGKernel : public torch::CustomClassHolder {
CU_STREAM_WRITE_VALUE_DEFAULT));
}

void
_ensure_topo_initialized() {
if (!topo_utils::is_topo_initialized()) {
topo_utils::initialize_topo(const_cast<c10d::ProcessGroup &>(this->tp_group));
}
}

void
copy_all_to_all(torch::Tensor input, at::cuda::CUDAStream stream) {
Expand Down Expand Up @@ -889,39 +881,6 @@ class AGKernel : public torch::CustomClassHolder {

void
copy_ring_push_2d_pcie(torch::Tensor input, at::cuda::CUDAStream stream) {
// [0, numa_world_size) stages: 0 <- 1 <- 2 <- 3 <- 4 <- 5 <- 6 <- 7 <- 0
// [numa_world_size, world_size) stages: 0 <- 1 <- 2 <-3 <- 0 && 4 <- 5 <- 6 <- 7 <- 4
int to_rank = (rank - 1 + world_size) % world_size; // always recv data from rank prev
int numa_world_size = topo_utils::topo_numa_local_world_size();
FLUX_CHECK_DIV(this->local_world_size, numa_world_size);
int numa_nodes = this->local_world_size / numa_world_size;
FLUX_CHECK_EQ(numa_nodes, 2) << " world_size " << this->local_world_size
<< " with numa_world_size " << numa_world_size;
int nnode = rank / numa_world_size;
for (int i = 0; i < world_size - 1; i++) { // with inner and intra numa node
int send_segment = (rank + i) % world_size;
if (i >= numa_world_size && rank % numa_world_size == 0) {
send_segment = (send_segment + numa_world_size) % world_size;
to_rank = (rank - 1 + numa_world_size) % numa_world_size + nnode * numa_world_size;
}
for (int j = 0; j < SPLIT; ++j) {
auto split_offset = j * split_chunk_size;
if (i != 0 && !(i >= numa_world_size &&
rank % numa_world_size == 0)) { // for i == 0 it is always ready
// previous rank recv done
wait_ready(rank, send_segment, j, stream);
}
void *from_ptr = this->input_ptrs[rank];
void *to_ptr = this->input_ptrs[to_rank];
CUDA_CHECK(cudaMemcpyAsync(
ptr_offset(to_ptr, send_segment * chunk_size + split_offset),
ptr_offset(from_ptr, send_segment * chunk_size + split_offset),
split_chunk_size,
cudaMemcpyDeviceToDevice,
stream));
set_ready(to_rank, send_segment, j, stream);
}
}
}

void
Expand Down
15 changes: 0 additions & 15 deletions src/all_gather/ths_op/all_gather_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,7 @@ static const int intra_numa_world_size = 4;

static AGRingMode
get_ring_mode(AGRingMode ring_mode) {
if (ring_mode == AGRingMode::Auto) { // auto detect. with nvlink use ring mode.
if (topo_utils::has_nvswitch()) {
return AGRingMode::All2All;
}

if (topo_utils::has_heterogeneous_pcie()) {
if (topo_utils::topo_numa_local_world_size() != intra_numa_world_size) {
std::cerr << "warning: only NUMA world_size==" << intra_numa_world_size
<< " is optimized for\n";
return AGRingMode::Ring1D; // PCI-e ring mode with no optimization
}
return AGRingMode::Ring2D;
}
return AGRingMode::Ring1D;
}
return ring_mode;
}

} // namespace bytedance::flux
15 changes: 1 addition & 14 deletions src/reduce_scatter/ths_op/gemm_reduce_scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,7 @@ class GemmRS : public torch::CustomClassHolder {

bool
has_nvlink() {
topo_utils::initialize_topo(const_cast<c10d::ProcessGroup &>(this->tp_group));
this->sub_world_size = topo_utils::topo_numa_local_world_size();
static int has_nvlink_env = get_int_from_env("FLUX_FORCE_NVLINK", -1);
if (has_nvlink_env == -1) {
if (topo_utils::has_nvswitch()) {
return true;
} else {
if (topo_utils::has_heterogeneous_nvlink()) {
this->sub_world_size = topo_utils::topo_nvlink_local_world_size();
}
return false;
}
}
return has_nvlink_env;
return true;
}

bool
Expand Down

0 comments on commit 66e2716

Please sign in to comment.