Skip to content

Commit

Permalink
Support load state dict form inference model format save result (Pa…
Browse files Browse the repository at this point in the history
…ddlePaddle#26718)

* support load infer model format state dict

* add unittests

* remove keep name table

* recolve circle inport

* fix compatible problem

* recover unittest

* polish doc and comment
  • Loading branch information
chenwhql authored Sep 3, 2020
1 parent bcdbac1 commit 209273e
Show file tree
Hide file tree
Showing 11 changed files with 350 additions and 97 deletions.
Empty file added paddle/http.log
Empty file.
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@
from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS
from .framework import SaveLoadConfig #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS

from .framework import NoamDecay #DEFINE_ALIAS
Expand Down
166 changes: 96 additions & 70 deletions python/paddle/fluid/dygraph/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,54 @@

import os
import collections
import functools
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer
import pickle
import six
from . import learning_rate_scheduler
import warnings
from .. import core
from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME, _load_persistable_vars
from .base import guard
from paddle.fluid.dygraph.jit import SaveLoadConfig
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers

__all__ = [
'save_dygraph',
'load_dygraph',
]


# NOTE(chenweihang): deprecate load_dygraph's argument keep_name_table,
# ensure compatibility when user still use keep_name_table argument
def deprecate_keep_name_table(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
def __warn_and_build_configs__(keep_name_table):
warnings.warn(
"The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.",
DeprecationWarning)
configs = SaveLoadConfig()
configs.keep_name_table = keep_name_table
return configs

# deal with arg `keep_name_table`
if len(args) > 1 and isinstance(args[1], bool):
args = list(args)
args[1] = __warn_and_build_configs__(args[1])
# deal with kwargs
elif 'keep_name_table' in kwargs:
kwargs['configs'] = __warn_and_build_configs__(kwargs[
'keep_name_table'])
kwargs.pop('keep_name_table')
else:
# do nothing
pass

return func(*args, **kwargs)

return wrapper


@dygraph_only
def save_dygraph(state_dict, model_path):
'''
Expand Down Expand Up @@ -100,41 +134,55 @@ def save_dygraph(state_dict, model_path):

# TODO(qingqing01): remove dygraph_only to support loading static model.
# maybe need to unify the loading interface after 2.0 API is ready.
#@dygraph_only
def load_dygraph(model_path, keep_name_table=False):
# @dygraph_only
@deprecate_keep_name_table
def load_dygraph(model_path, configs=None):
'''
:api_attr: imperative
Load parameter state_dict from disk.
Load parameter state dict from disk.
.. note::
Due to some historical reasons, if you load ``state_dict`` from the saved
result of `paddle.io.save_inference_model`, the structured variable name
will cannot be restored. You need to set the argument `use_structured_name=False`
when using `Layer.set_state_dict` later.
Args:
model_path(str) : The file prefix store the state_dict. (The path should Not contain suffix '.pdparams')
keep_name_table(bool, optional) : Whether keep structed name to parameter name conversion table in output dict.
Default : False
model_path(str) : The file prefix store the state_dict.
(The path should Not contain suffix '.pdparams')
configs (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig`
object that specifies additional configuration options, these options
are for compatibility with ``jit.save/io.save_inference_model`` formats.
Default None.
Returns:
state_dict(dict) : the dict store the state_dict
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
paddle.disable_static()
state_dict = emb.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")
emb = paddle.nn.Embedding([10, 10])
adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
parameter_list = emb.parameters() )
state_dict = adam.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")
state_dict = emb.state_dict()
paddle.save(state_dict, "paddle_dy")
para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy")
scheduler = paddle.optimizer.lr_scheduler.NoamLR(
d_model=0.01, warmup_steps=100, verbose=True)
adam = paddle.optimizer.Adam(
learning_rate=scheduler,
parameters=emb.parameters())
state_dict = adam.state_dict()
paddle.save(state_dict, "paddle_dy")
'''
para_state_dict, opti_state_dict = paddle.load("paddle_dy")
'''
# deal with argument `model_path`
model_prefix = model_path
if model_prefix.endswith(".pdparams"):
model_prefix = model_prefix[:-9]
Expand All @@ -145,66 +193,44 @@ def load_dygraph(model_path, keep_name_table=False):
opti_dict = None
params_file_path = model_prefix + ".pdparams"
opti_file_path = model_prefix + ".pdopt"

# deal with argument `configs`
if configs is None:
configs = SaveLoadConfig()

if not os.path.exists(params_file_path) and not os.path.exists(
opti_file_path):
# Load state dict by `jit.save` save format
# TODO(chenweihang): [Why not support `io.save_infernece_model` save format here]
# Load state dict by `jit.save/io.save_inference_model` save format
# NOTE(chenweihang): [ Compatibility of save_inference_model save format ]
# The model saved by `save_inference_model` does not completely correspond to
# the information required by the `state_dict` under the dygraph.
# Although we reluctantly restore the `state_dict` in some scenarios,
# this may not be complete and there are some limitations, so this function
# will be considered later. The limitations include:
# 1. `save_inference_model` not save structured name, we need to remind
# the user to configure the `use_structured_name` argument when `set_dict`,
# but this argument is currently not public
# 2. if `save_inference_model` save all persistable variables in a single file,
# user need to give the variable name list to load `state_dict`
# `save_inference_model` not save structured name, we need to remind
# the user to configure the `use_structured_name` argument when `set_state_dict`
# NOTE(chenweihang): `jit.save` doesn't save optimizer state

# 1. check model path
if not os.path.isdir(model_prefix):
raise ValueError("Model saved directory '%s' is not exists." %
model_prefix)
# 2. load `__variables.info__`
var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME)
if not os.path.exists(var_info_path):
raise RuntimeError(
"No target can be loaded. Now only supports loading `state_dict` from "
"the result saved by `imperative.save` and `imperative.jit.save`."
)
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f)
# 3. load `__variables__`
# TODO(chenweihang): now only supports loading from default save format:
# - all persistable vars saved in one file named `__variables__`
# for other case, we may need to modify the arguments of this API
var_file_path = os.path.join(model_prefix, VARIABLE_FILENAME)
if not os.path.exists(var_file_path):
raise RuntimeError(
"The parameter file to be loaded was not found. "
"Now only supports loading from the default save format, "
"and does not support custom params_filename and "
"save parameters separately.")
# 4. load all persistable vars
load_var_list = []
for name in sorted(extra_var_info):
var = _varbase_creator(name=name, persistable=True)
load_var_list.append(var)
_dygraph_tracer().trace_op(
type='load_combine',
inputs={},
outputs={'Out': load_var_list},
attrs={'file_path': var_file_path})
# 5. construct state_dict
para_dict = dict()
for var in load_var_list:
structured_name = extra_var_info[var.name].get('structured_name',
None)
if structured_name is None:
raise RuntimeError(
"Cannot find saved variable (%s)'s structured name in saved model.",
var.name)
para_dict[structured_name] = var.numpy()
# NOTE: `jit.save` doesn't save optimizer state

# 2. load program desc & construct _ProgramHolder
programs = _construct_program_holders(model_path,
configs.model_filename)

# 3. load layer parameters & buffers
# NOTE: using fluid.dygraph.guard() here will cause import error in py2
with guard():
persistable_var_dict = _construct_params_and_buffers(
model_prefix,
programs,
configs.separate_params,
configs.params_filename,
append_suffix=False)

# 4. construct state_dict
para_dict = dict()
for var_name in persistable_var_dict:
para_dict[var_name] = persistable_var_dict[var_name].numpy()
else:
# Load state dict by `save_dygraph` save format
para_dict = {}
Expand All @@ -213,7 +239,7 @@ def load_dygraph(model_path, keep_name_table=False):
para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')

if not keep_name_table and "StructuredToParameterName@@" in para_dict:
if not configs.keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"]

if os.path.exists(opti_file_path):
Expand Down
18 changes: 16 additions & 2 deletions python/paddle/fluid/dygraph/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,15 @@ def _load_persistable_vars(model_path,
return load_var_dict


# NOTE(chenweihang): to adapt paddle.load to get state_dict
def _remove_varname_suffix(var_dict, program_holder):
no_suffix_var_dict = dict()
for var_name in var_dict:
no_suffix_name = program_holder._suffix_varname_dict[var_name]
no_suffix_var_dict[no_suffix_name] = var_dict[var_name]
return no_suffix_var_dict


def _construct_program_holders(model_path, model_filename=None):
# make sure the path has been checked
program_holder_dict = dict()
Expand Down Expand Up @@ -517,7 +526,8 @@ def _construct_program_holders(model_path, model_filename=None):
def _construct_params_and_buffers(model_path,
programs,
separate_params=False,
params_filename=None):
params_filename=None,
append_suffix=True):
var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME)
if os.path.exists(var_info_path):
var_dict = _load_persistable_vars(model_path, var_info_path,
Expand All @@ -526,6 +536,10 @@ def _construct_params_and_buffers(model_path,
else:
var_dict = _load_persistable_vars_by_program(
model_path, programs['forward'], params_filename)

if not append_suffix:
var_dict = _remove_varname_suffix(var_dict, programs['forward'])

return var_dict


Expand Down Expand Up @@ -685,7 +699,7 @@ def _construct(model_path, configs=None):
# 1. load program desc & construct _ProgramHolder
programs = _construct_program_holders(model_path, model_filename)

# 2. load layer parameters & parameter attributes
# 2. load layer parameters & buffers
persistable_vars = _construct_params_and_buffers(
model_path, programs, separate_params, params_filename)

Expand Down
50 changes: 50 additions & 0 deletions python/paddle/fluid/dygraph/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def __init__(self):
self._model_filename = None
self._params_filename = None
self._separate_params = False
# used for `paddle.load`
self._keep_name_table = False

# NOTE: Users rarely use following configs, so these configs are not open to users,
# reducing user learning costs, but we retain the configuration capabilities
Expand Down Expand Up @@ -600,6 +602,54 @@ def separate_params(self, value):
% type(value))
self._separate_params = value

@property
def keep_name_table(self):
"""
Configures whether keep ``structured_name -> parameter_name`` dict in loaded state dict.
This dict is the debugging information saved when call `paddle.save`.
It is generally only used for debugging and does not affect the actual training or inference.
By default, it will not be retained in `paddle.load` result. Default: False.
.. note::
Only used for ``paddle.load``.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
linear = paddle.nn.Linear(5, 1)
state_dict = linear.state_dict()
paddle.save(state_dict, "paddle_dy")
configs = paddle.SaveLoadConfig()
configs.keep_name_table = True
para_state_dict, _ = paddle.load("paddle_dy", configs)
print(para_state_dict)
# the name_table is 'StructuredToParameterName@@'
# {'bias': array([0.], dtype=float32),
# 'StructuredToParameterName@@':
# {'bias': u'linear_0.b_0', 'weight': u'linear_0.w_0'},
# 'weight': array([[ 0.04230034],
# [-0.1222527 ],
# [ 0.7392676 ],
# [-0.8136974 ],
# [ 0.01211023]], dtype=float32)}
"""
return self._keep_name_table

@keep_name_table.setter
def keep_name_table(self, value):
if not isinstance(value, bool):
raise TypeError(
"The SaveLoadConfig.keep_name_table should be bool value, but received input's type is %s."
% type(value))
self._keep_name_table = value


@switch_to_static_graph
def save(layer, model_path, input_spec=None, configs=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_new_directory(self):
'paddle.distributed.prepare_context', 'paddle.DataParallel',
'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static',
'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer',
'paddle.jit.save', 'paddle.jit.load', 'paddle.jit.SaveLoadConfig',
'paddle.jit.save', 'paddle.jit.load', 'paddle.SaveLoadConfig',
'paddle.NoamDecay', 'paddle.PiecewiseDecay',
'paddle.NaturalExpDecay', 'paddle.ExponentialDecay',
'paddle.InverseTimeDecay', 'paddle.PolynomialDecay',
Expand Down
16 changes: 16 additions & 0 deletions python/paddle/fluid/tests/unittests/test_imperative_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,22 @@ def testOnlyLoadParams(self):
para_state_dict, opti_state_dict = paddle.load(
os.path.join('saved_dy', 'emb_dy.pdopt'))

def test_load_compatible_with_keep_name_table(self):
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
state_dict = emb.state_dict()
paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy'))

para_state_dict, opti_state_dict = paddle.load(
os.path.join('saved_dy', 'emb_dy'), True)
self.assertTrue(para_state_dict != None)
self.assertTrue(opti_state_dict == None)

para_state_dict, opti_state_dict = paddle.load(
os.path.join('saved_dy', 'emb_dy'), keep_name_table=True)
self.assertTrue(para_state_dict != None)
self.assertTrue(opti_state_dict == None)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 209273e

Please sign in to comment.