diff --git a/src/all_gather/ths_op/all_gather_gemm_kernel.cc b/src/all_gather/ths_op/all_gather_gemm_kernel.cc index 7391a10d9e..eeed1ea924 100644 --- a/src/all_gather/ths_op/all_gather_gemm_kernel.cc +++ b/src/all_gather/ths_op/all_gather_gemm_kernel.cc @@ -25,7 +25,6 @@ #include "flux/op_registry.h" #include "flux/runtime_config.h" #include "flux/ths_op/ths_op.h" -#include "flux/ths_op/topo_utils.h" #include "flux/ths_op/util.h" #include "flux/args/all_gather.h" #include "flux/utils.h" @@ -162,7 +161,6 @@ class AGKernel : public torch::CustomClassHolder { << "invalid nnodes: world_size[" << world_size << "] %% nnodes[" << nnodes << "] != 0"; FLUX_CHECK(!(transpose_weight == true && is_fp8_gemm == true)) << "FP8 GEMM does not support transpose weight"; - _ensure_topo_initialized(); this->ring_mode = get_ring_mode(ring_mode_); // input buffer @@ -812,12 +810,6 @@ class AGKernel : public torch::CustomClassHolder { CU_STREAM_WRITE_VALUE_DEFAULT)); } - void - _ensure_topo_initialized() { - if (!topo_utils::is_topo_initialized()) { - topo_utils::initialize_topo(const_cast(this->tp_group)); - } - } void copy_all_to_all(torch::Tensor input, at::cuda::CUDAStream stream) { @@ -889,39 +881,6 @@ class AGKernel : public torch::CustomClassHolder { void copy_ring_push_2d_pcie(torch::Tensor input, at::cuda::CUDAStream stream) { - // [0, numa_world_size) stages: 0 <- 1 <- 2 <- 3 <- 4 <- 5 <- 6 <- 7 <- 0 - // [numa_world_size, world_size) stages: 0 <- 1 <- 2 <-3 <- 0 && 4 <- 5 <- 6 <- 7 <- 4 - int to_rank = (rank - 1 + world_size) % world_size; // always recv data from rank prev - int numa_world_size = topo_utils::topo_numa_local_world_size(); - FLUX_CHECK_DIV(this->local_world_size, numa_world_size); - int numa_nodes = this->local_world_size / numa_world_size; - FLUX_CHECK_EQ(numa_nodes, 2) << " world_size " << this->local_world_size - << " with numa_world_size " << numa_world_size; - int nnode = rank / numa_world_size; - for (int i = 0; i < world_size - 1; i++) { // with inner and intra numa node - int send_segment = (rank + i) % world_size; - if (i >= numa_world_size && rank % numa_world_size == 0) { - send_segment = (send_segment + numa_world_size) % world_size; - to_rank = (rank - 1 + numa_world_size) % numa_world_size + nnode * numa_world_size; - } - for (int j = 0; j < SPLIT; ++j) { - auto split_offset = j * split_chunk_size; - if (i != 0 && !(i >= numa_world_size && - rank % numa_world_size == 0)) { // for i == 0 it is always ready - // previous rank recv done - wait_ready(rank, send_segment, j, stream); - } - void *from_ptr = this->input_ptrs[rank]; - void *to_ptr = this->input_ptrs[to_rank]; - CUDA_CHECK(cudaMemcpyAsync( - ptr_offset(to_ptr, send_segment * chunk_size + split_offset), - ptr_offset(from_ptr, send_segment * chunk_size + split_offset), - split_chunk_size, - cudaMemcpyDeviceToDevice, - stream)); - set_ready(to_rank, send_segment, j, stream); - } - } } void diff --git a/src/all_gather/ths_op/all_gather_types.h b/src/all_gather/ths_op/all_gather_types.h index 6b0db3c8b5..219ab399b3 100644 --- a/src/all_gather/ths_op/all_gather_types.h +++ b/src/all_gather/ths_op/all_gather_types.h @@ -19,22 +19,7 @@ static const int intra_numa_world_size = 4; static AGRingMode get_ring_mode(AGRingMode ring_mode) { - if (ring_mode == AGRingMode::Auto) { // auto detect. with nvlink use ring mode. - if (topo_utils::has_nvswitch()) { return AGRingMode::All2All; - } - - if (topo_utils::has_heterogeneous_pcie()) { - if (topo_utils::topo_numa_local_world_size() != intra_numa_world_size) { - std::cerr << "warning: only NUMA world_size==" << intra_numa_world_size - << " is optimized for\n"; - return AGRingMode::Ring1D; // PCI-e ring mode with no optimization - } - return AGRingMode::Ring2D; - } - return AGRingMode::Ring1D; - } - return ring_mode; } } // namespace bytedance::flux diff --git a/src/reduce_scatter/ths_op/gemm_reduce_scatter.cc b/src/reduce_scatter/ths_op/gemm_reduce_scatter.cc index cd551eae66..290d11ce0b 100644 --- a/src/reduce_scatter/ths_op/gemm_reduce_scatter.cc +++ b/src/reduce_scatter/ths_op/gemm_reduce_scatter.cc @@ -159,20 +159,7 @@ class GemmRS : public torch::CustomClassHolder { bool has_nvlink() { - topo_utils::initialize_topo(const_cast(this->tp_group)); - this->sub_world_size = topo_utils::topo_numa_local_world_size(); - static int has_nvlink_env = get_int_from_env("FLUX_FORCE_NVLINK", -1); - if (has_nvlink_env == -1) { - if (topo_utils::has_nvswitch()) { - return true; - } else { - if (topo_utils::has_heterogeneous_nvlink()) { - this->sub_world_size = topo_utils::topo_nvlink_local_world_size(); - } - return false; - } - } - return has_nvlink_env; + return true; } bool