Skip to content

Commit

Permalink
add if-else use_new_executor branch (#6897)
Browse files Browse the repository at this point in the history
* add if-else use_new_executor branch

* no pre-commit reformat

* add () for use_new_executor

* align old executor and new executor loss

* fix outs[loss].shape == () can't get value by outs[loss][-1]

* tiny format fix

* restore the way loss is calculated by new exe

* old executor doesn't have to / accumulate_steps

* dataloader_from_generator add param num_workers=1
  • Loading branch information
Wennie396 authored Sep 19, 2023
1 parent 27e25e6 commit 321faf3
Showing 1 changed file with 131 additions and 54 deletions.
185 changes: 131 additions & 54 deletions model_zoo/gpt-3/ppfleetx/core/engine/auto_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
from ppfleetx.utils.log import convert_timestamp_to_data, get_timestamp, logger
from ppfleetx.utils.version import version_check

def use_new_executor():
new_executor_micro_batching = os.environ.get(
'FLAGS_new_executor_micro_batching', None
)
return new_executor_micro_batching in [
1,
'1',
True,
'True',
'true',
]

class AutoEngine(BasicEngine):
def __init__(self, configs, module=None, mode="train"):
Expand Down Expand Up @@ -152,7 +163,10 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade

total_train_batch = self._max_steps if self._run_mode == "step" else len(train_data_loader)
total_train_step = self._max_steps if self._run_mode == "step" else total_train_batch * self._num_train_epochs
total_eval_batch = len(valid_data_loader) if valid_data_loader is not None else 0
if use_new_executor():
total_eval_batch = len(valid_data_loader) if valid_data_loader is not None else 0
else:
total_eval_batch = valid_data_loader._steps if valid_data_loader is not None else 0
valid_data_loader = valid_data_loader if valid_data_loader is not None else None
eval_finished_step = 0

Expand All @@ -163,26 +177,41 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade
if step < self._load_recovery["step"]:
continue

batches = self._validate_batch(batch)

fetch_list = None
if self._strategy.amp.enable:
# fetch_list = ["find_infinite_scale.tmp_0", "loss_scaling_0"]
fetch_list = []

final_loss = None
for micro_batch in batches:
with paddle.profiler.utils._nvprof_range(iter_id=step, start=self.nvprof_start, end=self.nvprof_end):
outs = self._auto_engine.run(micro_batch, fetch_list=fetch_list, mode="train")
# pp: some devices don't have loss in outs
if "loss" in outs:
if final_loss is None:
final_loss = np.sum(outs["loss"])
else:
final_loss += np.sum(outs["loss"])

if final_loss is not None and self._accumulate_steps > 1:
final_loss /= self._accumulate_steps
if use_new_executor():
batches = self._validate_batch(batch)
for micro_batch in batches:
with paddle.profiler.utils._nvprof_range(iter_id=step, start=self.nvprof_start, end=self.nvprof_end):
outs = self._auto_engine.run(micro_batch, fetch_list=fetch_list, mode="train")
# pp: some devices don't have loss in outs
if "loss" in outs:
if final_loss is None:
final_loss = np.sum(outs["loss"])
else:
final_loss += np.sum(outs["loss"])

if final_loss is not None and self._accumulate_steps > 1:
final_loss /= self._accumulate_steps
else:
if self._pp_degree == 1 and self._accumulate_steps > 1: # gradient merge
local_steps = self._accumulate_steps
else:
local_steps = 1
for _ in range(local_steps):
with paddle.profiler.utils._nvprof_range(iter_id=step, start=self.nvprof_start, end=self.nvprof_end):
outs = self._auto_engine.run(batch, fetch_list=fetch_list, mode="train")
# pp: some devices don't have loss in outs
if "loss" in outs:
if final_loss is None:
final_loss = np.sum(outs["loss"])
else:
final_loss += np.sum(outs["loss"])

if final_loss is not None:
train_losses.append(final_loss)
Expand Down Expand Up @@ -267,27 +296,49 @@ def fit(self, epoch=1, train_dataset=None, valid_dataset=None):

train_data_loader, valid_data_loader = None, None
if train_dataset:
train_data_loader = self._auto_engine.dataloader(
dataset=train_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=train_dataset.collate_fn,
num_workers=1,
sample_split=train_dataset.sample_split,
mode="train",
)
if use_new_executor():
train_data_loader = self._auto_engine.dataloader(
dataset=train_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=train_dataset.collate_fn,
num_workers=1,
sample_split=train_dataset.sample_split,
mode="train",
)
else:
train_data_loader = self._auto_engine.dataloader_from_generator(
dataset=train_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=train_dataset.collate_fn,
sample_split=train_dataset.sample_split,
mode="train",
)
if valid_dataset and self._eval_freq <= self._max_steps:
valid_data_loader = self._auto_engine.dataloader(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
num_workers=1,
sample_split=valid_dataset.sample_split,
mode="eval",
)
if use_new_executor():
valid_data_loader = self._auto_engine.dataloader(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
num_workers=1,
sample_split=valid_dataset.sample_split,
mode="eval",
)
else:
valid_data_loader = self._auto_engine.dataloader_from_generator(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
sample_split=valid_dataset.sample_split,
mode="eval",
)

for epoch_index in range(start_epoch, epoch):
train_epoch_start = get_timestamp()
Expand Down Expand Up @@ -320,6 +371,8 @@ def fit(self, epoch=1, train_dataset=None, valid_dataset=None):
convert_timestamp_to_data(get_timestamp() - train_start)
)
)
if valid_data_loader and not use_new_executor():
valid_data_loader._inner_dataloader.reset()

if self.profiler:
self._profiler_done()
Expand All @@ -328,16 +381,28 @@ def evaluate(self, epoch=1, valid_dataset=None):

valid_data_loader = None
if valid_dataset:
valid_data_loader = self._auto_engine.dataloader(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
num_workers=1,
sample_split=valid_dataset.sample_split,
mode="eval",
)
if use_new_executor():
valid_data_loader = self._auto_engine.dataloader(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
num_workers=1,
sample_split=valid_dataset.sample_split,
mode="eval",
)
else:
valid_data_loader = self._auto_engine.dataloader_from_generator(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
num_workers=1,
sample_split=valid_dataset.sample_split,
mode="eval",
)

for epoch_index in range(epoch):
eval_epoch_start = get_timestamp()
Expand Down Expand Up @@ -388,16 +453,28 @@ def predict(self, epoch=1, test_dataset=None):

test_data_loader = None
if test_dataset:
test_data_loader = self._auto_engine.dataloader(
dataset=test_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=test_dataset.collate_fn,
num_workers=1,
sample_split=test_dataset.sample_split,
mode="predict",
)
if use_new_executor():
test_data_loader = self._auto_engine.dataloader(
dataset=test_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=test_dataset.collate_fn,
num_workers=1,
sample_split=test_dataset.sample_split,
mode="predict",
)
else:
test_data_loader = self._auto_engine.dataloader_from_generator(
dataset=test_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=test_dataset.collate_fn,
num_workers=1,
sample_split=test_dataset.sample_split,
mode="predict",
)

test_start = get_timestamp()
test_losses = []
Expand Down

0 comments on commit 321faf3

Please sign in to comment.