Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ message RecomputeConfig {
repeated string checkpoints = 1;
optional bool enable_offload = 2 [ default = false ];
repeated int32 checkpoint_shape = 3;
optional bool enable_tuning = 4 [ default = false ]; // incubate for auto parallel
}

message ShardingConfig {
Expand All @@ -46,6 +47,7 @@ message ShardingConfig {
// Optimizer sharding. Temporary plans and may be deprecated
optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
optional int32 stage = 14 [ default = 1 ];
optional bool enable_tuning = 15 [ default = false ]; // incubate for auto parallel
}

message HybridConfig {
Expand Down
141 changes: 122 additions & 19 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def __init__(self,
inputs_spec=None,
labels_spec=None,
cluster=None,
strategy=None):
strategy=None,
user_tuning_config=None):
self.model = model
self.inputs_spec = self._validate_spec(inputs_spec)
self.labels_spec = self._validate_spec(labels_spec)
Expand All @@ -68,6 +69,7 @@ def __init__(self,
self.strategy = strategy
if self.strategy is None:
self.strategy = fleet.DistributedStrategy()
self._user_tuning_config = user_tuning_config

self._executor = None
self._cur_rank = paddle.distributed.get_rank()
Expand Down Expand Up @@ -127,19 +129,21 @@ def prepare(self,
self._prepare_single_mode("train")

def _prepare_single_mode(self, mode):
self._modes = [mode]
self._build(self._modes[0])
# Do auto parallel process
for mode in self._modes:
# Do the planning process
self._plan(mode)
for mode in self._modes:
# Do the parallel process
self._parallel(mode, self._all_ranks)

# Init comm and startup program
self._initialize(mode)
self._mode_init_states[mode] = True

self._build(mode)
# Do the planning process
self._plan(mode)

# Do the Optimization tuning
if self._user_tuning_config and mode == "train":
self._optimization_tuning(mode)

# Do the parallel process
self._parallel(mode, self._all_ranks)

# Init comm and startup program
self._initialize(mode)
self._mode_init_states[mode] = True

def _build(self, mode):
if _non_static_mode() or self._dygraph_mode:
Expand Down Expand Up @@ -174,6 +178,7 @@ def _build(self, mode):
metrics = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
# FIXME to support grad clip
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
inputs_spec = self.inputs_spec
Expand Down Expand Up @@ -204,12 +209,41 @@ def _build(self, mode):
"metrics": metrics
}

self._set_recompute_ckpts()
self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
self._dist_contexts[mode]._dygraph_mode = self._dygraph_mode

def _optimization_tuning(self, mode):

self.mode = mode
assert "batch_size" in self._user_tuning_config, "Optimization Tuning should provide with batch size."
assert "dataset" in self._user_tuning_config, "Optimization Tuning should provide with dataset."
batch_size = self._user_tuning_config["batch_size"]
dataset = self._user_tuning_config["dataset"]
dataset.dp_world_size = self._dp_world_size
dataset.dp_rank = self._dp_rank

from .tuner.optimization_tuner import OptimizationTuner
self._optimization_tuner = OptimizationTuner(self._user_tuning_config,
self._dist_contexts[mode],
dataset,
self.inputs_spec,
self.labels_spec,
batch_size=batch_size,
rank=self._cur_rank)

self._optimization_tuner.tune()

if self._user_tuning_config["run_after_tuning"]:
# update the strategy
self._dist_contexts[
mode]._strategy = self._optimization_tuner.get_best_config()
else:
return

def _plan(self, mode):
if self._planned_mode is None:
self._planned_mode = mode
Expand All @@ -219,6 +253,18 @@ def _plan(self, mode):
self._planners[mode] = Planner(mode, self._dist_contexts[mode])
self._planners[mode].plan()

# infer data parallel info
inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
block = self._dist_contexts[mode].serial_main_program.global_block()
feed_list = []
for var in inputs_var + labels_var:
if var.name in block.vars:
feed_list.append(block.vars[var.name])

self._dp_world_size, self._dp_rank = self._get_data_parallel_info(
feed_list[0], self._dist_contexts[mode])

def _parallel(self, mode, all_ranks):
# Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner,
Expand Down Expand Up @@ -317,6 +363,40 @@ def _initialize(self, mode):
prune_startup_prog = dist_startup_prog._prune(uninitialized)
self._executor.run(prune_startup_prog)

if self.strategy.amp and self.strategy.amp_configs['use_pure_fp16']:
# from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16
def cast_parameters_to_fp16(place,
program,
scope=None,
to_fp16_var_names=None):
"""
Traverse all parameters in the whole model and set them to the FP16 data type.
Whereas, this function will keep parameters of batchnorms in FP32.
Args:
place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the FP16 weight tensors.
program (Program): The used program.
scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values.
Default is None.
to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names`
will be set to FP16. Usually, it is the returned
value of `cast_model_to_fp16` API.
"""
from paddle.framework import core
import numpy as np
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())

var_scope = scope if scope else paddle.static.global_scope()
for param in all_parameters:
if param.dtype == core.VarDesc.VarType.FP16:
param_t = var_scope.find_var(
param.name).get_tensor()
data = np.array(param_t)
param_t.set(np.float16(data), place)

cast_parameters_to_fp16(self._place, prune_startup_prog)

def fit(self,
train_data,
batch_size=1,
Expand All @@ -342,7 +422,6 @@ def fit(self,
usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch)

for epoch in range(epochs):
train_logs = {"epoch": epoch}
for step, _ in enumerate(train_dataloader):
Expand Down Expand Up @@ -457,8 +536,6 @@ def _create_dataloader(self,
for var in inputs_var + labels_var:
if var.name in dist_main_block.vars:
feed_list.append(dist_main_block.vars[var.name])
dp_world_size, dp_rank = self._get_data_parallel_info(
feed_list[0], dist_context)

# remove the first three ops if multi run fit/evaluate/predict
op_size = len(dist_main_block.ops)
Expand All @@ -477,8 +554,8 @@ def _create_dataloader(self,
batch_size,
epochs,
steps_per_epoch,
data_parallel_world_size=dp_world_size,
data_parallel_rank=dp_rank)
data_parallel_world_size=self._dp_world_size,
data_parallel_rank=self._dp_rank)

# move read op from the end of program to the start of program
new_op_size = len(dist_main_block.ops)
Expand Down Expand Up @@ -561,6 +638,32 @@ def _get_data_parallel_info(self, var, dist_context):

return None, None

def _set_recompute_ckpts(self):
# NOTE hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here

config = self.strategy.recompute_configs

# extract ckpts by specific model
self.model
if isinstance(self.model, paddle.nn.Layer):
if hasattr(
self.model, "model"
) and self.model.model.__class__.__name__ == 'GPTForPretraining':
exact_ckpts = self.model.model.gpt.checkpoints
else:
exact_ckpts = config["checkpoints"]

# modify strategy
if self.strategy.recompute:
config["checkpoints"] = exact_ckpts[:]
self.strategy.recompute_configs = config
logs = {
'Model Class': self.model.model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts
}
self._logger.info(logs)

def save(self, path, training=True, mode=None):
if not mode:
mode = self.mode
Expand Down
1 change: 0 additions & 1 deletion python/paddle/distributed/auto_parallel/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from .dist_op import DistributedOperator
from .dist_tensor import DistributedTensor
from .planner import Planner
from paddle.distributed.passes import new_pass, PassContext

_logger = get_logger(logging.INFO)

Expand Down
12 changes: 10 additions & 2 deletions python/paddle/distributed/auto_parallel/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ def get_world_process_group():
return _g_process_group_map[0]


def new_process_group(ranks):
def clear_all_process_groups():
global _g_process_group_map
_g_process_group_map = {}
_g_process_group_map[0] = ProcessGroup(0, [])


def new_process_group(ranks, group_id=None):
global _g_process_group_map
# A key constructed from ranks is used for avoiding duplication
new_key = ''.join(map(str, sorted(ranks)))
Expand All @@ -54,7 +60,9 @@ def new_process_group(ranks):
num_groups = len(_g_process_group_map)
# Note: our process group may interfere with the original implementation
# so the created group id should start from the original _new_ring_id()
group_id = _new_ring_id() + num_groups + 1
if group_id == None:
group_id = _new_ring_id() + num_groups + 1

new_pg = ProcessGroup(group_id, ranks)
_g_process_group_map[group_id] = new_pg
return new_pg
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/distributed/auto_parallel/tuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .profiler import profiler

__all__ = []
Loading