Skip to content

[Layer] Enable pipeline parallel feature. #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Feb 19, 2024

Conversation

changqi1
Copy link
Contributor

@changqi1 changqi1 commented Feb 7, 2024

Usages:

  1. build with cmake .. -DWITH_PIPELINE_PARALLEL=ON to add MPI support
  2. Add XFT_PIPELINE_STAGE Marco to define pipeline parallel stages num.
Pipeline parallel and tensor parallel introduction:
  1) MPI_Instances = 16,XFT_PIPELINE_STAGE = 4  =>  ctx->ppSize = 4, ctx->tpSize = 4
  2) TP sync by oneCCL(row_comm) or shared_memory
  3) PP sync by MPI MPI_COMM_WORLD
  World Rank:      => Row Rank:       => Rank:  tp0 tp1 tp2 tp3
  [ 0,  1,  2,  3,    [ 0, 1, 2, 3];      pp0 [  0,  1,  2,  3];
    4,  5,  6,  7,    [ 0, 1, 2, 3];      pp1 [  0,  1,  2,  3];
    8,  9, 10, 11,    [ 0, 1, 2, 3];      pp2 [  0,  1,  2,  3];
   12, 13, 14, 15];   [ 0, 1, 2, 3];      pp3 [  0,  1,  2,  3];

                                      Prompts
                                         │
            ┌──────────────────┬─────────┴────────┬──────────────────┐
            │                  │                  │                  │
            ▼                  ▼                  ▼                  ▼
       Embedding(PP0)     Embedding(PP0)     Embedding(PP0)     Embedding(PP0)
            │                  │                  │                  │
  PP0       │                  │                  │                  │
  ┌─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┐
  │ TP0     │          TP1     │          TP2     │          TP3     │    layer0-7  │
  │ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐     │
  │ │ OMP            │ │ OMP            │ │ OMP            │ │ OMP            │     │
  │ │ │ │ │ │ │ │    │ │ │ │ │ │ │ │    │ │ │ │ │ │ │ │    │ │ │ │ │ │ │ │    │     │
  │ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│     │
  │ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘     │
  │ ┌───────┼──────────────────┼─────AllReduce────┼──────────────────┼────────┐     │
  │ └───────┼──────────────────┼──────────────────┼──────────────────┼────────┘     │
  └─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┘
  PP1       │ MPI Send/Recv    │                  │                  │
  ┌─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┐
  │ TP0     │          TP1     │           TP2    │            TP3   │   layer8-15  │
  │ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐     │
  │ │ OMP            │ │ OMP            │ │ OMP            │ │ OMP            │     │
  │ │ │ │ │ │ │ │    │ │ │ │ │ │ │ │    │ │ │ │ │ │ │ │    │ │ │ │ │ │ │ │    │     │
  │ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│     │
  │ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘     │
  │ ┌───────┼──────────────────┼─────AllReduce────┼──────────────────┼────────┐     │
  │ └───────┼──────────────────┼──────────────────┼──────────────────┼────────┘     │
  └─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┘
  PP2       │ MPI Send/Recv    │                  │                  │
  ┌─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┐
  │ TP0     │          TP1     │           TP2    │            TP3   │  layer16-23  │
  │ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐     │
  │ │ OMP            │ │ OMP            │ │ OMP            │ │ OMP            │     │
  │ │ │ │ │ │ │ │    │ │ │ │ │ │ │ │    │ │ │ │ │ │ │ │    │ │ │ │ │ │ │ │    │     │
  │ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│     │
  │ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘     │
  │ ┌───────┼──────────────────┼─────AllReduce────┼──────────────────┼────────┐     │
  │ └───────┼──────────────────┼──────────────────┼──────────────────┼────────┘     │
  └─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┘
  PP3       │ MPI Send/Recv    │                  │                  │
  ┌─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┐
  │ TP0     │          TP1     │           TP2    │            TP3   │  layer24-31  │
  │ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐ ┌───────▼────────┐     │
  │ │ OMP            │ │ OMP            │ │ OMP            │ │ OMP            │     │
  │ │ │ │ │ │ │ │    │ │ │ │ │ │ │ │    │ │ │ │ │ │ │ │    │ │ │ │ │ │ │ │    │     │
  │ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│ │ ▼ ▼ ▼ ▼ ▼ ▼ ...│     │
  │ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ └───────┬────────┘     │
  │ ┌───────┼──────────────────┼─────AllReduce────┼──────────────────┼────────┐     │
  │ └───────┼──────────────────┼──────────────────┼──────────────────┼────────┘     │
  └─────────┼──────────────────┼──────────────────┼──────────────────┼──────────────┘
            │                  │                  │                  │
            ▼                  ▼                  ▼                  ▼
       Predictor(PP3)     Predictor(PP3)     Predictor(PP3)     Predictor(PP3)
            │ MPI Send/Recv    │                  │                  │
            ▼                  ▼                  ▼                  ▼
       Searchers(PP0)     Searchers(PP0)     Searchers(PP0)     Searchers(PP0)
            │
            ▼
         Output
// pp=1, tp=2
$ XFT_PIPELINE_STAGE=1 OMP_NUM_THREADS=12 mpirun  \
    -n 1 numactl --all -C 48-59 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 60-71 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16

// pp=2, tp=1
$ XFT_PIPELINE_STAGE=2 OMP_NUM_THREADS=12 mpirun  \
    -n 1 numactl --all -C 48-59 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 60-71 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16
// pp=1, tp=4
$ XFT_PIPELINE_STAGE=1 OMP_NUM_THREADS=12 mpirun  \
    -n 1 numactl --all -C 48-59 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 60-71 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 72-83 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 84-95 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16

// pp=2, tp=2
$ XFT_PIPELINE_STAGE=2 OMP_NUM_THREADS=12 mpirun  \
    -n 1 numactl --all -C 48-59 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 60-71 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 72-83 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 84-95 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16

// pp=4, tp=1
$ XFT_PIPELINE_STAGE=4 OMP_NUM_THREADS=12 mpirun  \
    -n 1 numactl --all -C 48-59 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 60-71 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 72-83 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 84-95 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16
// pp=1, tp=8
$ XFT_PIPELINE_STAGE=1 OMP_NUM_THREADS=12 mpirun  \
    -n 1 numactl --all -C 48-59 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 60-71 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 72-83 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 84-95 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C  0-11 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 12-23 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 24-35 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 36-47 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16

// pp=2, tp=4
$ XFT_PIPELINE_STAGE=2 OMP_NUM_THREADS=12 mpirun  \
    -n 1 numactl --all -C 48-59 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 60-71 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 72-83 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 84-95 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C  0-11 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 12-23 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 24-35 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 36-47 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16

// pp=4, tp=2
$ XFT_PIPELINE_STAGE=4 OMP_NUM_THREADS=12 mpirun  \
    -n 1 numactl --all -C 48-59 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 60-71 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 72-83 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 84-95 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C  0-11 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 12-23 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 24-35 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 36-47 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16

// pp=8, tp=1
$ XFT_PIPELINE_STAGE=8 OMP_NUM_THREADS=12 mpirun  \
    -n 1 numactl --all -C 48-59 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 60-71 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 72-83 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 84-95 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C  0-11 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 12-23 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 24-35 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16 :  \
    -n 1 numactl --all -C 36-47 -m 1 ./example --model /data/qwen-1.8b-chat-xft/ --token /data/qwen-1.8b-chat-hf/tokenizer_config.json --dtype fp16 --loop 1 --input_len 16 --output_len 16

@changqi1 changqi1 marked this pull request as draft February 7, 2024 03:33
@changqi1 changqi1 marked this pull request as ready for review February 7, 2024 14:13
@changqi1 changqi1 requested a review from pujiang2018 February 7, 2024 14:15
@intelyoungway
Copy link

Amazing work!

@@ -14,7 +14,12 @@
# ============================================================================
cmake_minimum_required(VERSION 3.15.1)

find_package(MPI REQUIRED)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If oneCCL is not present in the user's environment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是,我环境中已经有oneCCL,但是model src中报没有MPI库

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在解耦了吧,src/models里面的代码不依赖于oneCCL 和MPI


int layers_per_pp_stage = layers / ctx->ppSize;
int start_layer = ctx->ppRank * layers_per_pp_stage;
for (int i = start_layer; i < start_layer + layers_per_pp_stage; ++i) {
Copy link
Contributor

@Duyi-Wang Duyi-Wang Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When layers is not divisible by ppSize, does it mean that a few layers (layers % ppSize) will not be processed? There is a warning but no termination if layers is not divisible by ppSize.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面的code已经报了error了,就不会执行这些code 了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用户自己设定的,通过都是可以整除了

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::cerr只是输出,并没有终止流程。不整除好像也是可以支持?后续代码中似乎并没有用到layers_per_pp_stage这种限定ppRank计算多少层的值?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

MPI_Recv(embBuf, batchSize * inputSeqLen * ctx->hiddenSize, MPI_FLOAT, prev_world_rank, curr_world_rank,
MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will reintroduce the MPI dependency into xft.so. It should be included in comm_helper.so and referenced through messager.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error: different scope when dynamic loading so file


public:
static void initPipeline() {
char *xft_pipeline_value = getenv("XFT_PIPELINE_STAGES");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if MPI_rank is divisible by ppStages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check in common_decoder.h

int embedding_world_rank = 0 * ctx->tpSize + ctx->tpRank;
int predictor_world_rank = (ctx->ppSize - 1) * ctx->tpSize + ctx->tpRank;
MPI_Send(this->nextTokens.data(), batchSize, MPI_INT32_T, embedding_world_rank, predictor_world_rank,
MPI_COMM_WORLD);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use messager and comm_helper.so to decouple MPI dependency.

@intel intel deleted a comment from Duyi-Wang Feb 8, 2024
@changqi1 changqi1 marked this pull request as draft February 8, 2024 07:29
@changqi1 changqi1 marked this pull request as ready for review February 18, 2024 05:11
@@ -176,16 +201,21 @@ class Messenger {
private:
int size;
int rank;
int color;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if color is a common concept. Is it easy to understand for others? add some comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -14,7 +14,12 @@
# ============================================================================
cmake_minimum_required(VERSION 3.15.1)

find_package(MPI REQUIRED)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在解耦了吧,src/models里面的代码不依赖于oneCCL 和MPI

@pujiang2018
Copy link
Contributor

Most code is clear, but MPI decouple is needed.

@changqi1
Copy link
Contributor Author

changqi1 commented Feb 18, 2024

Most code is clear, but MPI decouple is needed.

Have used compile macros PIPELINE_PARALLEL to decouple MPI.

@changqi1 changqi1 merged commit eea16a5 into intel:main Feb 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants