forked from facebookresearch/pytext
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtask_load_save_test.py
156 lines (141 loc) · 6.13 KB
/
task_load_save_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import itertools
import tempfile
import unittest
import torch
from pytext.common.constants import Stage
from pytext.config import LATEST_VERSION, PyTextConfig
from pytext.config.component import create_optimizer, create_scheduler
from pytext.data import Data
from pytext.data.sources import TSVDataSource
from pytext.optimizer import Adam, Optimizer
from pytext.optimizer.scheduler import Scheduler
from pytext.task import create_task
from pytext.task.serialize import (
CheckpointManager,
get_latest_checkpoint_path,
load,
save,
)
from pytext.task.tasks import DocumentClassificationTask
from pytext.trainers.training_state import TrainingState
from pytext.utils import test
tests_module = test.import_tests_module()
class TaskLoadSaveTest(unittest.TestCase):
def assertModulesEqual(self, mod1, mod2, message=None):
for p1, p2 in itertools.zip_longest(mod1.parameters(), mod2.parameters()):
self.assertTrue(p1.equal(p2), message)
def test_load_saved_model(self):
with tempfile.NamedTemporaryFile() as snapshot_file:
train_data = tests_module.test_file("train_data_tiny.tsv")
eval_data = tests_module.test_file("test_data_tiny.tsv")
config = PyTextConfig(
task=DocumentClassificationTask.Config(
data=Data.Config(
source=TSVDataSource.Config(
train_filename=train_data,
eval_filename=eval_data,
field_names=["label", "slots", "text"],
)
)
),
version=LATEST_VERSION,
save_snapshot_path=snapshot_file.name,
)
task = create_task(config.task)
model = task.model
save(config, model, meta=None, tensorizers=task.data.tensorizers)
task2, config2, training_state_none = load(snapshot_file.name)
self.assertEqual(config, config2)
self.assertModulesEqual(model, task2.model)
self.assertIsNone(training_state_none)
model.eval()
task2.model.eval()
inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3])
self.assertEqual(model(*inputs).tolist(), task2.model(*inputs).tolist())
def assertOptimizerEqual(self, optim_1, optim_2, msg=None):
self.assertTrue(type(optim_1) is Adam and type(optim_2) is Adam, msg)
state_dict_1 = optim_1.state_dict()
state_dict_2 = optim_2.state_dict()
self.assertEqual(len(state_dict_1), len(state_dict_2))
params_1 = optim_1.state_dict()["param_groups"][0]["params"]
params_2 = optim_1.state_dict()["param_groups"][0]["params"]
self.assertEqual(len(params_1), len(params_2), msg)
def assertCheckpointEqual(
self,
model,
config,
training_state,
model_restored,
config_restored,
training_state_restored,
):
optimizer_restored = training_state_restored.optimizer
scheduler_restored = training_state_restored.scheduler
self.assertOptimizerEqual(training_state.optimizer, optimizer_restored)
self.assertEqual(training_state.start_time, training_state_restored.start_time)
self.assertEqual(training_state.epoch, training_state_restored.epoch)
self.assertEqual(training_state.rank, training_state_restored.rank)
self.assertEqual(training_state.stage, training_state_restored.stage)
self.assertEqual(
training_state.epochs_since_last_improvement,
training_state_restored.epochs_since_last_improvement,
)
self.assertIsNotNone(scheduler_restored)
self.assertIsNotNone(config_restored)
self.assertModulesEqual(model, model_restored)
model.eval()
model_restored.eval()
inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3])
self.assertEqual(model(*inputs).tolist(), model_restored(*inputs).tolist())
def test_load_checkpoint(self):
with tempfile.NamedTemporaryFile() as checkpoint_file:
train_data = tests_module.test_file("train_data_tiny.tsv")
eval_data = tests_module.test_file("test_data_tiny.tsv")
config = PyTextConfig(
task=DocumentClassificationTask.Config(
data=Data.Config(
source=TSVDataSource.Config(
train_filename=train_data,
eval_filename=eval_data,
field_names=["label", "slots", "text"],
)
)
),
version=LATEST_VERSION,
save_snapshot_path=checkpoint_file.name,
)
task = create_task(config.task)
model = task.model
# test checkpoint saving and loading
optimizer = create_optimizer(Adam.Config(), model)
scheduler = create_scheduler(Scheduler.Config(), optimizer)
training_state = TrainingState(
model=model,
optimizer=optimizer,
scheduler=scheduler,
start_time=0,
epoch=0,
rank=0,
stage=Stage.TRAIN,
epochs_since_last_improvement=0,
best_model_state=None,
best_model_metric=None,
tensorizers=task.data.tensorizers,
)
id = "epoch-1"
saved_path = save(
config, model, None, task.data.tensorizers, training_state, id
)
# TODO: fix get_latest_checkpoint_path T53664139
# self.assertEqual(saved_path, get_latest_checkpoint_path())
task_restored, config_restored, training_state_restored = load(saved_path)
self.assertCheckpointEqual(
model,
config,
training_state,
task_restored.model,
config_restored,
training_state_restored,
)