Skip to content

Commit

Permalink
Upgrade faster ernie by using fused transformers (#1308)
Browse files Browse the repository at this point in the history
* add faster-ernie-1.0 model params

* add amp for faster ernie seq cls task

Co-authored-by: Zeyu Chen <chenzeyu01@baidu.com>
  • Loading branch information
joey12300 and ZeyuChen authored Nov 13, 2021
1 parent 30eb43c commit 2d7e781
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
20 changes: 16 additions & 4 deletions examples/experimental/faster_ernie/seq_cls/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import random
import time
import distutils.util

import numpy as np
import paddle
Expand All @@ -39,6 +40,8 @@
parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
parser.add_argument("--seed", type=int, default=1000, help="random seed for initialization")
parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
parser.add_argument("--use_amp", type=distutils.util.strtobool, default=False, help="Enable mixed precision training.")
parser.add_argument("--scale_loss", type=float, default=2**15, help="The value of scale_loss for fp16.")
args = parser.parse_args()
# yapf: enable

Expand Down Expand Up @@ -121,15 +124,20 @@ def do_train():

criterion = paddle.nn.loss.CrossEntropyLoss()
metric = paddle.metric.Accuracy()
if args.use_amp:
scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)

global_step = 0
tic_train = time.time()
for epoch in range(1, args.epochs + 1):
for step, batch in enumerate(train_data_loader, start=1):
texts, labels = batch["text"], batch["label"]
texts = to_tensor(texts)
logits, predictions = model(texts)
loss = criterion(logits, labels)
with paddle.amp.auto_cast(
args.use_amp,
custom_white_list=["fused_feedforward", "fused_attention"]):
logits, predictions = model(texts)
loss = criterion(logits, labels)
probs = F.softmax(logits, axis=1)
correct = metric.compute(logits, labels)
metric.update(correct)
Expand All @@ -142,8 +150,12 @@ def do_train():
% (global_step, epoch, step, loss, acc,
10 / (time.time() - tic_train)))
tic_train = time.time()
loss.backward()
optimizer.step()
if args.use_amp:
scaler.scale(loss).backward()
scaler.minimize(optimizer, loss)
else:
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()
if global_step % 100 == 0:
Expand Down
11 changes: 6 additions & 5 deletions paddlenlp/experimental/ernie_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import paddle
import paddle.fluid.core as core
import paddle.nn as nn
from paddle.incubate.nn import FusedTransformerEncoderLayer

from paddlenlp.experimental import FasterTokenizer, FasterPretrainedModel
from paddlenlp.transformers.model_utils import register_base_model
Expand Down Expand Up @@ -105,7 +106,7 @@ class FasterErniePretrainedModel(FasterPretrainedModel):
pretrained_resource_files_map = {
"model_state": {
"ernie-1.0":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie/ernie_v1_chn_base.pdparams",
"https://paddlenlp.bj.bcebos.com/models/transformers/faster_ernie/faster_ernie_v1_chn_base.pdparams",
"ernie-tiny":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_tiny/ernie_tiny.pdparams",
"ernie-2.0-en":
Expand Down Expand Up @@ -245,14 +246,14 @@ def __init__(
self.embeddings = ErnieEmbeddings(
vocab_size, hidden_size, hidden_dropout_prob,
max_position_embeddings, type_vocab_size, pad_token_id, weight_attr)
encoder_layer = nn.TransformerEncoderLayer(
encoder_layer = FusedTransformerEncoderLayer(
hidden_size,
num_attention_heads,
intermediate_size,
dropout=hidden_dropout_prob,
dropout_rate=hidden_dropout_prob,
activation=hidden_act,
attn_dropout=attention_probs_dropout_prob,
act_dropout=0,
attn_dropout_rate=attention_probs_dropout_prob,
act_dropout_rate=0,
weight_attr=weight_attr, )
self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers)
self.pooler = ErniePooler(hidden_size, weight_attr)
Expand Down

0 comments on commit 2d7e781

Please sign in to comment.