Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ConvBert] support static export #1643

Merged
merged 5 commits into from
Mar 4, 2022
Merged

[ConvBert] support static export #1643

merged 5 commits into from
Mar 4, 2022

Conversation

JunnYu
Copy link
Member

@JunnYu JunnYu commented Jan 26, 2022

PR types

Bug fixes

PR changes

Models

Description

支持convbert静态图导出。

  • 使用paddle.shape而不是tensor.shape获取tensor的形状。
  • F.unfold无法在变长输入情况下使用,因此采用for循环加slice方式处理。

导出代码

from paddlenlp.transformers import ConvBertModel, ConvBertTokenizer
import paddle
import os
import argparse


def get_args(add_help=True):
    """get_args
    Parse all args using argparse lib
    Args:
        add_help: Whether to add -h option on args
    Returns:
        An object which contains many parameters used for inference.
    """
    parser = argparse.ArgumentParser(
        description='Paddlenlp Classification Training', add_help=add_help)
    parser.add_argument(
        "--model_path",
        default="convbert-base",
        type=str,
        help="Path of the trained model to be exported.", )
    parser.add_argument(
        '--save_inference_dir',
        default='./convbert_infer',
        help='path where to save')

    args = parser.parse_args()
    return args


def export(args):
    # build model
    model = ConvBertModel.from_pretrained(args.model_path)
    tokenizer = ConvBertTokenizer.from_pretrained(args.model_path)
    model.eval()

    # decorate model with jit.save
    model = paddle.jit.to_static(
        model,
        input_spec=[
            paddle.static.InputSpec(
                shape=[None, None], dtype="int64"),  # input_ids
            paddle.static.InputSpec(
                shape=[None, None], dtype="int64")  # token_type_ids
        ])
    # save inference model
    paddle.jit.save(model, os.path.join(args.save_inference_dir, "inference"))
    tokenizer.save_pretrained(args.save_inference_dir)
    print(
        f"inference model and tokenizer have been saved into {args.save_inference_dir}"
    )


if __name__ == "__main__":
    args = get_args()
    export(args)

@yingyibiao yingyibiao self-assigned this Jan 26, 2022
Copy link
Contributor

@yingyibiao yingyibiao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yingyibiao yingyibiao merged commit 8139863 into PaddlePaddle:develop Mar 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants