Skip to content

Commit

Permalink
[Recompute] Support ernie for dygraph recompute. (PaddlePaddle#2849)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHUI authored Aug 10, 2022
1 parent 2a4a2fb commit 7ce81e1
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 29 deletions.
54 changes: 31 additions & 23 deletions model_zoo/ernie-1.0/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
from paddle.io import DataLoader, Dataset
from visualdl import LogWriter

Expand Down Expand Up @@ -327,6 +328,7 @@ def do_train(args):
model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
model_config[
"attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
model_config["enable_recompute"] = args.use_recompute
model = model_class(base_class(**model_config))
else:
model = model_class.from_pretrained(
Expand Down Expand Up @@ -462,33 +464,39 @@ def do_train(args):
input_ids, segment_ids, input_mask, masked_lm_positions, \
masked_lm_labels, next_sentence_labels = batch

with paddle.amp.auto_cast(args.use_amp,
custom_black_list=[
"reduce_sum",
"c_softmax_with_cross_entropy",
"elementwise_div"
],
level='O2'):

# Create the model for the ernie pretrain
prediction_scores, seq_relationship_score = model(
input_ids=input_ids,
token_type_ids=segment_ids,
position_ids=None,
attention_mask=input_mask,
masked_positions=masked_lm_positions)

lm_loss, sop_loss = criterion(prediction_scores,
seq_relationship_score,
masked_lm_labels,
next_sentence_labels)
loss = lm_loss + sop_loss
with model.no_sync():
with paddle.amp.auto_cast(args.use_amp,
custom_black_list=[
"reduce_sum",
"c_softmax_with_cross_entropy",
"elementwise_div"
],
level='O2'):

# Create the model for the ernie pretrain
prediction_scores, seq_relationship_score = model(
input_ids=input_ids,
token_type_ids=segment_ids,
position_ids=None,
attention_mask=input_mask,
masked_positions=masked_lm_positions)

lm_loss, sop_loss = criterion(prediction_scores,
seq_relationship_score,
masked_lm_labels,
next_sentence_labels)
loss = lm_loss + sop_loss

if args.use_amp:
scaler.scale(loss).backward()
else:
loss.backward()

fused_allreduce_gradients(list(model.parameters()), None)

if args.use_amp:
scaler.scale(loss).backward()
scaler.minimize(optimizer, loss)
else:
loss.backward()
optimizer.step()

optimizer.clear_grad()
Expand Down
1 change: 1 addition & 0 deletions model_zoo/ernie-1.0/run_pretrain_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def main():
model_config[
"attention_probs_dropout_prob"] = model_args.attention_probs_dropout_prob
model = model_class(base_class(**model_config))
# model_config["enable_recompute"] = args.use_recompute
else:
model = model_class.from_pretrained(
model_args.model_name_or_path,
Expand Down
7 changes: 5 additions & 2 deletions paddlenlp/transformers/ernie/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,8 @@ def __init__(self,
pad_token_id=0,
task_type_vocab_size=3,
task_id=0,
use_task_id=False):
use_task_id=False,
enable_recompute=False):
super(ErnieModel, self).__init__()
self.pad_token_id = pad_token_id
self.initializer_range = initializer_range
Expand All @@ -585,7 +586,9 @@ def __init__(self,
act_dropout=0,
weight_attr=weight_attr,
normalize_before=False)
self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers)
self.encoder = nn.TransformerEncoder(encoder_layer,
num_hidden_layers,
enable_recompute=enable_recompute)
self.pooler = ErniePooler(hidden_size, weight_attr)
self.apply(self.init_weights)

Expand Down
78 changes: 74 additions & 4 deletions paddlenlp/transformers/model_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,25 @@
from dataclasses import fields, dataclass
from typing import Any, List, Tuple, Optional
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle.distributed.fleet.utils import recompute

from .utils import adapt_stale_fwd_patch


def layer_init_wrapper(func):

@functools.wraps(func)
def _impl(self, *args, **kwargs):
enable_recompute = kwargs.pop("enable_recompute", False)
func(self, *args, **kwargs)
if paddle.in_dynamic_mode():
self.enable_recompute = enable_recompute
else:
self.enable_recompute = False

return _impl


def _transformer_encoder_layer_fwd(self,
src,
src_mask=None,
Expand Down Expand Up @@ -60,6 +75,46 @@ def _transformer_encoder_layer_fwd(self,
(src, ) + outputs[::-1]) # hidden_states, cache, attentions


def _transformer_decoder_fwd(self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
cache=None):
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)

output = tgt
new_caches = []
for i, mod in enumerate(self.layers):
if cache is None:
if self.enable_recompute:
output = recompute(mod,
output,
memory,
tgt_mask,
memory_mask,
cache=None)
else:
output = mod(output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
cache=None)
else:
output, new_cache = mod(output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
cache=cache[i])
new_caches.append(new_cache)

if self.norm is not None:
output = self.norm(output)

return output if cache is None else (output, new_caches)


def _transformer_encoder_fwd(self,
src,
src_mask=None,
Expand All @@ -75,10 +130,16 @@ def _transformer_encoder_fwd(self,
# NOTE: Also includes embeding output which is same as HF.
all_hidden_states = [output] if output_hidden_states else None
for i, mod in enumerate(self.layers):
layer_outputs = mod(output,
src_mask=src_mask,
cache=None if cache is None else cache[i],
output_attentions=output_attentions)
if self.enable_recompute:
layer_outputs = recompute(mod, output, src_mask,
None if cache is None else cache[i],
output_attentions)
else:
layer_outputs = mod(output,
src_mask=src_mask,
cache=None if cache is None else cache[i],
output_attentions=output_attentions)

if isinstance(layer_outputs, tuple):
output = layer_outputs[0]
outputs = layer_outputs[1:]
Expand Down Expand Up @@ -122,6 +183,12 @@ def _transformer_encoder_fwd(self,
# patches of paddle.nn.Transformer to get all hidden_states and attentions
paddle.nn.TransformerEncoderLayer.forward = _transformer_encoder_layer_fwd
paddle.nn.TransformerEncoder.forward = _transformer_encoder_fwd
paddle.nn.TransformerDecoder.forward = _transformer_decoder_fwd

_encoder_init = paddle.nn.TransformerEncoder.__init__
_decoder_init = paddle.nn.TransformerDecoder.__init__
paddle.nn.TransformerEncoder.__init__ = layer_init_wrapper(_encoder_init)
paddle.nn.TransformerDecoder.__init__ = layer_init_wrapper(_decoder_init)


def _get_wrap_setattr(cls):
Expand All @@ -139,6 +206,9 @@ def _wrap_setattr(self, name, value):
paddle.nn.TransformerEncoder.__setattr__ = functools.wraps(
paddle.nn.TransformerEncoder.__setattr__)(_get_wrap_setattr(
paddle.nn.TransformerEncoder))
paddle.nn.TransformerDecoder.__setattr__ = functools.wraps(
paddle.nn.TransformerDecoder.__setattr__)(_get_wrap_setattr(
paddle.nn.TransformerDecoder))


def is_tensor(x):
Expand Down

0 comments on commit 7ce81e1

Please sign in to comment.