Skip to content

Commit 4fa6445

Browse files
authored
IPEX Tensor Parallel (#2435)
1 parent f4ee125 commit 4fa6445

26 files changed

+1717
-3
lines changed

.gitmodules

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
[submodule "third_party/libxsmm"]
88
path = third_party/libxsmm
99
url = https://github.com/libxsmm/libxsmm.git
10+
[submodule "third_party/oneCCL"]
11+
path = third_party/oneCCL
12+
url = https://github.com/oneapi-src/oneCCL
1013
[submodule "third_party/sleef"]
1114
path = third_party/sleef
1215
url = https://github.com/shibatch/sleef.git
16+

cmake/Modules/FindoneCCL.cmake

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# - Try to find oneCCL
2+
#
3+
# The following are set after configuration is done:
4+
# ONECCL_FOUND : set to true if oneCCL is found.
5+
# ONECCL_INCLUDE_DIRS : path to oneCCL include dir.
6+
# ONECCL_LIBRARIES : list of libraries for oneCCL
7+
#
8+
# and the following imported targets:
9+
#
10+
# oneCCL
11+
12+
IF (NOT ONECCL_FOUND)
13+
SET(ONECCL_FOUND OFF)
14+
SET(ONECCL_LIBRARIES)
15+
SET(ONECCL_INCLUDE_DIRS)
16+
17+
SET(ONECCL_ROOT "${PROJECT_SOURCE_DIR}/third_party/oneCCL")
18+
19+
IF(BUILD_NO_ONECCL_PACKAGE)
20+
ADD_SUBDIRECTORY(${ONECCL_ROOT} oneCCL EXCLUDE_FROM_ALL)
21+
ELSE()
22+
ADD_SUBDIRECTORY(${ONECCL_ROOT} build)
23+
ENDIF()
24+
25+
IF(NOT TARGET ccl)
26+
MESSAGE(FATAL_ERROR "Failed to find oneCCL target")
27+
ENDIF()
28+
add_library(oneCCL ALIAS ccl)
29+
30+
GET_TARGET_PROPERTY(INCLUDE_DIRS oneCCL INCLUDE_DIRECTORIES)
31+
SET(ONECCL_INCLUDE_DIRS ${INCLUDE_DIRS})
32+
SET(ONECCL_LIBRARIES oneCCL)
33+
34+
find_package_handle_standard_args(oneCCL FOUND_VAR ONECCL_FOUND REQUIRED_VARS ONECCL_LIBRARIES ONECCL_INCLUDE_DIRS)
35+
36+
ENDIF(NOT ONECCL_FOUND)

cmake/cpu/BuildFlags.cmake

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ if(env_cxx_standard GREATER -1)
1313
endif()
1414
set(CMAKE_CXX_STANDARD 17)
1515
set(CMAKE_C_STANDARD 11)
16-
set(CMAKE_CXX_EXTENSIONS OFF)
16+
#oneCCL build only support the gnu standard
17+
set(CMAKE_CXX_EXTENSIONS ON)
1718

1819
if(MSVC)
1920
set(CMAKE_COMPILE_WARNING_AS_ERROR OFF)

cmake/cpu/Options.cmake

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ if(WIN32)
1313
set(USE_LIBXSMM ON)
1414
endif()
1515

16+
option(USE_CCL "Enable oneCCL in IPEX" ON)
17+
option(USE_SHM "Enable shared memory communication in IPEX" ON)
18+
if(WIN32)
19+
set(USE_SHM OFF)
20+
endif()
21+
#set USE_SHM to OFF if USE_CCL is OFF
22+
23+
1624
cmake_dependent_option(BUILD_STATIC_ONEMKL "Static link with oneMKL" OFF "BUILD_WITH_XPU" ON)
1725

1826
function (print_cpu_config_summary)
@@ -49,6 +57,8 @@ function (print_cpu_config_summary)
4957
message(STATUS " IPEX_DISP_OP : ${IPEX_DISP_OP}")
5058
message(STATUS " BUILD_XSMM_VIA_CMAKE : ${BUILD_LIBXSMM_VIA_CMAKE}")
5159
message(STATUS " USE_LIBXSMM : ${USE_LIBXSMM}")
60+
message(STATUS " USE_CCL : ${USE_CCL}")
61+
message(STATUS " USE_SHM : ${USE_SHM}")
5262
message(STATUS "")
5363
message(STATUS "********************************")
5464
endfunction()

csrc/cpu/CMakeLists.txt

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ set(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE)
1111
set(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE)
1212

1313
#find_package(TorchCCL REQUIRED)
14+
# Find OneCCL Lib
15+
set(DEPENDS_LIB)
16+
if(USE_CCL)
17+
include(${IPEX_ROOT_DIR}/cmake/Modules/FindoneCCL.cmake)
18+
# Find OneCCL Lib
19+
link_directories(${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/oneCCL/deps/mpi/lib)
20+
find_package(oneCCL REQUIRED)
21+
list(APPEND DEPENDS_LIB oneCCL)
22+
list(APPEND DEPENDS_LIB mpi)
23+
endif()
1424

1525
# TODO: Once llga is merged into oneDNN, use oneDNN directly as the third_party of IPEX
1626
# use the oneDNN in llga temporarily: third_party/llga/third_party/oneDNN
@@ -34,6 +44,14 @@ if(USE_LIBXSMM)
3444
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_LIBXSMM")
3545
endif(USE_LIBXSMM)
3646

47+
if(USE_CCL)
48+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_CCL")
49+
endif(USE_CCL)
50+
51+
if(USE_SHM)
52+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_SHM")
53+
endif(USE_SHM)
54+
3755
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_IPEX_MAIN_LIB")
3856

3957
# ---[ Main build
@@ -73,6 +91,7 @@ add_subdirectory(${IPEX_CPU_ROOT_DIR}/isa)
7391
add_subdirectory(${IPEX_CPU_ROOT_DIR}/toolkit)
7492
add_subdirectory(${IPEX_CPU_ROOT_DIR}/runtime)
7593
add_subdirectory(${IPEX_CPU_ROOT_DIR}/utils)
94+
add_subdirectory(${IPEX_CPU_ROOT_DIR}/comm)
7695

7796
add_subdirectory(${IPEX_CPU_ROOT_DIR}/jit)
7897

@@ -84,7 +103,8 @@ if(USE_LIBXSMM)
84103
endif(USE_LIBXSMM)
85104

86105
set(IPEX_CPU_CPP_SRCS ${IPEX_CPU_CPP_DYNDISP_SRCS} ${IPEX_CPU_CPP_ISA_SRCS_GEN} ${IPEX_CPU_CPP_UTILS_SRCS} ${IPEX_CPU_CPP_QUANTIZATION_SRCS} ${IPEX_CPU_CPP_JIT_SRCS} ${IPEX_JIT_COMMON_CPP_SRCS}
87-
${IPEX_CPU_CPP_ISA_SRCS} ${IPEX_CPU_CPP_IDEEP_SRCS} ${IPEX_CPU_CPP_AUTOCAST_SRCS} ${IPEX_CPU_CPP_ATEN_SRCS} ${IPEX_CPU_CPP_RUNTIME_SRCS} ${IPEX_CPU_CPP_TOOLKIT_SRCS} ${IPEX_UTLIS_CPP_SRCS} ${IPEX_CPU_CPP_TPP_SRCS})
106+
${IPEX_CPU_CPP_ISA_SRCS} ${IPEX_CPU_CPP_IDEEP_SRCS} ${IPEX_CPU_CPP_AUTOCAST_SRCS} ${IPEX_CPU_CPP_ATEN_SRCS} ${IPEX_CPU_CPP_RUNTIME_SRCS} ${IPEX_CPU_CPP_TOOLKIT_SRCS} ${IPEX_UTLIS_CPP_SRCS}
107+
${IPEX_CPU_CPP_TPP_SRCS} ${IPEX_CPU_CPP_COMM_SRCS})
88108

89109
list(REMOVE_ITEM IPEX_CPU_CPP_SRCS ${IPEX_CPU_CPP_ISA_SRCS_ORIGIN})
90110

@@ -123,6 +143,7 @@ target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${ONEDNN_GENERATED_INCLUDE}
123143

124144
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/include)
125145
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${PYTHON_INCLUDE_DIR})
146+
target_link_libraries(${PLUGIN_NAME_CPU} PUBLIC ${DEPENDS_LIB})
126147

127148
include(${IPEX_ROOT_DIR}/cmake/ClangFormat.cmake)
128149
if(CLANG_FORMAT)
@@ -221,6 +242,7 @@ if(BUILD_STRIPPED_BIN)
221242
set_target_properties(${PLUGIN_NAME_CPU} PROPERTIES LINK_FLAGS_RELEASE -s)
222243
endif()
223244

245+
224246
install(TARGETS ${PLUGIN_NAME_CPU}
225247
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
226248
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include "CollectiveCommunicationPrimitive.h"
2+
#include <ATen/FunctionalTensorWrapper.h>
3+
#include <torch/all.h>
4+
#include <torch/csrc/autograd/function.h>
5+
6+
namespace torch_ipex {
7+
namespace cpu {
8+
9+
IPEX_DEFINE_DISPATCH(all_reduce_add_kernel_stub);
10+
IPEX_DEFINE_DISPATCH(allgather_kernel_stub);
11+
12+
at::Tensor all_reduce_add(at::Tensor t_in) {
13+
RECORD_FUNCTION("ipex::all_reduce_add", c10::ArrayRef<c10::IValue>({}));
14+
return all_reduce_add_kernel_stub(kCPU, t_in);
15+
}
16+
17+
at::Tensor allgather(
18+
at::Tensor t_in,
19+
std::vector<int64_t> cols_per_rank,
20+
int64_t world_size) {
21+
RECORD_FUNCTION("ipex::allgather", c10::ArrayRef<c10::IValue>({}));
22+
return allgather_kernel_stub(kCPU, t_in, cols_per_rank, world_size);
23+
}
24+
25+
} // namespace cpu
26+
} // namespace torch_ipex
27+
28+
namespace {
29+
30+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
31+
m.def("all_reduce_add(Tensor(a!) t_in)-> (Tensor)");
32+
m.impl(
33+
"all_reduce_add", c10::DispatchKey::CPU, torch_ipex::cpu::all_reduce_add);
34+
m.def("allgather(Tensor input, int[] output, int world_size) -> (Tensor)");
35+
m.impl("allgather", c10::DispatchKey::CPU, torch_ipex::cpu::allgather);
36+
}
37+
} // namespace
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <dyndisp/DispatchStub.h>
5+
6+
namespace torch_ipex {
7+
namespace cpu {
8+
9+
namespace {
10+
11+
at::Tensor all_reduce_add(at::Tensor& t_in);
12+
at::Tensor allgather(
13+
at::Tensor t_in,
14+
std::vector<int64_t> cols_per_rank,
15+
int64_t world_size);
16+
int64_t get_world_size(const at::Tensor dummy_input);
17+
int64_t get_rank(const at::Tensor dummy_input);
18+
} // namespace
19+
20+
using all_reduce_add_fn = at::Tensor (*)(at::Tensor& t_in);
21+
using allgather_fn = at::Tensor (*)(
22+
at::Tensor t_in,
23+
std::vector<int64_t> cols_per_rank,
24+
int64_t world_size);
25+
26+
IPEX_DECLARE_DISPATCH(all_reduce_add_fn, all_reduce_add_kernel_stub);
27+
IPEX_DECLARE_DISPATCH(allgather_fn, allgather_kernel_stub);
28+
29+
} // namespace cpu
30+
} // namespace torch_ipex

csrc/cpu/aten/ShmAllReduceAdd.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
2+
#include "ShmAllReduceAdd.h"
3+
#include <ATen/FunctionalTensorWrapper.h>
4+
#include <torch/all.h>
5+
#include <torch/csrc/autograd/function.h>
6+
7+
namespace torch_ipex {
8+
namespace cpu {
9+
10+
IPEX_DEFINE_DISPATCH(shm_all_reduce_add_kernel_stub);
11+
12+
at::Tensor shm_all_reduce_add_forward_cpu(
13+
at::Tensor& t_in,
14+
at::Tensor& t_address,
15+
at::Tensor& t_state,
16+
at::Tensor& t_blockState,
17+
int64_t shm_block_size,
18+
int64_t rank,
19+
int64_t world_size) {
20+
return shm_all_reduce_add_kernel_stub(
21+
kCPU,
22+
t_in,
23+
t_address,
24+
t_state,
25+
t_blockState,
26+
shm_block_size,
27+
rank,
28+
world_size);
29+
}
30+
31+
} // namespace cpu
32+
} // namespace torch_ipex

csrc/cpu/aten/ShmAllReduceAdd.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <dyndisp/DispatchStub.h>
5+
6+
namespace torch_ipex {
7+
namespace cpu {
8+
9+
namespace {
10+
11+
at::Tensor shm_all_reduce_add(
12+
at::Tensor& t_in,
13+
at::Tensor& t_address,
14+
at::Tensor& t_state,
15+
at::Tensor& t_blockState,
16+
int64_t shm_block_size,
17+
int64_t rank,
18+
int64_t world_size);
19+
}
20+
21+
using shm_all_reduce_add_kernel_fn = at::Tensor (*)(
22+
at::Tensor& t_in,
23+
at::Tensor& t_address,
24+
at::Tensor& t_state,
25+
at::Tensor& t_blockState,
26+
int64_t shm_block_size,
27+
int64_t rank,
28+
int64_t world_size);
29+
30+
IPEX_DECLARE_DISPATCH(
31+
shm_all_reduce_add_kernel_fn,
32+
shm_all_reduce_add_kernel_stub);
33+
34+
} // namespace cpu
35+
} // namespace torch_ipex
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/Tensor.h>
3+
#include <aten/CollectiveCommunicationPrimitive.h>
4+
#include <comm/messager.h>
5+
#include <torch/csrc/autograd/function.h>
6+
7+
namespace torch_ipex {
8+
namespace cpu {
9+
10+
namespace {
11+
at::Tensor all_reduce_add_kernel_impl(at::Tensor& t_in) {
12+
Messenger::getInstance().reduceAdd(t_in);
13+
return t_in;
14+
}
15+
16+
at::Tensor allgather_kernel_impl(
17+
at::Tensor t_in,
18+
std::vector<int64_t> cols_per_rank,
19+
int64_t world_size) {
20+
std::vector<at::Tensor> output_tensors;
21+
auto shape = t_in.contiguous().sizes();
22+
for (int64_t rank = 0; rank < world_size; rank++) {
23+
std::vector<int64_t> t_out_shape(shape.begin(), shape.end() - 1);
24+
t_out_shape.push_back(cols_per_rank[rank + 1] - cols_per_rank[rank]);
25+
output_tensors.push_back(at::empty(t_out_shape, t_in.options()));
26+
}
27+
28+
return Messenger::getInstance().allgather(t_in, output_tensors);
29+
}
30+
31+
} // anonymous namespace
32+
33+
IPEX_REGISTER_DISPATCH(all_reduce_add_kernel_stub, &all_reduce_add_kernel_impl);
34+
35+
IPEX_REGISTER_DISPATCH(allgather_kernel_stub, &allgather_kernel_impl);
36+
37+
} // namespace cpu
38+
} // namespace torch_ipex

0 commit comments

Comments
 (0)