From 0a5395bb78810d0beec2c4b0f9094998580f45aa Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Wed, 20 May 2020 10:06:31 -0700 Subject: [PATCH] Remove 'model_.' prefix from onnx model initializers in training (#3881) * Remove 'model_.' prefix for onnx model initializers in training * fix test case remove redundant device test * rename * Fix state_dict/load_state_dict with frozen_weight * nit * Add monkey patch for pt opset 10 * remove pt patch in CI * nit: newline --- .../python/onnxruntime_test_ort_trainer.py | 94 ++++++++- onnxruntime/test/testdata/ckpt_mnist.pt | Bin 1595972 -> 1595944 bytes orttraining/orttraining/python/ort_trainer.py | 82 ++++++-- orttraining/orttraining/python/pt_patch.py | 21 ++ .../linux/docker/scripts/install_deps.sh | 12 +- .../scripts/pyt_patch/symbolic_opset10.py | 199 ------------------ 6 files changed, 179 insertions(+), 229 deletions(-) create mode 100644 orttraining/orttraining/python/pt_patch.py delete mode 100644 tools/ci_build/github/linux/docker/scripts/pyt_patch/symbolic_opset10.py diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py index 5ab6e538810f8..98c9b6b8999b1 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py @@ -6,6 +6,7 @@ import pytest import sys import copy +import numpy as np from numpy.testing import assert_allclose, assert_array_equal import onnx @@ -260,9 +261,9 @@ def get_model(self): model_desc = MNISTWrapper.mnist_model_description() return model, model_desc - def get_trainer(self, model, model_desc, device, onnx_opset_ver=12): + def get_trainer(self, model, model_desc, device, onnx_opset_ver=12, frozen_weights=[]): return ORTTrainer(model, MNISTWrapper.my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ], - torch.float32), device, _opset_version=onnx_opset_ver) + torch.float32), device, _opset_version=onnx_opset_ver, frozen_weights=frozen_weights) class TestOrtTrainer(unittest.TestCase): @@ -386,7 +387,7 @@ def testMNISTStateDict(self): loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) state_dict = trainer.state_dict() - assert state_dict.keys() == {'model_.fc1.bias', 'model_.fc1.weight', 'model_.fc2.bias', 'model_.fc2.weight'} + assert state_dict.keys() == {'fc1.bias', 'fc1.weight', 'fc2.bias', 'fc2.weight'} def testMNISTSaveAsONNX(self): torch.manual_seed(1) @@ -435,6 +436,93 @@ def testMNISTDevice(self): loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) + def testMNISTInitializerNames(self): + torch.manual_seed(1) + device = torch.device("cuda") + + mnist = MNISTWrapper() + train_loader, test_loader = mnist.get_loaders() + model, model_desc = mnist.get_model() + + trainer = mnist.get_trainer(model, model_desc, device) + learningRate = 0.02 + epoch = 0 + + data, target = next(iter(train_loader)) + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) + + assert set([n.name for n in trainer.onnx_model_.graph.initializer]) \ + == set([n for n, t in model.named_parameters()]) + + def testMNISTFrozenWeight(self): + torch.manual_seed(1) + device = torch.device("cuda") + + mnist = MNISTWrapper() + train_loader, test_loader = mnist.get_loaders() + model, model_desc = mnist.get_model() + + trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=['fc1.weight']) + + learningRate = 0.02 + epoch = 0 + + data, target = next(iter(train_loader)) + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) + + fc1_trainstep_1 = trainer.state_dict()['fc1.weight'] + fc2_trainstep_1 = trainer.state_dict()['fc2.weight'] + + loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) + + fc1_trainstep_2 = trainer.state_dict()['fc1.weight'] + fc2_trainstep_2 = trainer.state_dict()['fc2.weight'] + assert np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and \ + not np.array_equal(fc2_trainstep_1, fc2_trainstep_2) + + def testMNISTFrozenWeightCheckpoint(self): + torch.manual_seed(1) + device = torch.device("cuda") + + mnist = MNISTWrapper() + train_loader, test_loader = mnist.get_loaders() + model, model_desc = mnist.get_model() + + trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=['fc1.weight']) + + learningRate = 0.02 + epoch = 0 + + # do one train step + data, target = next(iter(train_loader)) + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) + + # do one eval step + data, target = next(iter(train_loader)) + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + loss, _ = trainer.eval_step(data, target) + + # save checkpoint, load model and compare + state_dict = trainer.state_dict() + + new_model, _ = mnist.get_model() + trainer = mnist.get_trainer(new_model, model_desc, device, frozen_weights=['fc1.weight']) + trainer.load_state_dict(state_dict) + + ckpt_loss, _ = trainer.eval_step(data, target) + assert loss == ckpt_loss + def testBertTrainingBasic(self): expected_losses = [ 11.02906322479248, 11.094074249267578, 11.00899887084961, 11.06129264831543, diff --git a/onnxruntime/test/testdata/ckpt_mnist.pt b/onnxruntime/test/testdata/ckpt_mnist.pt index 01a9a9b4488c1331915e070bc7197a8b0a59b6c7..e3b5723604577474d142c79d9c0632c102f8c1a7 100644 GIT binary patch delta 382 zcmX@IC2_@;#0l?s%qkQf3zIOgG`2LFJdIIZn+qtD zmTah3o|>7SQBo+G!K9&)!JHu&!3UJEG=UjpWN2Bn_M1hHkv8ko0(d37W`UXfQq$L~aC1oZS7b<2jX=r3HX9z~{0kvD2 zz-=`!G@ES6q-um@Do9g#YG!&yNuee-ZH7jsCT12ElRKHLoP%jhnZPo3D+Vzl~d /dev/null; then aria2c -q -d $(dirname $path) -o $(basename $path) "$uri" else - curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail + curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail fi return $? @@ -76,15 +76,15 @@ fi if [[ $SYS_LONG_BIT = "64" && "$GLIBC_VERSION" -gt "9" ]]; then echo "Installing azcopy" mkdir -p /tmp/azcopy - GetFile https://aka.ms/downloadazcopy-v10-linux /tmp/azcopy/azcopy.tar.gz + GetFile https://aka.ms/downloadazcopy-v10-linux /tmp/azcopy/azcopy.tar.gz tar --strip 1 -xf /tmp/azcopy/azcopy.tar.gz -C /tmp/azcopy cp /tmp/azcopy/azcopy /usr/bin echo "Installing cmake" - GetFile https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5-Linux-x86_64.tar.gz /tmp/src/cmake-3.13.5-Linux-x86_64.tar.gz + GetFile https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5-Linux-x86_64.tar.gz /tmp/src/cmake-3.13.5-Linux-x86_64.tar.gz tar -zxf /tmp/src/cmake-3.13.5-Linux-x86_64.tar.gz --strip=1 -C /usr else echo "Installing cmake" - GetFile https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5.tar.gz /tmp/src/cmake-3.13.5.tar.gz + GetFile https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5.tar.gz /tmp/src/cmake-3.13.5.tar.gz tar -xf /tmp/src/cmake-3.13.5.tar.gz -C /tmp/src pushd . cd /tmp/src/cmake-3.13.5 @@ -112,10 +112,6 @@ elif [ $DEVICE_TYPE = "gpu" ]; then ${PYTHON_EXE} -m pip install sympy==1.1.1 if [[ $BUILD_EXTR_PAR = *--enable_training* ]]; then ${PYTHON_EXE} -m pip install --upgrade --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html - - # patch pytorch onnx export opset version 10 to export nll_loss - PATH_TO_SYMBOLIC10=$(${PYTHON_EXE} -c 'import torch; import os; print(os.path.join(os.path.dirname(torch.__file__), "onnx/"))') - cp "${SCRIPT_DIR}/pyt_patch/symbolic_opset10.py" "${PATH_TO_SYMBOLIC10}" fi if [[ $BUILD_EXTR_PAR = *--enable_training_python_frontend_e2e_tests* ]]; then ${PYTHON_EXE} -m pip install transformers diff --git a/tools/ci_build/github/linux/docker/scripts/pyt_patch/symbolic_opset10.py b/tools/ci_build/github/linux/docker/scripts/pyt_patch/symbolic_opset10.py deleted file mode 100644 index 6ed324f52f954..0000000000000 --- a/tools/ci_build/github/linux/docker/scripts/pyt_patch/symbolic_opset10.py +++ /dev/null @@ -1,199 +0,0 @@ -from __future__ import absolute_import, division, print_function, unicode_literals - -import torch -from torch.nn.modules.utils import _single, _pair, _triple -import torch.onnx -# This import monkey-patches graph manipulation methods on Graph, used for the -# ONNX symbolics -import torch.onnx.utils - -import torch.onnx.symbolic_helper as sym_help -from torch.onnx.symbolic_helper import parse_args, _unimplemented -import torch.onnx.symbolic_opset9 - - -# EDITING THIS FILE? READ THIS FIRST! -# see Note [Edit Symbolic Files] in symbolic_helper.py - -# This file exports ONNX ops for opset 10 -# Opset 10 is supported by ONNX release 1.5.0 -# release on 04/24/19 - - -@parse_args('v', 'i', 'i', 'none') -def sort(g, self, dim, decending, out=None): - return sym_help._sort_helper(g, self, dim, decending=decending, out=out) - - -@parse_args('v', 'v', 'i', 'i', 'i', 'none') -def topk(g, self, k, dim, largest, sorted, out=None): - return sym_help._topk_helper(g, self, k, dim, largest=largest, sorted=sorted, out=out) - - -def _max_pool(name, tuple_fn, ndims, return_indices): - @parse_args('v', 'is', 'is', 'is', 'is', 'i') - def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): - if not stride: - stride = kernel_size - kwargs = { - 'kernel_shape_i': tuple_fn(kernel_size), - 'pads_i': tuple_fn(padding) * 2, - 'strides_i': tuple_fn(stride), - 'ceil_mode_i': ceil_mode, - } - if set(tuple_fn(dilation)) != {1}: - kwargs['dilations_i'] = tuple_fn(dilation) - # easy but hacky way to get flattened indices values - # to be used to convert the indices values to non-flattened. - # In ONNX the indices are computed as a flatten 1-D tensor, - # so the values in indices are in [0, N x C x D1 x ... x Dn). - # To convert the indices to the same format used by Pytorch, - # we first execute a maxpool with a kernel and stride of 1 on the same input. - # This will result in a tensor of indices in which each index will have it's own value. - # Using this tensor as a reference, we extract the first index of each axis and subtract - # it from each index of this axis in the indices to convert. - # This step will result in a tensor were each dimension has values of indices within - # the dimension it is in. - # For more information : - # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 - if return_indices: - r, indices = g.op("MaxPool", input, outputs=2, **kwargs) - _, flattened_indices = g.op("MaxPool", input, outputs=2, - kernel_shape_i=[1 for _ in range(ndims)], - strides_i=[1 for _ in range(ndims)]) - # convert indices to have non-flattened indices values - from torch.onnx.symbolic_opset9 import sub - s = sym_help._slice_helper(g, flattened_indices, axes=[2 + i for i in range(ndims)], - starts=tuple_fn(0), ends=tuple_fn(1)) - indices = sub(g, indices, s) - return r, indices - else: - r = g.op("MaxPool", input, outputs=1, **kwargs) - return r - - return symbolic_fn - - -max_pool1d = _max_pool("max_pool1d", _single, 1, return_indices=False) -max_pool2d = _max_pool("max_pool2d", _pair, 2, return_indices=False) -max_pool3d = _max_pool("max_pool3d", _triple, 3, return_indices=False) -max_pool1d_with_indices = _max_pool("max_pool1d_with_indices", _single, 1, return_indices=True) -max_pool2d_with_indices = _max_pool("max_pool2d_with_indices", _pair, 2, return_indices=True) -max_pool3d_with_indices = _max_pool("max_pool3d_with_indices", _triple, 3, return_indices=True) - - -def _avg_pool(name, tuple_fn): - @parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none') - def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None): - if not stride: - stride = kernel_size - padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name) - if count_include_pad: - input = g.op("Pad", input, - pads_i=((0,) * 2 + padding) * 2, - mode_s='constant', - value_f=0.) - padding = (0,) * len(padding) - output = g.op("AveragePool", input, - kernel_shape_i=tuple_fn(kernel_size), - strides_i=tuple_fn(stride), - pads_i=padding * 2, - ceil_mode_i=ceil_mode) - return output - return symbolic_fn - - -avg_pool1d = _avg_pool('avg_pool1d', _single) -avg_pool2d = _avg_pool('avg_pool2d', _pair) -avg_pool3d = _avg_pool('avg_pool3d', _triple) - - -def _interpolate(name, dim, interpolate_mode): - def symbolic_fn(g, input, output_size, *args): - scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args) - sym_help._interpolate_warning(interpolate_mode) - align_corners = sym_help._maybe_get_scalar(align_corners) - if align_corners: - return _unimplemented(name, "align_corners == True") - if scales is None: - scales = sym_help._interpolate_size_to_scales(g, input, output_size, dim) - return g.op("Resize", input, scales, mode_s=interpolate_mode) - return symbolic_fn - - -upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest") -upsample_nearest2d = _interpolate('upsample_nearest2d', 4, "nearest") -upsample_nearest3d = _interpolate('upsample_nearest3d', 5, "nearest") -upsample_linear1d = _interpolate('upsample_linear1d', 3, "linear") -upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, "linear") -upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, "linear") - - -def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_scale_factor): - scales, mode = sym_help._interpolate_get_scales_and_mode(g, input, size, scale_factor, - mode, align_corners) - return g.op("Resize", input, scales, mode_s=mode) - - -def _slice(g, input, axes, starts, ends, steps=None, dynamic_slice=False): - if dynamic_slice: - starts = g.op("Unsqueeze", starts, axes_i=[0]) - ends = g.op("Unsqueeze", ends, axes_i=[0]) - axes = g.op("Unsqueeze", axes, axes_i=[0]) - else: - assert len(starts) == len(ends) - assert len(starts) == len(axes) - assert steps is None or len(starts) == len(steps) - if len(starts) == 1 and starts[0] == 0 and ends[0] == 9223372036854775807 \ - and (steps is None or (len(steps) == 1 and steps[0] == 1)): - return input - axes = g.op("Constant", value_t=torch.tensor(axes)) - starts = g.op("Constant", value_t=torch.tensor(starts)) - ends = g.op("Constant", value_t=torch.tensor(ends)) - if steps is None: - return g.op("Slice", input, starts, ends, axes) - steps = g.op("Constant", value_t=torch.tensor(steps)) - return g.op("Slice", input, starts, ends, axes, steps) - - -@parse_args('v', 'v', 'v', 'v', 'i') -def slice(g, self, dim, start, end, step): - if (start.node().kind() != 'onnx::Constant' or - end.node().kind() != 'onnx::Constant' or dim.node().kind() != 'onnx::Constant'): - dynamic_slice = True - else: - start = [sym_help._parse_arg(start, 'i')] - end = [sym_help._parse_arg(end, 'i')] - dim = [sym_help._parse_arg(dim, 'i')] - dynamic_slice = False - return sym_help._slice_helper(g, self, axes=dim, starts=start, ends=end, steps=[step], dynamic_slice=dynamic_slice) - - -@parse_args('v', 'is') -def flip(g, input, dims): - return sym_help._slice_helper(g, input, axes=dims, - starts=[-1] * len(dims), - ends=[-9223372036854775807] * len(dims), - steps=[-1] * len(dims)) - - -def fmod(g, input, other): - return g.op("Mod", input, other, fmod_i=1) - -# put nll in version 10 because ORT does not support some ops (like Equal) beyound opset 10. - - -@parse_args('v', 'v', 'v', 'v', 'i', 'none') -def nll_loss(g, self, target, weight=None, reduction='mean', ignore_index=-100): - if not weight and not ignore_index: - return g.op("nll_loss", self, target) - elif ignore_index: - ignore_index_ = g.op("Constant", value_t=torch.tensor(ignore_index, dtype=torch.int64)) - eq_ = g.op("Equal", target, ignore_index_) - not_eq_ = g.op("Not", eq_) - weight_ = g.op("Cast", not_eq_, to_i=1) # FLOAT = 1; // float - not_eq_int64_ = g.op("Cast", not_eq_, to_i=7) # INT64 = 7; // int64_t - target_ = g.op("Mul", target, not_eq_int64_) - # if weight: - # weight_ = g.op("Mul", weight_, weight) - return g.op("nll_loss", self, target_, weight_)