Skip to content

Commit

Permalink
Remove 'model_.' prefix from onnx model initializers in training (#3881)
Browse files Browse the repository at this point in the history
* Remove 'model_.' prefix for onnx model initializers in training

* fix test case remove redundant device test

* rename

* Fix state_dict/load_state_dict with frozen_weight

* nit

* Add monkey patch for pt opset 10

* remove pt patch in CI

* nit: newline
  • Loading branch information
BowenBao authored May 20, 2020
1 parent 08763e8 commit 0a5395b
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 229 deletions.
94 changes: 91 additions & 3 deletions onnxruntime/test/python/onnxruntime_test_ort_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import sys
import copy
import numpy as np
from numpy.testing import assert_allclose, assert_array_equal

import onnx
Expand Down Expand Up @@ -260,9 +261,9 @@ def get_model(self):
model_desc = MNISTWrapper.mnist_model_description()
return model, model_desc

def get_trainer(self, model, model_desc, device, onnx_opset_ver=12):
def get_trainer(self, model, model_desc, device, onnx_opset_ver=12, frozen_weights=[]):
return ORTTrainer(model, MNISTWrapper.my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ],
torch.float32), device, _opset_version=onnx_opset_ver)
torch.float32), device, _opset_version=onnx_opset_ver, frozen_weights=frozen_weights)

class TestOrtTrainer(unittest.TestCase):

Expand Down Expand Up @@ -386,7 +387,7 @@ def testMNISTStateDict(self):
loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))

state_dict = trainer.state_dict()
assert state_dict.keys() == {'model_.fc1.bias', 'model_.fc1.weight', 'model_.fc2.bias', 'model_.fc2.weight'}
assert state_dict.keys() == {'fc1.bias', 'fc1.weight', 'fc2.bias', 'fc2.weight'}

def testMNISTSaveAsONNX(self):
torch.manual_seed(1)
Expand Down Expand Up @@ -435,6 +436,93 @@ def testMNISTDevice(self):

loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))

def testMNISTInitializerNames(self):
torch.manual_seed(1)
device = torch.device("cuda")

mnist = MNISTWrapper()
train_loader, test_loader = mnist.get_loaders()
model, model_desc = mnist.get_model()

trainer = mnist.get_trainer(model, model_desc, device)
learningRate = 0.02
epoch = 0

data, target = next(iter(train_loader))
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)

loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))

assert set([n.name for n in trainer.onnx_model_.graph.initializer]) \
== set([n for n, t in model.named_parameters()])

def testMNISTFrozenWeight(self):
torch.manual_seed(1)
device = torch.device("cuda")

mnist = MNISTWrapper()
train_loader, test_loader = mnist.get_loaders()
model, model_desc = mnist.get_model()

trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=['fc1.weight'])

learningRate = 0.02
epoch = 0

data, target = next(iter(train_loader))
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)

loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))

fc1_trainstep_1 = trainer.state_dict()['fc1.weight']
fc2_trainstep_1 = trainer.state_dict()['fc2.weight']

loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))

fc1_trainstep_2 = trainer.state_dict()['fc1.weight']
fc2_trainstep_2 = trainer.state_dict()['fc2.weight']
assert np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and \
not np.array_equal(fc2_trainstep_1, fc2_trainstep_2)

def testMNISTFrozenWeightCheckpoint(self):
torch.manual_seed(1)
device = torch.device("cuda")

mnist = MNISTWrapper()
train_loader, test_loader = mnist.get_loaders()
model, model_desc = mnist.get_model()

trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=['fc1.weight'])

learningRate = 0.02
epoch = 0

# do one train step
data, target = next(iter(train_loader))
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)

loss, _ = trainer.train_step(data, target, torch.tensor([learningRate]))

# do one eval step
data, target = next(iter(train_loader))
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[0], -1)

loss, _ = trainer.eval_step(data, target)

# save checkpoint, load model and compare
state_dict = trainer.state_dict()

new_model, _ = mnist.get_model()
trainer = mnist.get_trainer(new_model, model_desc, device, frozen_weights=['fc1.weight'])
trainer.load_state_dict(state_dict)

ckpt_loss, _ = trainer.eval_step(data, target)
assert loss == ckpt_loss

def testBertTrainingBasic(self):
expected_losses = [
11.02906322479248, 11.094074249267578, 11.00899887084961, 11.06129264831543,
Expand Down
Binary file modified onnxruntime/test/testdata/ckpt_mnist.pt
Binary file not shown.
82 changes: 63 additions & 19 deletions orttraining/orttraining/python/ort_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import onnxruntime as ort
from distutils.version import LooseVersion
from .checkpointing_utils import list_checkpoint_files, get_checkpoint_name, CombineZeroCheckpoint
import onnxruntime.capi.pt_patch

DEFAULT_OPSET_VERSION = 10

Expand Down Expand Up @@ -326,11 +327,29 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op
do_constant_folding=False,
**other_export_options)

model = onnx.load_model_from_string(f.getvalue())
onnx_model = onnx.load_model_from_string(f.getvalue())

model = FuseSofmaxNLLToSoftmaxCE(model)
# Remove 'model_.' prefix introduced by model wrapper for initializers.
replace_name_dict = {}
for n in onnx_model.graph.initializer:
if n.name.startswith('model_.'):
replace_name_dict[n.name] = n.name[len('model_.'):]
n.name = replace_name_dict[n.name]
for n in onnx_model.graph.node:
for i, name in enumerate(n.input):
if name in replace_name_dict:
n.input[i] = replace_name_dict[name]

return model
# onnx model initializer may contain non-trainable registered buffers that are not part
# of pytorch model named parameteres.
assert set([n for n, t in model.model_.named_parameters()]).issubset(
set([n.name for n in onnx_model.graph.initializer])), \
"Initializer names do not match between PyTorch model and ONNX model, " \
"please report a bug to ONNX Runtime."

onnx_model = FuseSofmaxNLLToSoftmaxCE(onnx_model)

return onnx_model

def create_ort_training_session_with_optimizer(model, device, training_optimizer_name, lr_params_feed_name,
map_optimizer_attributes, world_rank=-1, world_size=1,
Expand Down Expand Up @@ -361,6 +380,11 @@ def create_ort_training_session_with_optimizer(model, device, training_optimizer
torch_params = {}
optimizer_attributes_map = {}
optimizer_int_attributes_map = {}

unused_frozen_weights = [n for n in frozen_weights if n not in [i.name for i in model.graph.initializer]]
if unused_frozen_weights:
raise RuntimeError("{} in frozen_weights not found in model weights.".format(unused_frozen_weights))

weights_to_train = set()
for initializer in model.graph.initializer:
if initializer.name in frozen_weights:
Expand Down Expand Up @@ -639,15 +663,36 @@ def train(self):
def eval(self):
self.is_train = False

def _update_onnx_model_initializers(self, state_tensors):
# replace the initializers with new value
new_weights = []
replace_indices = []
for i, w in enumerate(self.onnx_model_.graph.initializer):
if w.name in state_tensors:
new_weights.append(numpy_helper.from_array(state_tensors[w.name], w.name))
replace_indices.append(i)
replace_indices.sort(reverse=True)
for w_i in replace_indices:
del self.onnx_model_.graph.initializer[w_i]
self.onnx_model_.graph.initializer.extend(new_weights)

def state_dict(self):
if not self.session:
warnings.warn("ONNXRuntime training session is not initialized yet. "
"Please run train_step or eval_step at least once before calling state_dict().")
return {}

# extract trained weights
session_state = self.session.get_state()
torch_state = {}
for name in session_state:
torch_state[name] = torch.from_numpy(session_state[name])

# extract untrained weights and buffer
for n in self.onnx_model_.graph.initializer:
if n.name not in torch_state:
torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n))

return torch_state

def load_state_dict(self, state_dict, strict=False):
Expand All @@ -659,30 +704,29 @@ def load_state_dict(self, state_dict, strict=False):
self.strict_ = strict
return

session_state = {}
# update onnx model from loaded state dict
cur_initializers_names = [n.name for n in self.onnx_model_.graph.initializer]
new_initializers = {}

for name in state_dict:
session_state[name] = state_dict[name].numpy()
self.session.load_state(session_state, strict)
if name in cur_initializers_names:
new_initializers[name] = state_dict[name].numpy()
elif strict:
raise RuntimeError("Checkpoint tensor: {} is not present in the model.".format(name))

self._update_onnx_model_initializers(new_initializers)

# create new session based on updated onnx model
self.state_dict_ = None
self._init_session()

def save_as_onnx(self, path):
if not self.session:
warnings.warn("ONNXRuntime training session is not initialized yet. "
"Please run train_step or eval_step at least once before calling save_as_onnx().")
return
state_tensors = self.session.get_state()
# replace the initializers with new value
new_weights = []
replace_indices = []
i = 0
for w in self.onnx_model_.graph.initializer:
if w.name in state_tensors:
new_weights.append(numpy_helper.from_array(state_tensors[w.name], w.name))
replace_indices.append(i)
i += 1
replace_indices.sort(reverse=True)
for w_i in replace_indices:
del self.onnx_model_.graph.initializer[w_i]
self.onnx_model_.graph.initializer.extend(new_weights)
self._update_onnx_model_initializers(state_tensors)

with open(path, "wb") as f:
f.write(self.onnx_model_.SerializeToString())
Expand Down
21 changes: 21 additions & 0 deletions orttraining/orttraining/python/pt_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch

from torch.onnx import symbolic_opset10
from torch.onnx.symbolic_helper import parse_args

@parse_args('v', 'v', 'v', 'v', 'i', 'none')
def nll_loss(g, self, target, weight=None, reduction='mean', ignore_index=-100):
if not weight and not ignore_index:
return g.op("nll_loss", self, target)
elif ignore_index:
ignore_index_ = g.op("Constant", value_t=torch.tensor(ignore_index, dtype=torch.int64))
eq_ = g.op("Equal", target, ignore_index_)
not_eq_ = g.op("Not", eq_)
weight_ = g.op("Cast", not_eq_, to_i=1) # FLOAT = 1; // float
not_eq_int64_ = g.op("Cast", not_eq_, to_i=7) #INT64 = 7; // int64_t
target_ = g.op("Mul", target, not_eq_int64_)
# if weight:
# weight_ = g.op("Mul", weight_, weight)
return g.op("nll_loss", self, target_, weight_)

symbolic_opset10.nll_loss = nll_loss
12 changes: 4 additions & 8 deletions tools/ci_build/github/linux/docker/scripts/install_deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function GetFile {
if command -v aria2c > /dev/null; then
aria2c -q -d $(dirname $path) -o $(basename $path) "$uri"
else
curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail
curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail
fi

return $?
Expand Down Expand Up @@ -76,15 +76,15 @@ fi
if [[ $SYS_LONG_BIT = "64" && "$GLIBC_VERSION" -gt "9" ]]; then
echo "Installing azcopy"
mkdir -p /tmp/azcopy
GetFile https://aka.ms/downloadazcopy-v10-linux /tmp/azcopy/azcopy.tar.gz
GetFile https://aka.ms/downloadazcopy-v10-linux /tmp/azcopy/azcopy.tar.gz
tar --strip 1 -xf /tmp/azcopy/azcopy.tar.gz -C /tmp/azcopy
cp /tmp/azcopy/azcopy /usr/bin
echo "Installing cmake"
GetFile https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5-Linux-x86_64.tar.gz /tmp/src/cmake-3.13.5-Linux-x86_64.tar.gz
GetFile https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5-Linux-x86_64.tar.gz /tmp/src/cmake-3.13.5-Linux-x86_64.tar.gz
tar -zxf /tmp/src/cmake-3.13.5-Linux-x86_64.tar.gz --strip=1 -C /usr
else
echo "Installing cmake"
GetFile https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5.tar.gz /tmp/src/cmake-3.13.5.tar.gz
GetFile https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5.tar.gz /tmp/src/cmake-3.13.5.tar.gz
tar -xf /tmp/src/cmake-3.13.5.tar.gz -C /tmp/src
pushd .
cd /tmp/src/cmake-3.13.5
Expand Down Expand Up @@ -112,10 +112,6 @@ elif [ $DEVICE_TYPE = "gpu" ]; then
${PYTHON_EXE} -m pip install sympy==1.1.1
if [[ $BUILD_EXTR_PAR = *--enable_training* ]]; then
${PYTHON_EXE} -m pip install --upgrade --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html

# patch pytorch onnx export opset version 10 to export nll_loss
PATH_TO_SYMBOLIC10=$(${PYTHON_EXE} -c 'import torch; import os; print(os.path.join(os.path.dirname(torch.__file__), "onnx/"))')
cp "${SCRIPT_DIR}/pyt_patch/symbolic_opset10.py" "${PATH_TO_SYMBOLIC10}"
fi
if [[ $BUILD_EXTR_PAR = *--enable_training_python_frontend_e2e_tests* ]]; then
${PYTHON_EXE} -m pip install transformers
Expand Down
Loading

0 comments on commit 0a5395b

Please sign in to comment.