forked from PaddlePaddle/PaddleRec
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathto_static.py
100 lines (86 loc) · 3.6 KB
/
to_static.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import os
import paddle.nn as nn
import time
import logging
import sys
import importlib
__dir__ = os.path.dirname(os.path.abspath(__file__))
#sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from utils.utils_single import load_yaml, load_dy_model_class, get_abs_model, create_data_loader
from utils.save_load import load_model, save_model, save_jit_model
from paddle.io import DistributedBatchSampler, DataLoader
import argparse
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description='paddle-rec run')
parser.add_argument("-m", "--config_yaml", type=str)
parser.add_argument("-o", "--opt", nargs='*', type=str)
args = parser.parse_args()
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
args.config_yaml = get_abs_model(args.config_yaml)
return args
def main(args):
paddle.seed(12345)
# load config
config = load_yaml(args.config_yaml)
dy_model_class = load_dy_model_class(args.abs_dir)
config["config_abs_dir"] = args.abs_dir
# modify config from command
if args.opt:
for parameter in args.opt:
parameter = parameter.strip()
key, value = parameter.split("=")
if type(config.get(key)) is int:
value = int(value)
if type(config.get(key)) is bool:
value = (True if value.lower() == "true" else False)
config[key] = value
# tools.vars
use_gpu = config.get("runner.use_gpu", True)
train_data_dir = config.get("runner.train_data_dir", None)
epochs = config.get("runner.epochs", None)
print_interval = config.get("runner.print_interval", None)
model_save_path = config.get("runner.model_save_path", "model_output")
model_init_path = config.get("runner.model_init_path", None)
end_epoch = config.get("runner.infer_end_epoch", 0)
CE = config.get("runner.CE", False)
logger.info("**************common.configs**********")
logger.info(
"use_gpu: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
format(use_gpu, train_data_dir, epochs, print_interval,
model_save_path))
logger.info("**************common.configs**********")
place = paddle.set_device('gpu' if use_gpu else 'cpu')
dy_model = dy_model_class.create_model(config)
if not CE:
model_save_path = os.path.join(model_save_path, str(end_epoch - 1))
load_model(model_init_path, dy_model)
# example dnn model forward
dy_model = paddle.jit.to_static(
dy_model,
input_spec=[[
paddle.static.InputSpec(
shape=[None, 1], dtype='int64') for jj in range(26)
], paddle.static.InputSpec(
shape=[None, 13], dtype='float32')])
save_jit_model(dy_model, model_save_path, prefix='tostatic')
if __name__ == '__main__':
args = parse_args()
main(args)