From 7243b80a8e5e59ccf138a648658a03346d3d69e2 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Mon, 11 May 2020 18:18:57 -0700 Subject: [PATCH] Fix state_dict/load_state_dict with frozen_weight --- .../python/onnxruntime_test_ort_trainer.py | 51 +++++++++++++- onnxruntime/test/testdata/ckpt_mnist.pt | Bin 1595972 -> 1595944 bytes orttraining/orttraining/python/ort_trainer.py | 64 ++++++++++++------ 3 files changed, 93 insertions(+), 22 deletions(-) diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py index c1ac2faad1c2d..f55cc4ad2d4aa 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py @@ -6,6 +6,7 @@ import pytest import sys import copy +import numpy as np from numpy.testing import assert_allclose, assert_array_equal import onnx @@ -378,7 +379,7 @@ def testMNISTStateDict(self): loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) state_dict = trainer.state_dict() - assert state_dict.keys() == {'model_.fc1.bias', 'model_.fc1.weight', 'model_.fc2.bias', 'model_.fc2.weight'} + assert state_dict.keys() == {'fc1.bias', 'fc1.weight', 'fc2.bias', 'fc2.weight'} def testMNISTSaveAsONNX(self): torch.manual_seed(1) @@ -446,7 +447,53 @@ def testMNISTFrozenWeight(self): data = data.reshape(data.shape[0], -1) loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - assert 'fc1.weight' not in trainer.state_dict() and 'fc2.weight' in trainer.state_dict() + + fc1_trainstep_1 = trainer.state_dict()['fc1.weight'] + fc2_trainstep_1 = trainer.state_dict()['fc2.weight'] + + loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) + + fc1_trainstep_2 = trainer.state_dict()['fc1.weight'] + fc2_trainstep_2 = trainer.state_dict()['fc2.weight'] + assert np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and \ + not np.array_equal(fc2_trainstep_1, fc2_trainstep_2) + + def testMNISTFrozenWeightCheckpoint(self): + torch.manual_seed(1) + device = torch.device("cuda") + + mnist = MNISTWrapper() + train_loader, test_loader = mnist.get_loaders() + model, model_desc = mnist.get_model() + + trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=['fc1.weight']) + + learningRate = 0.02 + epoch = 0 + + # do one train step + data, target = next(iter(train_loader)) + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) + + # do one eval step + data, target = next(iter(train_loader)) + data, target = data.to(device), target.to(device) + data = data.reshape(data.shape[0], -1) + + loss, _ = trainer.eval_step(data, target) + + # save checkpoint, load model and compare + state_dict = trainer.state_dict() + + new_model, _ = mnist.get_model() + trainer = mnist.get_trainer(new_model, model_desc, device, frozen_weights=['fc1.weight']) + trainer.load_state_dict(state_dict) + + ckpt_loss, _ = trainer.eval_step(data, target) + assert loss == ckpt_loss def testBertTrainingBasic(self): expected_losses = [ diff --git a/onnxruntime/test/testdata/ckpt_mnist.pt b/onnxruntime/test/testdata/ckpt_mnist.pt index 01a9a9b4488c1331915e070bc7197a8b0a59b6c7..e3b5723604577474d142c79d9c0632c102f8c1a7 100644 GIT binary patch delta 382 zcmX@IC2_@;#0l?s%qkQf3zIOgG`2LFJdIIZn+qtD zmTah3o|>7SQBo+G!K9&)!JHu&!3UJEG=UjpWN2Bn_M1hHkv8ko0(d37W`UXfQq$L~aC1oZS7b<2jX=r3HX9z~{0kvD2 zz-=`!G@ES6q-um@Do9g#YG!&yNuee-ZH7jsCT12ElRKHLoP%jhnZPo3D+Vzl~d