Skip to content

Commit

Permalink
Cherry pick from incubate/fleety (PaddlePaddle#67719)
Browse files Browse the repository at this point in the history
* add loger (PaddlePaddle#67447)

* sort parameters (PaddlePaddle#67622)

* add get_padding tensor method (PaddlePaddle#67565)

* add get_padding tensor method

* polish

* fix style

* add sharding v2 check

* make sync_rotate_logger lazy

---------

Co-authored-by: ShenLiang <1422485404@qq.com>
  • Loading branch information
sneaxiy and ForFishes authored Aug 28, 2024
1 parent 54fee11 commit 2f4212d
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
fused_parameters,
)

g_sharding_v2_check_zero_padding = int(
os.getenv("FLAGS_sharding_v2_check_zero_padding", "0")
)


def _is_trainable(param):
return not param.stop_gradient
Expand Down Expand Up @@ -720,8 +724,15 @@ def _build_comm_buffers(self, acc_steps, group_size=256 * 1024 * 1024):
and os.getenv("XPU_PADDLE_FUSE_SHARDING_BUFFER") is not None
):
group_size = 2**62

# NOTE(shenliang03): If comm_overlap is not used, the parameter list is sorted by data type to
# to reduce communication overhead.
all_params = self._parameter_list
if not self.comm_overlap:
all_params = sorted(all_params, key=lambda x: str(x.dtype))

comm_group = self._hcg.get_sharding_parallel_group()
var_groups = assign_group_by_size(self._parameter_list, group_size)
var_groups = assign_group_by_size(all_params, group_size)
for group_idx, parameters in var_groups.items():
buffer = FusedCommBuffer(
group_idx,
Expand Down Expand Up @@ -792,11 +803,24 @@ def reduce_gradients(self, parameter_list, hcg):
for param in comm_buffer.params:
comm_buffer._copy_grad_to_buffer(param)

if g_sharding_v2_check_zero_padding:
self._check_padding_zero()

for comm_buffer in self._comm_buffer_list:
if not self.comm_overlap:
comm_buffer._comm_grads()

comm_buffer.scale_grads()

def _check_padding_zero(self):
for comm_buffer in self._comm_buffer_list:
for k, v in comm_buffer._sharding_param_grad_view.items():
pad_tensor = v._get_padding()
if pad_tensor is not None:
assert paddle.all(
pad_tensor == 0
).item(), f"The padding of Tensor {k} is not zero"

def _forward_pre_hook_function(self, tasks):
def __impl__(x, y):
for task in tasks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
fused_allreduce_gradients,
unwrap_optimizer,
)
from ...utils.log_util import logger
from ...utils.log_util import logger, sync_rotate_logger
from ...utils.mix_precision_utils import MixPrecisionOptimizer

__all__ = []
Expand All @@ -45,6 +45,7 @@ def __init__(self, clip, hcg):
self.not_sharding_stage1 = True

def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist):
sync_rotate_logger().info("Starting to calculate global norm.")
# sharding first
sharding_flag = self._hcg.get_sharding_parallel_world_size() > 1
dp_flag = self._hcg.get_data_parallel_world_size() > 1
Expand Down Expand Up @@ -94,6 +95,8 @@ def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist):
group=self._hcg.get_pipe_parallel_group(),
)

sync_rotate_logger().info("Finished calculating global norm.")

@no_grad()
def _dygraph_clip(self, params_grads):
sum_square_dist_fp16 = []
Expand Down Expand Up @@ -376,6 +379,8 @@ def _filter_fn(self, param, strategy):
return False

def _step(self, parameters_list):
sync_rotate_logger().info("Starting hybridoptimizer step")

mp_group = self._hcg.get_model_parallel_group()
src_rank = self._hcg.get_model_parallel_group_src_rank()
params = None
Expand Down Expand Up @@ -411,11 +416,15 @@ def syc_grad(p):
p.grad, src_rank, mp_group, mp_configs.sync_mode
)

sync_rotate_logger().info("Starting mp grad sync")

# Grad sync before opt
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_grad:
for p in params:
syc_grad(p)

sync_rotate_logger().info("Finished mp grad sync")

self._inner_opt.step()

def syc_param(p):
Expand Down Expand Up @@ -477,6 +486,7 @@ def syc_moment(p):
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_moment:
for p in params:
syc_moment(p)
sync_rotate_logger().info("Finishing hybridoptimizer step")

def _hybrid_sync_grad(self, parameter_list):
dp_parameter_list = parameter_list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
broadcast_sep_parameters,
broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from ..utils.log_util import logger, sync_rotate_logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer

Expand Down Expand Up @@ -900,6 +900,7 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
def _forward_step(
self, input_tensor, micro_dataset, chunk_id=None, step_id=None
):
sync_rotate_logger().info("Before forward_step")
if self._enable_timer:
self.timers("forward_step").start()
if self.is_pipeline_first_stage():
Expand Down Expand Up @@ -957,6 +958,7 @@ def _forward_step(
self.micro_batch_id += 1
if self._enable_timer:
self.timers("forward_step").stop()
sync_rotate_logger().info("After forward_step")
if self.is_pipeline_last_stage() and self._compute_loss:
return backward_loss_tensor
return output_tensor
Expand All @@ -966,6 +968,7 @@ def _backward_step(
):
if self._enable_timer:
self.timers("backward_step").start()
sync_rotate_logger().info("Before backward_step")
with paddle.amp.auto_cast(enable=False):
self.callbacks.on_location(
PipelineParallelMicroStepLocations.BACKWARD_BEGIN,
Expand Down Expand Up @@ -1012,6 +1015,8 @@ def _backward_step(
output_tensor_grad=output_tensor_grad,
step_id=step_id,
)

sync_rotate_logger().info("After backward_step")
return input_tensor_grad

def _check_micro_batch_data_valid(self, micro_batch_data):
Expand Down Expand Up @@ -1355,6 +1360,7 @@ def forward_backward_pipeline(
static_scheduler=False,
return_micro_batch_loss=False,
):
sync_rotate_logger().info("start forward_backward_pipeline")
# use interleave scheduling strategy.
# this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
Expand Down Expand Up @@ -1908,6 +1914,8 @@ def _process_bwd_buffer(step_id, tensor):
self._p2p_helper.clear_meta_cache()

self.timer_printer()
sync_rotate_logger().info("end forward_backward_pipeline")

return train_loss

def train_batch(
Expand Down
45 changes: 45 additions & 0 deletions python/paddle/distributed/fleet/utils/log_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
# limitations under the License.

import logging
import os
from distutils.util import strtobool
from logging.handlers import RotatingFileHandler

import paddle
from paddle.distributed.utils.log_utils import get_logger

logger = get_logger("INFO", __name__)
Expand Down Expand Up @@ -70,3 +74,44 @@ def layer_to_str(base, *args, **kwargs):
name += ", ".join(f"{key}={value}" for key, value in kwargs.items())
name += ")"
return name


class DistributedLogger(logging.Logger):
def __init__(self, name, level=logging.NOTSET):
super().__init__(name, level)

def info(self, msg, *args, **kwargs):
if strtobool(os.getenv('FLAGS_distributed_debug_logger', '0')):
paddle.device.synchronize()
super().info(f"Distributed Debug: {msg}", *args, **kwargs)


def get_rotate_file_logger(log_level, name='root'):
distributed_logger = DistributedLogger(name + '_rotate', level=log_level)
distributed_logger.propagate = False

device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
log_dir = os.path.join(os.getcwd(), "hybrid_parallel")
os.makedirs(log_dir, exist_ok=True)

path = os.path.join(log_dir, f"worker_{device_id}.log")
handler = RotatingFileHandler(
path, maxBytes=2 * 1024 * 1024 * 1024, backupCount=3 # 2GB
)

log_format = logging.Formatter(
'[%(asctime)-15s] [%(levelname)8s] %(filename)s:%(lineno)s - %(message)s'
)
handler.setFormatter(log_format)
distributed_logger.addHandler(handler)
return distributed_logger


g_sync_rotate_logger = None


def sync_rotate_logger():
global g_sync_rotate_logger
if g_sync_rotate_logger is None:
g_sync_rotate_logger = get_rotate_file_logger("INFO", __name__)
return g_sync_rotate_logger
20 changes: 20 additions & 0 deletions python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,26 @@ def __init__(
# share param buffer
self._share_param_buffer()

def _get_padding(self):
if self._param_begin < self._param_end and self._slice_grad is not None:
padding_start = self._index + self._param._numel()
padding_end = self._index + self._padded_size
padding_start = max(self._param_begin, padding_start)
padding_end = min(self._param_end, padding_end)

if padding_start >= padding_end:
return None

padding = padding_end - padding_start
grad_numel = self._slice_grad._numel()
assert grad_numel >= padding, f"{grad_numel} vs {padding}"
padding_grad = self._slice_grad._slice(
grad_numel - padding, grad_numel
)
return padding_grad
else:
return None

def _slice_grad_from_buffer(self):
assert self._grad_buffer is not None
if self._param_begin < self._param_end:
Expand Down

0 comments on commit 2f4212d

Please sign in to comment.