forked from facebookresearch/pytext
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredictor_test.py
88 lines (76 loc) · 3.36 KB
/
predictor_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import tempfile
import unittest
import numpy as np
from pytext import batch_predict_caffe2_model
from pytext.config import LATEST_VERSION, PyTextConfig
from pytext.data import Data
from pytext.data.sources import TSVDataSource
from pytext.data.tensorizers import (
FloatListTensorizer,
LabelTensorizer,
TokenTensorizer,
)
from pytext.models.doc_model import DocModel
from pytext.task import create_task
from pytext.task.serialize import save
from pytext.task.tasks import DocumentClassificationTask
from pytext.utils import test
tests_module = test.import_tests_module()
class PredictorTest(unittest.TestCase):
@unittest.skip("C2 deprecated")
def test_batch_predict_caffe2_model(self):
with tempfile.NamedTemporaryFile() as snapshot_file, tempfile.NamedTemporaryFile() as caffe2_model_file:
train_data = tests_module.test_file("train_data_tiny.tsv")
eval_data = tests_module.test_file("test_data_tiny.tsv")
config = PyTextConfig(
task=DocumentClassificationTask.Config(
model=DocModel.Config(
inputs=DocModel.Config.ModelInput(
tokens=TokenTensorizer.Config(),
dense=FloatListTensorizer.Config(
column="dense", dim=1, error_check=True
),
labels=LabelTensorizer.Config(),
)
),
data=Data.Config(
source=TSVDataSource.Config(
train_filename=train_data,
eval_filename=eval_data,
test_filename=eval_data,
field_names=["label", "slots", "text", "dense"],
)
),
),
version=21,
save_snapshot_path=snapshot_file.name,
export_caffe2_path=caffe2_model_file.name,
)
task = create_task(config.task)
task.export(task.model, caffe2_model_file.name)
model = task.model
save(config, model, meta=None, tensorizers=task.data.tensorizers)
pt_results = task.predict(task.data.data_source.test)
def assert_caffe2_results_correct(caffe2_results):
for pt_res, res in zip(pt_results, caffe2_results):
np.testing.assert_array_almost_equal(
pt_res["score"].tolist()[0],
[score[0] for score in res.values()],
)
results = batch_predict_caffe2_model(
snapshot_file.name, caffe2_model_file.name
)
self.assertEqual(4, len(results))
assert_caffe2_results_correct(results)
results = batch_predict_caffe2_model(
snapshot_file.name, caffe2_model_file.name, cache_size=2
)
self.assertEqual(4, len(results))
assert_caffe2_results_correct(results)
results = batch_predict_caffe2_model(
snapshot_file.name, caffe2_model_file.name, cache_size=-1
)
self.assertEqual(4, len(results))
assert_caffe2_results_correct(results)