Skip to content

Commit 0b7db65

Browse files
m3rlin45facebook-github-bot
authored andcommitted
Document top-level API a little, make testing from the command line easier. (facebookresearch#106)
Summary: Pull Request resolved: facebookresearch#106 Add docstrings and types to a couple functions exported to the top-level PyText API. Update the test cli mode to take in a PyTextConfig rather than a TestConfig for easier command line usage. Reviewed By: ahhegazy Differential Revision: D13367488 fbshipit-source-id: 2347e63f4e31a737566ef6aabd3399a79e6ff023
1 parent dc81bac commit 0b7db65

File tree

5 files changed

+90
-38
lines changed

5 files changed

+90
-38
lines changed

pytext/__init__.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
import json
34
import uuid
5+
from typing import Callable, Mapping, Optional
46

57
import numpy as np
68
from caffe2.python import workspace
79
from caffe2.python.predictor import predictor_exporter
810

911
from .builtin_task import register_builtin_tasks
12+
from .config import PyTextConfig, config_from_json
1013
from .config.component import create_featurizer
1114
from .data.featurizer import InputRecord
1215
from .utils.onnx_utils import CAFFE2_DB_TYPE, convert_caffe2_blob_name
@@ -15,6 +18,9 @@
1518
register_builtin_tasks()
1619

1720

21+
Predictor = Callable[[Mapping[str, str]], Mapping[str, np.array]]
22+
23+
1824
def _predict(workspace_id, feature_config, predict_net, featurizer, input):
1925
workspace.SwitchWorkspace(workspace_id)
2026
features = featurizer.featurize(InputRecord(**input))
@@ -49,7 +55,26 @@ def _predict(workspace_id, feature_config, predict_net, featurizer, input):
4955
}
5056

5157

52-
def create_predictor(config, model_file=None):
58+
def load_config(filename: str) -> PyTextConfig:
59+
"""
60+
Load a PyText configuration file from a file path.
61+
See pytext.config.pytext_config for more info on configs.
62+
"""
63+
with open(filename) as file:
64+
config_json = json.loads(file.read())
65+
if "config" not in config_json:
66+
return config_from_json(PyTextConfig, config_json)
67+
return config_from_json(PyTextConfig, config_json["config"])
68+
69+
70+
def create_predictor(
71+
config: PyTextConfig, model_file: Optional[str] = None
72+
) -> Predictor:
73+
"""
74+
Create a simple prediction API from a training config and an exported caffe2
75+
model file. This model file should be created by calling export on a trained
76+
model snapshot.
77+
"""
5378
workspace_id = str(uuid.uuid4())
5479
workspace.SwitchWorkspace(workspace_id, True)
5580
predict_net = predictor_exporter.prepare_prediction_net(

pytext/config/serialize.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77
from .pytext_config import PyTextConfig, TestConfig
88

99

10-
class Mode(Enum):
11-
TRAIN = "train"
12-
TEST = "test"
13-
14-
1510
class ConfigParseError(Exception):
1611
pass
1712

@@ -219,13 +214,10 @@ def _get_class_type(cls):
219214
return cls.__origin__ if hasattr(cls, "__origin__") else cls
220215

221216

222-
def parse_config(mode, config_json):
217+
def parse_config(config_json):
223218
"""
224219
Parse PyTextConfig object from parameter string or parameter file
225220
"""
226-
config_cls = {Mode.TRAIN: PyTextConfig, Mode.TEST: TestConfig}[mode]
227-
# TODO T32608471 should assume the entire json is PyTextConfig later, right
228-
# now we're matching the file format for pytext trainer.py inside fbl
229221
if "config" not in config_json:
230-
return config_from_json(config_cls, config_json)
231-
return config_from_json(config_cls, config_json["config"])
222+
return config_from_json(PyTextConfig, config_json)
223+
return config_from_json(PyTextConfig, config_json["config"])

pytext/config/test/pytext_all_config_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import unittest
77

88
from pytext.builtin_task import register_builtin_tasks
9-
from pytext.config.serialize import Mode, parse_config
9+
from pytext.config.serialize import parse_config
1010

1111

1212
register_builtin_tasks()
@@ -32,6 +32,5 @@ def test_load_all_configs(self):
3232
print("--- loading:", filename)
3333
with open(filename) as file:
3434
config_json = json.load(file)
35-
# Most configs don't work in Mode.TEST
36-
config = parse_config(Mode.TRAIN, config_json)
35+
config = parse_config(config_json)
3736
self.assertIsNotNone(config)

pytext/main.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
from pytext import create_predictor
1212
from pytext.config import PyTextConfig, TestConfig
13-
from pytext.config.serialize import Mode, config_from_json, config_to_json, parse_config
13+
from pytext.config.serialize import config_from_json, config_to_json, parse_config
1414
from pytext.task import load
1515
from pytext.utils.documentation_helper import (
1616
ROOT_CONFIG,
@@ -22,7 +22,7 @@
2222
from pytext.workflow import (
2323
batch_predict,
2424
export_saved_model_to_caffe2,
25-
test_model,
25+
test_model_from_snapshot_path,
2626
train_model,
2727
)
2828
from torch.multiprocessing.spawn import spawn
@@ -180,33 +180,57 @@ def gen_default_config(context, task_name, options):
180180

181181

182182
@main.command()
183+
@click.option(
184+
"--model-snapshot",
185+
default="",
186+
help="load model snapshot and test configuration from this file",
187+
)
188+
@click.option("--test-path", default="", help="path to test data")
189+
@click.option(
190+
"--use-cuda/--no-cuda",
191+
default=None,
192+
help="Run supported parts of the model on GPU if available.",
193+
)
183194
@click.pass_context
184-
def test(context):
185-
"""Test a trained model snapshot."""
186-
config_json = context.obj.load_config()
187-
config = parse_config(Mode.TEST, config_json)
195+
def test(context, model_snapshot, test_path, use_cuda):
196+
"""Test a trained model snapshot.
197+
198+
If model-snapshot is provided, the models and configuration will then be loaded from
199+
the snapshot rather than any passed config file.
200+
Otherwise, a config file will be loaded.
201+
"""
202+
if model_snapshot:
203+
print(f"Loading model snapshot and config from {model_snapshot}")
204+
if use_cuda is None:
205+
raise Exception(
206+
"if --model-snapshot is set --use-cuda/--no-cuda must be set"
207+
)
208+
else:
209+
print(f"No model snapshot provided, loading from config")
210+
config = parse_config(context.obj.load_config())
211+
model_snapshot = config.save_snapshot_path
212+
use_cuda = config.use_cuda_if_available
213+
print(f"Configured model snapshot {model_snapshot}")
188214
print("\n=== Starting testing...")
189-
test_model(config)
215+
test_model_from_snapshot_path(model_snapshot, use_cuda, test_path)
190216

191217

192218
@main.command()
193219
@click.pass_context
194220
def train(context):
195221
"""Train a model and save the best snapshot."""
196-
config_json = context.obj.load_config()
197-
config = parse_config(Mode.TRAIN, config_json)
222+
config = parse_config(context.obj.load_config())
198223
print("\n===Starting training...")
199224
if config.distributed_world_size == 1:
200225
train_model(config)
201226
else:
202227
train_model_distributed(config)
203228
print("\n=== Starting testing...")
204-
test_config = TestConfig(
205-
load_snapshot_path=config.save_snapshot_path,
206-
test_path=config.task.data_handler.test_path,
207-
use_cuda_if_available=config.use_cuda_if_available,
229+
test_model_from_snapshot_path(
230+
config.save_snapshot_path,
231+
config.use_cuda_if_available,
232+
config.task.data_handler.test_path,
208233
)
209-
test_model(test_config)
210234

211235

212236
@main.command()
@@ -215,7 +239,7 @@ def train(context):
215239
@click.pass_context
216240
def export(context, model, output_path):
217241
"""Convert a pytext model snapshot to a caffe2 model."""
218-
config = parse_config(Mode.TRAIN, context.obj.load_config())
242+
config = parse_config(context.obj.load_config())
219243
model = model or config.save_snapshot_path
220244
output_path = output_path or config.export_caffe2_path
221245
print(f"Exporting {model} to {output_path}")
@@ -227,7 +251,7 @@ def export(context, model, output_path):
227251
@click.pass_context
228252
def predict(context, exported_model):
229253
"""Start a repl executing examples against a caffe2 model."""
230-
config = parse_config(Mode.TRAIN, context.obj.load_config())
254+
config = parse_config(context.obj.load_config())
231255
print(f"Loading model from {exported_model or config.export_caffe2_path}")
232256
predictor = create_predictor(config, exported_model)
233257

pytext/workflow.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
33
import os
4-
from typing import Any, Dict, List, Tuple, get_type_hints
4+
from typing import Any, Dict, List, Optional, Tuple, get_type_hints
55

66
import torch
77
from pytext.config import PyTextConfig, TestConfig
@@ -103,16 +103,28 @@ def export_saved_model_to_caffe2(
103103

104104

105105
def test_model(test_config: TestConfig, metrics_channel: Channel = None) -> Any:
106-
_set_cuda(test_config.use_cuda_if_available)
106+
return test_model_from_snapshot_path(
107+
test_config.load_snapshot_path,
108+
test_config.use_cuda_if_available,
109+
test_config.test_path,
110+
metrics_channel,
111+
)
112+
107113

108-
task, train_config = load(test_config.load_snapshot_path)
114+
def test_model_from_snapshot_path(
115+
snapshot_path: str,
116+
use_cuda_if_available: bool,
117+
test_path: Optional[str] = None,
118+
metrics_channel: Optional[Channel] = None,
119+
):
120+
_set_cuda(use_cuda_if_available)
121+
task, train_config = load(snapshot_path)
122+
if not test_path:
123+
test_path = train_config.task.data_handler.test_path
109124
if metrics_channel is not None:
110125
task.metric_reporter.add_channel(metrics_channel)
111126

112-
return (
113-
task.test(test_config.test_path),
114-
train_config.task.metric_reporter.output_path,
115-
)
127+
return (task.test(test_path), train_config.task.metric_reporter.output_path)
116128

117129

118130
def batch_predict(model_file: str, examples: List[Dict[str, Any]]):

0 commit comments

Comments
 (0)