Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tensor and optimizer serialization #6087

Merged
merged 10 commits into from
Aug 29, 2021
7 changes: 7 additions & 0 deletions oneflow/api/python/framework/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
.def("__repr__", [](const Symbol<DType>& d) { return d->name(); })
.def(py::self == py::self)
.def(py::hash(py::self))
.def(py::pickle(
[](const Symbol<DType>& dtype) { // __getstate__
return static_cast<int>(dtype->data_type());
},
[](int t) { // __setstate__
return CHECK_JUST(DType::Get(DataType(t)));
}))
.def_property_readonly(
"bytes", [](const Symbol<DType>& dtype) { return dtype->bytes().GetOrThrow(); });

Expand Down
11 changes: 11 additions & 0 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ def _ne(self, other):
return self.ne(other)


def _getstate(self):
assert self.is_local, "Only support local tensor to pickle"
return {"data": self.numpy(), "dtype": self.dtype}


def _setstate(self, pickle_dict):
return self.__init__(pickle_dict["data"], dtype=pickle_dict["dtype"])


def is_nonzero(input):
r"""
is_nonzero(input) -> (bool)
Expand Down Expand Up @@ -390,6 +399,8 @@ def RegisterMethods():
Tensor.backward = _backward
Tensor.__getitem__ = _getitem
Tensor.__setitem__ = _setitem
Tensor.__setstate__ = _setstate
Tensor.__getstate__ = _getstate
Tensor.__str__ = _str
Tensor.__repr__ = _repr
Tensor.__eq__ = _eq
Expand Down
116 changes: 114 additions & 2 deletions python/oneflow/nn/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""
import collections
from itertools import chain
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Union
Expand Down Expand Up @@ -42,6 +43,8 @@ def __init__(
self._options["clip_grad_norm_type"] = parameters["clip_grad_norm_type"]

def __getitem__(self, key):
if key == "params":
return self._parameters
return self._options[key]

def __setitem__(self, key, value):
Expand All @@ -50,6 +53,9 @@ def __setitem__(self, key, value):
def __contains__(self, key):
return self._options.__contains__(key)

def items(self):
return self.__dict__.items()

@property
def options(self):
return self._options
Expand All @@ -72,10 +78,116 @@ def add_param_group(self, param_group) -> None:
raise NotImplementedError()

def load_state_dict(self, state_dict) -> None:
raise NotImplementedError()
r"""
Load the state of the optimizer which is created by `state_dict` function.

It almost copied from: https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.load_state_dict
"""

# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict["param_groups"]

if len(groups) != len(saved_groups):
raise ValueError(
"loaded state dict has a different number of " "parameter groups"
)
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)

# Update the state
id_map = {
old_id: p
for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)),
chain.from_iterable((g["params"] for g in groups)),
)
}

def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, Tensor):
if value.is_local:
value = value.to(param.device)
else:
value = value.to_consistent(
placement=value.placement, sbp=value.sbp
)
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, collections.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value

# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = dict()
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
else:
state[k] = v
self._state = state

# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
group._options = deepcopy(new_group["_options"])
group._enable_clip_grad = new_group["_enable_clip_grad"]
return group

param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.param_groups = param_groups

def state_dict(self):
raise NotImplementedError()
r"""
Returns the state of the optimizer as a :class:`dict`.

It contains two entries:

* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_group - a dict containing all parameter groups.

It almost copied from: https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.state_dict
"""

# Save order indices instead of Tensors
param_mappings = {}
start_index = 0

def pack_group(group):
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "_parameters"}
param_mappings.update(
{
id(p): i
for i, p in enumerate(group["params"], start_index)
if id(p) not in param_mappings
}
)
packed["params"] = [param_mappings[id(p)] for p in group["params"]]
start_index += len(packed["params"])
return packed

param_groups = [pack_group(g) for g in self.param_groups]
# Remap state to use order indices as keys
packed_state = {
(param_mappings[id(k)] if isinstance(k, Tensor) else k): v
for k, v in self._state.items()
}
return {
"state": packed_state,
"param_groups": param_groups,
}

def step(self, closure: Union[Callable, None] = None) -> Union[Tensor, None]:
raise NotImplementedError()
Expand Down
33 changes: 33 additions & 0 deletions python/oneflow/test/modules/test_optim_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""

import tempfile
import unittest
from collections import OrderedDict

Expand All @@ -35,6 +36,8 @@ def compare_with_numpy_adam(
weight_decay,
eps,
do_bias_correction,
reload_state_step,
save_load_by_pickle,
):
random_grad_seq = []
for _ in range(train_iters):
Expand Down Expand Up @@ -67,6 +70,18 @@ def train_one_iter(grad):

for i in range(train_iters):
train_one_iter(random_grad_seq[i])
if i == reload_state_step:
state_dict = adam.state_dict()
adam = flow.optim.Adam([x])
if save_load_by_pickle:
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
file_name = f.name
import pickle

pickle.dump(state_dict, f)
with open(file_name, "rb") as f:
state_dict = pickle.load(f)
adam.load_state_dict(state_dict)
return x

def train_by_numpy():
Expand Down Expand Up @@ -116,6 +131,8 @@ def compare_with_numpy_adam_clip_grad(
do_bias_correction,
clip_grad_max_norm,
clip_grad_norm_type,
reload_state_step,
save_load_by_pickle,
):
random_grad_seq = []
for _ in range(train_iters):
Expand Down Expand Up @@ -151,6 +168,18 @@ def train_one_iter(grad):

for i in range(train_iters):
train_one_iter(random_grad_seq[i])
if i == reload_state_step:
state_dict = adam.state_dict()
adam = flow.optim.Adam([x])
if save_load_by_pickle:
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
file_name = f.name
import pickle

pickle.dump(state_dict, f)
with open(file_name, "rb") as f:
state_dict = pickle.load(f)
adam.load_state_dict(state_dict)
return x

def train_by_numpy():
Expand Down Expand Up @@ -203,6 +232,8 @@ def test_adam(test_case):
arg_dict["weight_decay"] = [0.0, 0.1]
arg_dict["eps"] = [1e-08, 1e-07]
arg_dict["do_bias_correction"] = [True, False]
arg_dict["reload_state_step"] = [5] # save and load optim state
arg_dict["save_load_by_pickle"] = [False, True]

for arg in GenArgList(arg_dict):
compare_with_numpy_adam(test_case, *arg)
Expand All @@ -219,6 +250,8 @@ def test_adam_clip_grad(test_case):
arg_dict["do_bias_correction"] = [True, False]
arg_dict["clip_grad_max_norm"] = [0, 0.5, 1.0]
arg_dict["clip_grad_norm_type"] = ["inf", "-inf", 0.0, 1.0, 2.0, 3.5]
arg_dict["reload_state_step"] = [5] # save and load optim state
arg_dict["save_load_by_pickle"] = [False, True]

for arg in GenArgList(arg_dict):
compare_with_numpy_adam_clip_grad(test_case, *arg)
Expand Down
40 changes: 39 additions & 1 deletion python/oneflow/test/modules/test_optim_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""

import tempfile
import unittest
from collections import OrderedDict

Expand All @@ -26,7 +27,14 @@


def compare_with_numpy_adamw(
test_case, device, x_shape, learning_rate, train_iters, weight_decay
test_case,
device,
x_shape,
learning_rate,
train_iters,
weight_decay,
reload_state_step,
save_load_by_pickle,
):
random_grad_seq = []
for _ in range(train_iters):
Expand All @@ -50,6 +58,18 @@ def train_one_iter(grad):

for i in range(train_iters):
train_one_iter(random_grad_seq[i])
if i == reload_state_step:
state_dict = adam.state_dict()
adam = flow.optim.AdamW([x])
if save_load_by_pickle:
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
file_name = f.name
import pickle

pickle.dump(state_dict, f)
with open(file_name, "rb") as f:
state_dict = pickle.load(f)
adam.load_state_dict(state_dict)
return x

def train_by_numpy():
Expand Down Expand Up @@ -91,6 +111,8 @@ def compare_with_numpy_adamw_clip_grad(
weight_decay,
clip_grad_max_norm,
clip_grad_norm_type,
reload_state_step,
save_load_by_pickle,
):
random_grad_seq = []
for _ in range(train_iters):
Expand Down Expand Up @@ -123,6 +145,18 @@ def train_one_iter(grad):

for i in range(train_iters):
train_one_iter(random_grad_seq[i])
if i == reload_state_step:
state_dict = adam.state_dict()
adam = flow.optim.AdamW([x])
if save_load_by_pickle:
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
file_name = f.name
import pickle

pickle.dump(state_dict, f)
with open(file_name, "rb") as f:
state_dict = pickle.load(f)
adam.load_state_dict(state_dict)
return x

def train_by_numpy():
Expand Down Expand Up @@ -167,6 +201,8 @@ def test_adamw(test_case):
arg_dict["learning_rate"] = [1]
arg_dict["train_iters"] = [10]
arg_dict["weight_decay"] = [0.001, 0.0]
arg_dict["reload_state_step"] = [5] # save and load optim state
arg_dict["save_load_by_pickle"] = [False, True]
for arg in GenArgList(arg_dict):
compare_with_numpy_adamw(test_case, *arg)

Expand All @@ -179,6 +215,8 @@ def test_adamw_clip_grad(test_case):
arg_dict["weight_decay"] = [0.001, 0.0]
arg_dict["clip_grad_max_norm"] = [0, 0.5, 1.0]
arg_dict["clip_grad_norm_type"] = ["inf", "-inf", 0.0, 1.0, 2.0, 3.5]
arg_dict["reload_state_step"] = [5] # save and load optim state
arg_dict["save_load_by_pickle"] = [False, True]
for arg in GenArgList(arg_dict):
compare_with_numpy_adamw_clip_grad(test_case, *arg)

Expand Down
Loading