Skip to content

Commit

Permalink
Fix state_dict/load_state_dict with frozen_weight
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenBao committed May 12, 2020
1 parent 40d121b commit 7243b80
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 22 deletions.
51 changes: 49 additions & 2 deletions onnxruntime/test/python/onnxruntime_test_ort_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import sys
import copy
import numpy as np
from numpy.testing import assert_allclose, assert_array_equal

import onnx
Expand Down Expand Up @@ -378,7 +379,7 @@ def testMNISTStateDict(self):
loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))

state_dict = trainer.state_dict()
assert state_dict.keys() == {'model_.fc1.bias', 'model_.fc1.weight', 'model_.fc2.bias', 'model_.fc2.weight'}
assert state_dict.keys() == {'fc1.bias', 'fc1.weight', 'fc2.bias', 'fc2.weight'}

def testMNISTSaveAsONNX(self):
torch.manual_seed(1)
Expand Down Expand Up @@ -446,7 +447,53 @@ def testMNISTFrozenWeight(self):
data = data.reshape(data.shape[0], -1)

loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))
assert 'fc1.weight' not in trainer.state_dict() and 'fc2.weight' in trainer.state_dict()

fc1_trainstep_1 = trainer.state_dict()['fc1.weight']
fc2_trainstep_1 = trainer.state_dict()['fc2.weight']

loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))

fc1_trainstep_2 = trainer.state_dict()['fc1.weight']
fc2_trainstep_2 = trainer.state_dict()['fc2.weight']
assert np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and \
not np.array_equal(fc2_trainstep_1, fc2_trainstep_2)

def testMNISTFrozenWeightCheckpoint(self):
torch.manual_seed(1)
device = torch.device("cuda")

mnist = MNISTWrapper()
train_loader, test_loader = mnist.get_loaders()
model, model_desc = mnist.get_model()

trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=['fc1.weight'])

learningRate = 0.02
epoch = 0

# do one train step
data, target = next(iter(train_loader))
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)

loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))

# do one eval step
data, target = next(iter(train_loader))
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)

loss, _ = trainer.eval_step(data, target)

# save checkpoint, load model and compare
state_dict = trainer.state_dict()

new_model, _ = mnist.get_model()
trainer = mnist.get_trainer(new_model, model_desc, device, frozen_weights=['fc1.weight'])
trainer.load_state_dict(state_dict)

ckpt_loss, _ = trainer.eval_step(data, target)
assert loss == ckpt_loss

def testBertTrainingBasic(self):
expected_losses = [
Expand Down
Binary file modified onnxruntime/test/testdata/ckpt_mnist.pt
Binary file not shown.
64 changes: 44 additions & 20 deletions orttraining/orttraining/python/ort_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,11 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op
for i, name in enumerate(n.input):
if name in replace_name_dict:
n.input[i] = replace_name_dict[name]
assert set([n.name for n in onnx_model.graph.initializer]) == \
set([n for n, t in model.model_.named_parameters()]), \

# onnx model initializer may contain non-trainable registered buffers that are not part
# of pytorch model named parameteres.
assert set([n for n, t in model.model_.named_parameters()]).issubset(
set([n.name for n in onnx_model.graph.initializer])), \
"Initializer names do not match between PyTorch model and ONNX model, " \
"please report a bug to ONNX Runtime."

Expand Down Expand Up @@ -433,8 +436,7 @@ def create_ort_training_session_with_optimizer(model, device, training_optimizer

unused_frozen_weights = [n for n in frozen_weights if n not in [i.name for i in model.graph.initializer]]
if unused_frozen_weights:
warnings.warn("Ignoring {} in frozen_weights as they are not found in model weights."\
.format(unused_frozen_weights))
raise RuntimeError("{} in frozen_weights not found in model weights.".format(unused_frozen_weights))

weights_to_train = set()
for initializer in model.graph.initializer:
Expand Down Expand Up @@ -763,15 +765,38 @@ def train(self):
def eval(self):
self.is_train = False

def _update_onnx_model_initializers(self, state_tensors):
# replace the initializers with new value
new_weights = []
replace_indices = []
i = 0
for w in self.onnx_model_.graph.initializer:
if w.name in state_tensors:
new_weights.append(numpy_helper.from_array(state_tensors[w.name], w.name))
replace_indices.append(i)
i += 1
replace_indices.sort(reverse=True)
for w_i in replace_indices:
del self.onnx_model_.graph.initializer[w_i]
self.onnx_model_.graph.initializer.extend(new_weights)

def state_dict(self):
if not self.session:
warnings.warn("ONNXRuntime training session is not initialized yet. "
"Please run train_step or eval_step at least once before calling state_dict().")
return {}

# extract trained weights
session_state = self.session.get_state()
torch_state = {}
for name in session_state:
torch_state[name] = torch.from_numpy(session_state[name])

# extract untrained weights and buffer
for n in self.onnx_model_.graph.initializer:
if n.name not in torch_state:
torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n))

return torch_state

def load_state_dict(self, state_dict, strict=False):
Expand All @@ -783,30 +808,29 @@ def load_state_dict(self, state_dict, strict=False):
self.strict_ = strict
return

session_state = {}
# update onnx model from loaded state dict
cur_initializers_names = [n.name for n in self.onnx_model_.graph.initializer]
new_initializers = {}

for name in state_dict:
session_state[name] = state_dict[name].numpy()
self.session.load_state(session_state, strict)
if name in cur_initializers_names:
new_initializers[name] = state_dict[name].numpy()
elif strict:
raise RuntimeError("Checkpoint tensor: {} is not present in the model.".format(name))

self._update_onnx_model_initializers(new_initializers)

# create new session based on updated onnx model
self.state_dict_ = None
self._init_session()

def save_as_onnx(self, path):
if not self.session:
warnings.warn("ONNXRuntime training session is not initialized yet. "
"Please run train_step or eval_step at least once before calling save_as_onnx().")
return
state_tensors = self.session.get_state()
# replace the initializers with new value
new_weights = []
replace_indices = []
i = 0
for w in self.onnx_model_.graph.initializer:
if w.name in state_tensors:
new_weights.append(numpy_helper.from_array(state_tensors[w.name], w.name))
replace_indices.append(i)
i += 1
replace_indices.sort(reverse=True)
for w_i in replace_indices:
del self.onnx_model_.graph.initializer[w_i]
self.onnx_model_.graph.initializer.extend(new_weights)
self._update_onnx_model_initializers(state_tensors)

with open(path, "wb") as f:
f.write(self.onnx_model_.SerializeToString())
Expand Down

0 comments on commit 7243b80

Please sign in to comment.