Skip to content

Commit

Permalink
Merge pull request #291 from Tencent/develop
Browse files Browse the repository at this point in the history
update readme adding superPod multi-node results.
  • Loading branch information
feifeibear authored Dec 21, 2021
2 parents bd40a88 + a5a8e6a commit bc9683f
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 30 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ We also evaluated PatrickStar v0.4.3 on a single node of A100 SuperPod. It is ab

Detail benchmark results on WeChat AI data center as well as NVIDIA SuperPod are posted on this [Google Doc](https://docs.google.com/spreadsheets/d/136CWc_jA_2zC4h1r-6dzD4PrOvp6aw6uCDchEyQv6sE/edit?usp=sharing).


Scale PatrickStar to multiple machine (node) on SuperPod.
We succeed to train a GPT3-175B on 32 GPU. As far as we known, it is the first work
to run GPT3 on such small GPU cluster.
Microsoft used 10,000 V100 to pertrain GPT3.
Now you can finetune it or even pretrain your own one on 32 A100 GPU, amazing!

![alt perf](./doc/m_node_superpod.png "performance testing result on multiple Node of SuperNode")


We've also trained the [CLUE-GPT2](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) model with PatrickStar, the loss and accuracy curve is shown below:

![CLUE-GPT2](./doc/clue-gpt2-loss-n-acc.png)
Expand Down
Binary file added doc/m_node_superpod.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/one_node_perf_a100.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
32 changes: 25 additions & 7 deletions examples/benchmark/process_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

import os
import sys
import numpy as np
from scipy.stats import t


def is_run_this_file(path, file, res_dict, file_dict):
Expand All @@ -48,6 +50,8 @@ def is_run_this_file(path, file, res_dict, file_dict):

f = open(path + "/" + file)
is_run = True

perf_list = np.array([])
if not os.path.isdir(file):
fn_list = file.split(".")[1].split("_")
for i in range(len(fn_list)):
Expand All @@ -62,17 +66,31 @@ def is_run_this_file(path, file, res_dict, file_dict):
if "Tflops" in line and "WARM" not in line:
sline = line.split()
perf = float(sline[-2])
if key not in res_dict:
res_dict[key] = perf
file_dict[key] = file
else:
if res_dict[key] < perf:
res_dict[key] = perf
file_dict[key] = file

perf_list = np.append(perf_list, perf)

is_run = False
if "RuntimeError" in line:
return False

if len(perf_list) == 0:
return False

# calculate CI of perf_list
perf_list = perf_list[1:-1]
m = perf_list.mean()
s = perf_list.std()
dof = len(perf_list) - 1
confidence = 0.95
t_crit = np.abs(t.ppf((1 - confidence) / 2, dof))
ic_perf = (
-s * t_crit / np.sqrt(len(perf_list)),
+s * t_crit / np.sqrt(len(perf_list)),
)

res_dict[key] = (*ic_perf, m)
file_dict[key] = file

return is_run


Expand Down
10 changes: 10 additions & 0 deletions examples/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ def model_config(model_name):
SEQ_LEN = 1024
NUM_LAYER = 96
NUM_HEAD = 96
elif model_name == "GPT_220B":
HIDDEN_DIM = 12288
SEQ_LEN = 1024
NUM_LAYER = 120
NUM_HEAD = 96
elif model_name == "GPT_250B":
HIDDEN_DIM = 12288
SEQ_LEN = 1024
NUM_LAYER = 137
NUM_HEAD = 96
elif model_name == "GPT_310B":
HIDDEN_DIM = 16384
SEQ_LEN = 1024
Expand Down
2 changes: 1 addition & 1 deletion examples/pretrain_bert_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_transformer_model_helper(
is_ckp=use_ckp,
is_fp16=use_fp16,
dist_plan=dist_plan,
num_steps=5,
num_steps=20,
)
print("*" * 20 + " LOSS " + "*" * 20)
print(f"{loss_list}")
Expand Down
23 changes: 20 additions & 3 deletions examples/run_transformers.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,25 @@ export MEM_PROF=${MEM_PROF:-0}
# asyn memory monitor for mem sampler
export AMM=${AMM:-1}
# mem saving comm
export MSC=${MSC:-0}
export MSC=${MSC:-1}
# mem caching comm
export CACHE=${CACHE:-1}
# async move
export ASYNC_MOVE=${ASYNC_MOVE:-0}
# linear tiling comm
export TILING=${TILING:-0}
# hybrid adam
export HYB=${HYB:-1}

export LOCAL_WORLD_SIZE=${LOCAL_WORLD_SIZE:-1}
export CS_SEARCH=${CS_SEARCH:-0}

export NNODES=${NNODES:-1}
export NODE_RANK=${NODE_RANK:-0}
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
export MASTER_PORT=${MASTER_PORT:-"12345"}
export SUFFIX=${SUFFIX:-""}

if [[ ${TILING} == 1 ]]; then
TILING_FLAG="--with_tiling_linear"
else
Expand Down Expand Up @@ -109,13 +118,20 @@ else
fi

let CHUNK_SIZE=${CS}*1024*1024
export HYBRID_ADAM_FLAG="--use_hybrid_adam"

if [[ ${HYB} == 1 ]]; then
export HYBRID_ADAM_FLAG="--use_hybrid_adam"
else
export HYBRID_ADAM_FLAG=""
fi



LOG_DIR="./logs_${MODEL_NAME}"
mkdir -p ${LOG_DIR}

GIT_VER=`git rev-parse --short=5 HEAD`
LOG_FILE="log.${MODEL_NAME}_gpu_${GPU_NUM}_cs_${CS}_bs_${BS}_cpueb_${CPU_EBD}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}"
LOG_FILE="log.${MODEL_NAME}_gpu_${GPU_NUM}_cs_${CS}_bs_${BS}_cpueb_${CPU_EBD}_hyb_${HYB}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}_node_${NNODES}_${SUFFIX}"

is_run_flag=`python ./benchmark/is_run_this_file.py --path "${LOG_DIR}" --file "${LOG_FILE}"`
echo is_run_flag $is_run_flag
Expand Down Expand Up @@ -183,6 +199,7 @@ python -m torch.distributed.launch --nproc_per_node=1 \
done
else
env OMP_NUM_THREADS=${TNUM} timeout -s SIGKILL 30m python -m torch.distributed.launch --nproc_per_node=${GPU_NUM} \
--nnodes=${NNODES} --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
pretrain_bert_demo.py \
--default_chunk_size=${CHUNK_SIZE} \
${cmd_opts} \
Expand Down
18 changes: 11 additions & 7 deletions patrickstar/core/chunk_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from patrickstar.core.const import ChunkType
from patrickstar.core.memtracer import RuntimeMemTracer
from patrickstar.profiler import profiler
from patrickstar.utils import logger, get_rank, get_world_size
from patrickstar.utils import logger, get_rank, get_world_size, log_dist
import logging
import patrickstar.utils.global_timer as global_timer
from .chunk_data import Chunk
from .comm import CommInfo
Expand Down Expand Up @@ -216,23 +217,26 @@ def prepare_device(self, target_device: torch.device, need_bytes: int):
target_device.type
)

logger.debug(
log_dist(
f"prepare_target: device {target_device} need_bytes {need_bytes / 1e6} MB, "
f"ava_chunk_mem_size {ava_chunk_mem_size / 1e6} MB, "
f"remaining_chunk_mem_size {remaining_chunk_mem_size / 1e6} MB."
f"remaining_chunk_mem_size {remaining_chunk_mem_size / 1e6} MB.",
level=logging.DEBUG,
)

# TODO(jiaruifang) Situation where there is no space.
# This condition is not good enough, we need to check if botn CPU and GPU
# don't have enough space.
if ava_chunk_mem_size < need_bytes:
logger.warning(
f"{target_device} has not enough space for {need_bytes} elements"
log_dist(
f"{target_device} has not enough space for {need_bytes} elements",
level=logging.WARNING,
)
logger.warning(
log_dist(
f"{target_device} has not enough space for {need_bytes / 1e6} MB. "
f"Device used Chunk Memory is {self.get_chunk_memory_used(target_device) / 1e6} MB. "
f"Avaibale Chunk Memory is {ava_chunk_mem_size / 1e6} MB"
f"Avaibale Chunk Memory is {ava_chunk_mem_size / 1e6} MB",
level=logging.WARNING,
)
if self._time_profile:
global_timer.my_timer.finish_profile("CHUNK_LIST_prepare_device")
Expand Down
13 changes: 10 additions & 3 deletions patrickstar/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def __init__(self, rank: int, default_chunk_size: int, config=None):
tracer_config = default_tracer_config
opt_config = default_opt_config

self.mem_tracer = RuntimeMemTracer(self.local_rank, tracer_config)
self.mem_tracer = RuntimeMemTracer(
self.local_rank, tracer_config, opt_config["with_mem_saving_comm"]
)
self.opt_config = opt_config

self.chunk_eviction_strategy = LatestAccessChunkEvictionPolicy(
Expand Down Expand Up @@ -396,6 +398,7 @@ def _fetch_remote_chunks(
# If the gpu owns the chunk (local rank), access it.
# If the gpu do not own the chunk (remote chunk), allocate memory.
if src_rank == rank:
self.chunk_eviction_strategy.trace_access(chunk_id, compute_device)
self.chunk_list.access_chunk(chunk_id, compute_device)
else:
self.chunk_list.try_best_allocate_payload(
Expand Down Expand Up @@ -447,6 +450,7 @@ def _fetch_remote_chunks(

# Use collective communication to achieve the most efficient communication.
# However, it is memory consumping. world_size chunks on GPU simutaneously.
self.chunk_eviction_strategy.trace_access(local_chunk_id, compute_device)
self.chunk_list.access_chunk(local_chunk_id, compute_device)
self.chunk_list[local_chunk_id].pin()
allgather_payload_buff = []
Expand Down Expand Up @@ -493,6 +497,7 @@ def _fetch_remote_chunks(
global_timer.my_timer.finish_profile("CLIENT_fetch_remote_chunks")

def _access_tensor_in_chunk(self, param, access_type, compute_device, chunk_id):
self.chunk_eviction_strategy.trace_access(chunk_id, compute_device)
self.chunk_list.access_chunk(chunk_id, compute_device)
# 2. Locate the param on the chunk.
tensor_id = param.ps_attr.get_tensor_id(access_type)
Expand Down Expand Up @@ -584,7 +589,7 @@ def access_dist(
local_chunk_id = chunk_id

# collect the time a chunk has to be placed on compute-device
self.chunk_eviction_strategy.trace_access(local_chunk_id, compute_device)
# self.chunk_eviction_strategy.trace_access(local_chunk_id, compute_device)

ret = self._access_tensor_in_chunk(param, access_type, compute_device, chunk_id)
if self._time_profile:
Expand Down Expand Up @@ -640,7 +645,7 @@ def access(
chunk_id = self.chunk_tensor_index.get_chunk_id(param, access_type)

# collect the time a chunk has to be placed on compute-device
self.chunk_eviction_strategy.trace_access(chunk_id, compute_device)
# self.chunk_eviction_strategy.trace_access(chunk_id, compute_device)

if chunk_id is None:
raise RuntimeError(
Expand Down Expand Up @@ -763,6 +768,7 @@ def release_dist(
break
if do_allreduce:
# move the chunk_id to GPU
self.chunk_eviction_strategy.trace_access(chunk_id, self.device)
self.chunk_list.access_chunk(chunk_id, self.device)
if self._time_profile:
global_timer.my_timer.start_profile(
Expand Down Expand Up @@ -818,6 +824,7 @@ def release_dist(
assert self.chunk_list[local_chunk_id].payload is not None
input_list = []
for i in chunk_id_list:
self.chunk_eviction_strategy.trace_access(i, self.device)
self.chunk_list.access_chunk(i, self.device)
self.chunk_list[i].pin()
input_list.append(self.chunk_list[i].payload)
Expand Down
11 changes: 8 additions & 3 deletions patrickstar/core/eviction_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from queue import PriorityQueue
from patrickstar.core.memtracer import Metronome
from patrickstar.core.const import ChunkState
from patrickstar.utils import logger
from patrickstar.utils import log_dist
import logging


class ChunkEvictionPolicyBase(ABC):
Expand Down Expand Up @@ -112,6 +113,8 @@ def derive_eviction_list(self, id_to_chunk_map, need_bytes, target_device):
chunk.get_device() is not None
and chunk.get_device().type == target_device.type
and chunk.get_state() != ChunkState.COMPUTE
and chunk.get_state() != ChunkState.RELEASED
and chunk.get_state() != ChunkState.FREE
and not chunk.is_pin()
):
# The next moment when this chunk was accessed.
Expand All @@ -133,10 +136,12 @@ def derive_eviction_list(self, id_to_chunk_map, need_bytes, target_device):

# Raise error when failed to make enough room.
if moved_bytes < need_bytes:
logger.warning(
log_dist(
f"device {target_device} still needs {need_bytes / 1e6} MB, "
f"but there is not enough space on it, only {moved_bytes / 1e6} MB available. "
f"movable_chunk_info {movable_chunk_info}"
f"movable_chunk_info {movable_chunk_info}",
[0],
logging.WARNING,
)
return moved_list

Expand Down
17 changes: 11 additions & 6 deletions patrickstar/core/memtracer/memtracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
log_dist,
get_memory_info,
get_sys_memory_used,
get_world_size,
get_local_world_size,
logger,
get_world_size,
)
from patrickstar.core.memtracer.metronome import Metronome
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -95,7 +95,9 @@ class RuntimeMemTracer(object):
Chunkable Memry: Memory can be used to store chunk.
"""

def __init__(self, local_rank: int = 0, config=None):
def __init__(
self, local_rank: int = 0, config=None, with_mem_saving_comm: bool = False
):
self.local_rank = local_rank
self.metronome = Metronome()
self.gpu_chunk_available_mem = 0
Expand All @@ -104,7 +106,7 @@ def __init__(self, local_rank: int = 0, config=None):
self.gpu_chunk_used_mem = 0
self.cpu_chunk_used_mem = 0
self.cpu_chunk_used_mem_pinned = 0

self.with_mem_saving_comm = with_mem_saving_comm
if config is not None:
self._overall_gpu_mem_ratio = config.get("overall_gpu_mem_ratio", 0.8)
self._overall_cpu_mem_ratio = config.get("overall_cpu_mem_ratio", 0.8)
Expand Down Expand Up @@ -395,7 +397,10 @@ def available_chunk_mem(self, device_type):
else:
return self._overall_cpu_mem
elif device_type == "cuda":
world_size = get_world_size()
if self.with_mem_saving_comm:
msc_factor = 1
else:
msc_factor = get_world_size()
if self.metronome.training_stage() == TrainingStage.ADAM:
return self._overall_gpu_mem - 4 * self._default_chunk_size * 4
elif self.metronome.training_stage() == TrainingStage.FWD:
Expand All @@ -409,7 +414,7 @@ def available_chunk_mem(self, device_type):
)
return (
min(next_mom_ava_mem, cur_mom_ava_mem)
- world_size * 2 * self._default_chunk_size
- msc_factor * 2 * self._default_chunk_size
)
elif self.metronome.training_stage() == TrainingStage.BWD:
next_mom = self.metronome.next_moment()
Expand All @@ -422,5 +427,5 @@ def available_chunk_mem(self, device_type):
)
return (
min(next_mom_ava_mem, cur_mom_ava_mem)
- world_size * 2 * self._default_chunk_size
- msc_factor * 2 * self._default_chunk_size * msc_factor
)

0 comments on commit bc9683f

Please sign in to comment.