Skip to content

[Layer] Enable pipeline parallel feature. #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Feb 19, 2024
Merged
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ else()
list(APPEND 3RDPART_LIB_LIST "xdnn_static")
endif()

# pipeline parallel feature
option(WITH_PIPELINE_PARALLEL "Build with pipeline parallel" OFF)
if(WITH_PIPELINE_PARALLEL)
message(STATUS "Notice: Building with pipeline parallel.")
add_definitions(-DPIPELINE_PARALLEL=true)
endif()

# Enable AVX512_FP16 optimization
# add_definitions(-DAVX512_FP32_WEIGHT_ONLY_FP16=true)
add_definitions(-DAVX512_FP16_WEIGHT_ONLY_FP16=true)
Expand Down
52 changes: 42 additions & 10 deletions src/comm_helper/comm_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,45 @@

static ccl::communicator *pcomm;

extern "C" int init(int *rank, int *size) {
// world_color is initialized to pipeline_parallel_stages_num(pp_size)
// and will be re-assign to world_color of MPI == ppRank
extern "C" int init(int *world_size, int *world_rank, int *world_color) {
ccl::init();

MPI_Init(NULL, NULL);
MPI_Comm_size(MPI_COMM_WORLD, size);
MPI_Comm_rank(MPI_COMM_WORLD, rank);
MPI_Comm_size(MPI_COMM_WORLD, world_size);
MPI_Comm_rank(MPI_COMM_WORLD, world_rank);

// world_color = world_rank / tpSize = world_rank / (world_size / ppSize)
// like: world_color = 0~7 / (8 / 4), XFT_PIPELINE_STAGES = ppSize = 4; tpSize = 2
// world_rank = 0, 1, -> world_color = ppRank = 0, 0, -> tpRank = 0, 1;
// 2, 3, 1, 1, 0, 1;
// 4, 5, 2, 2, 0, 1;
// 6, 7; 3, 3; 0, 1;
*world_color = *world_rank / (*world_size / *world_color);
MPI_Comm row_comm;
MPI_Comm_split(MPI_COMM_WORLD, *world_color, *world_rank, &row_comm);

int row_size, row_rank;
MPI_Comm_size(row_comm, &row_size);
MPI_Comm_rank(row_comm, &row_rank);

ccl::shared_ptr_class<ccl::kvs> kvs;
ccl::kvs::address_type mainAddr;

if (*rank == 0) {
if (row_rank == 0) {
kvs = ccl::create_main_kvs();
mainAddr = kvs->get_address();
MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, MPI_COMM_WORLD);
MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, row_comm);
} else {
MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, MPI_COMM_WORLD);
MPI_Bcast((void *)mainAddr.data(), mainAddr.size(), MPI_BYTE, 0, row_comm);
kvs = ccl::create_kvs(mainAddr);
}

pcomm = new ccl::communicator(ccl::create_communicator(*size, *rank, kvs));
pcomm = new ccl::communicator(ccl::create_communicator(row_size, row_rank, kvs));

*rank = pcomm->rank();
*size = pcomm->size();
*world_size = pcomm->size();
*world_rank = pcomm->rank();

#ifdef USE_SHM
char myHostname[MPI_MAX_PROCESSOR_NAME];
Expand All @@ -53,7 +69,7 @@ extern "C" int init(int *rank, int *size) {
MPI_COMM_WORLD);

int sameHostnames = 1;
for (int i = 1; i < *size; i++) {
for (int i = 1; i < *world_size; i++) {
if (strcmp(myHostname, &all_hostnames[i * MPI_MAX_PROCESSOR_NAME]) != 0) {
sameHostnames = 0;
break;
Expand Down Expand Up @@ -89,4 +105,20 @@ extern "C" void broadcast(int *buf, size_t count) {
extern "C" void allgatherv(
const float *sendBuf, size_t count, float *recvBuf, const std::vector<long unsigned int> &recvCounts) {
ccl::allgatherv(sendBuf, count, recvBuf, recvCounts, *pcomm).wait();
}

extern "C" void worldSendFP32(const float *buf, int count, int dest, int tag) {
MPI_Send((const void *)buf, count, MPI_FLOAT, dest, tag, MPI_COMM_WORLD);
}

extern "C" void worldRecvFP32(float *buf, int count, int source, int tag) {
MPI_Recv((void *)buf, count, MPI_FLOAT, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}

extern "C" void worldSendINT32(const int32_t *buf, int count, int dest, int tag) {
MPI_Send((const void *)buf, count, MPI_INT32_T, dest, tag, MPI_COMM_WORLD);
}

extern "C" void worldRecvINT32(int32_t *buf, int count, int source, int tag) {
MPI_Recv((void *)buf, count, MPI_INT32_T, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}
12 changes: 11 additions & 1 deletion src/common/transformer_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ struct DecoderContext {
// # of splits (the same as NUMA node number in the system)
const int numSplit;

// For pipeline parallel and tensor parallel config
int ppSize = 1; // pipeline parallel stage size
int ppRank = 0; // pipeline parallel stage rank
int tpSize = 1; // tensor parallel size
int tpRank = 0; // tensor parallel rank

enum ActivationType { RELU, GELU, SWIGLU, SILU };
ActivationType actType;

Expand All @@ -105,7 +111,7 @@ struct DecoderContext {
public:
DecoderContext(int _layers, int _hiddenSize, int _attHeadNum, int _kvHeadNum, int _imSize, const std::string &act,
float epsilon, int _vocabSize, int _embeddingSize, int _maxPositions, int _maxPosEmbed, int _maxSeqLength,
int _splitIdx, int _splits, RopeParams *_ropeParamsPtr = nullptr, int numThreads = 0)
int _splitIdx, int _splits, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr, int numThreads = 0)
: layers(_layers)
, hiddenSize(_hiddenSize)
, intermediateSize(_imSize)
Expand All @@ -119,6 +125,10 @@ struct DecoderContext {
, ropeParamsPtr(_ropeParamsPtr)
, splitIdx(_splitIdx)
, numSplit(_splits)
, ppSize(_ppSize)
, ppRank(_ppRank)
, tpSize(_splits)
, tpRank(_splitIdx)
, epsilon(epsilon) {
if (attHeadNum != 0) {
this->attHeadSize = hiddenSize / attHeadNum;
Expand Down
6 changes: 4 additions & 2 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ class Attention {
imBuffer.Assign(inputBuffer.Data(), inputBuffer.Rows(), inputBuffer.Cols(), inputBuffer.Stride());
inputBuffer.Assign(tmp, rows, cols, stride);
}

// TODO: refine the logic (and support large inputSeqLen when pastSeqLen > 0)
if constexpr (std::is_same_v<InT, bfloat16_t> && std::is_same_v<OutT, bfloat16_t>) {
if (pastSeqLen == 0) {
Expand All @@ -284,8 +285,9 @@ class Attention {
if (ctx->inputSeqLen >= 1024 && pastSeqLen == 0)
flashAttention(
ctx, qkvGroupMatMul, outBuffer, imBuffer, presentKey, presentValue, attnMask, pastSeqLen);
else
else {
fusedAttention(ctx, query, key, value, imBuffer, presentKey, presentValue, attnMask, pastSeqLen);
}
}
t4.release();

Expand Down Expand Up @@ -375,7 +377,7 @@ class Attention {
// to make sure it works better (the logic here is trying to make sure each head of BMM result [seq * seq] in cache)
// WARN: reserve field in context is used to make it effective for all layers, do not change it in other places
int &mBlockSize = ctx->reserved1;
if (layerId == 0) {
if (layerId % (ctx->layers / ctx->ppSize) == 0) {
// TODO: if pastSeqLen > 0 and inputSeqLen large.
if (pastSeqLen == 0) {
const int l2CacheSize = 2 * 1024 * 1024; // TODO: get it dynamically
Expand Down
7 changes: 7 additions & 0 deletions src/models/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,10 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} MODEL_SRCS)

add_library(models OBJECT ${MODEL_SRCS})
add_dependencies(models utils)

if(WITH_PIPELINE_PARALLEL)
find_package(MPI REQUIRED)
include_directories(${MPI_INCLUDE_PATH})
add_definitions(${MPI_CXX_COMPILE_FLAGS})
target_link_libraries(models ${MPI_CXX_LIBRARIES})
endif()
Loading