diff --git a/pytext/config/pytext_config.py b/pytext/config/pytext_config.py index 075258112..0298fd9fc 100644 --- a/pytext/config/pytext_config.py +++ b/pytext/config/pytext_config.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from collections import OrderedDict -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union class ConfigBaseMeta(type): @@ -112,6 +112,9 @@ class TestConfig(ConfigBase): load_snapshot_path: str # Test data path test_path: str = "test.tsv" + #: Field names for the TSV. If this is not set, the first line of each file + #: will be assumed to be a header containing the field names. + field_names: Optional[List[str]] = None use_cuda_if_available: bool = True # Whether to use TensorBoard use_tensorboard: bool = True diff --git a/pytext/data/data.py b/pytext/data/data.py index dd1769e77..9381d65b7 100644 --- a/pytext/data/data.py +++ b/pytext/data/data.py @@ -234,19 +234,22 @@ def __initialize_tensorizers(self): initializer.send(row) @generator_iterator - def batches(self, stage: Stage, rank=0, world_size=1): + def batches(self, stage: Stage, rank=0, world_size=1, data_source=None): """Create batches of tensors to pass to model train_batch. This function yields dictionaries that mirror the `tensorizers` dict passed to `__init__`, ie. the keys will be the same, and the tensors will be the shape expected from the respective tensorizers. `stage` is used to determine which data source is used to create batches. + if data_source is provided, it is used instead of the configured data_sorce + this is to allow setting a different data_source for testing a model """ + data_source = data_source or self.data_source rows = shard( { - Stage.TRAIN: self.data_source.train, - Stage.TEST: self.data_source.test, - Stage.EVAL: self.data_source.eval, + Stage.TRAIN: data_source.train, + Stage.TEST: data_source.test, + Stage.EVAL: data_source.eval, }[stage], rank, world_size, diff --git a/pytext/main.py b/pytext/main.py index 81413d15b..59d9f6364 100644 --- a/pytext/main.py +++ b/pytext/main.py @@ -250,8 +250,14 @@ def gen_default_config(context, task_name, options): default=True, help="Whether to visualize test metrics using TensorBoard.", ) +@click.option( + "--field_names", + default=None, + help="""Field names for the test-path. If this is not set, the first line of + each file will be assumed to be a header containing the field names.""", +) @click.pass_context -def test(context, model_snapshot, test_path, use_cuda, use_tensorboard): +def test(context, model_snapshot, test_path, use_cuda, use_tensorboard, field_names): """Test a trained model snapshot. If model-snapshot is provided, the models and configuration will then be @@ -277,7 +283,11 @@ def test(context, model_snapshot, test_path, use_cuda, use_tensorboard): metric_channels.append(TensorBoardChannel()) try: test_model_from_snapshot_path( - model_snapshot, use_cuda, test_path, metric_channels + model_snapshot, + use_cuda, + test_path, + metric_channels, + field_names=field_names, ) finally: for mc in metric_channels: diff --git a/pytext/task/new_task.py b/pytext/task/new_task.py index 3d8a7b5ea..75b1a6a7e 100644 --- a/pytext/task/new_task.py +++ b/pytext/task/new_task.py @@ -170,9 +170,11 @@ def train( rank=rank, ) - def test(self): + def test(self, data_source): return self.trainer.test( - self.data.batches(Stage.TEST), self.model, self.metric_reporter + self.data.batches(Stage.TEST, data_source=data_source), + self.model, + self.metric_reporter, ) def export(self, model, export_path, metric_channels=None, export_onnx_path=None): diff --git a/pytext/task/serialize.py b/pytext/task/serialize.py index 567996ef4..043dfbf39 100644 --- a/pytext/task/serialize.py +++ b/pytext/task/serialize.py @@ -44,6 +44,7 @@ def load(load_path: str): print(f"Loading model from {load_path}...") state = torch.load(load_path, map_location=lambda storage, loc: storage) config = pytext_config_from_json(state[CONFIG_JSON]) + task = create_task( config.task, metadata=state[DATA_STATE], model_state=state[MODEL_STATE] ) diff --git a/pytext/workflow.py b/pytext/workflow.py index 580054c0a..8681f339f 100644 --- a/pytext/workflow.py +++ b/pytext/workflow.py @@ -7,6 +7,8 @@ from pytext.config import PyTextConfig, TestConfig from pytext.config.component import create_exporter from pytext.data.data_handler import CommonMetadata +from pytext.data.sources.data_source import SafeFileWrapper +from pytext.data.sources.tsv import TSVDataSource from pytext.metric_reporters.channel import Channel from pytext.task import NewTask, Task, create_task, load, save from pytext.utils import set_random_seeds @@ -162,6 +164,7 @@ def test_model( test_config.test_path, metric_channels, test_out_path, + test_config.field_names, ) @@ -171,6 +174,7 @@ def test_model_from_snapshot_path( test_path: Optional[str] = None, metric_channels: Optional[List[Channel]] = None, test_out_path: str = "", + field_names: Optional[List[str]] = None, ): _set_cuda(use_cuda_if_available) task, train_config = load(snapshot_path) @@ -190,7 +194,15 @@ def test_model_from_snapshot_path( test_out_path = train_config.task.metric_reporter.output_path if isinstance(task, NewTask): - test_results = task.test() + if test_path: + data_source = TSVDataSource( + test_file=SafeFileWrapper(test_path), + schema=task.data.data_source.schema, + field_names=field_names, + ) + else: + data_source = task.data.data_source + test_results = task.test(data_source) else: if not test_path: test_path = train_config.task.data_handler.test_path