Skip to content

Commit

Permalink
fix using provided test file in test workflow with NewTask (facebookr…
Browse files Browse the repository at this point in the history
…esearch#431)

Summary:
Pull Request resolved: facebookresearch#431

If task is of instance NewTask we are currently ignoring the provided test file, and testing on the file specified in the config. This allows actually setting the test file

Reviewed By: bethebunny

Differential Revision: D14526577

fbshipit-source-id: 7b4d366b2daec30fe0ebc5aec428498bc22d23b1
  • Loading branch information
rutyrinott authored and facebook-github-bot committed Mar 29, 2019
1 parent 4632b6a commit ae0321f
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 10 deletions.
5 changes: 4 additions & 1 deletion pytext/config/pytext_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from collections import OrderedDict
from typing import Any, Optional, Union
from typing import Any, List, Optional, Union


class ConfigBaseMeta(type):
Expand Down Expand Up @@ -112,6 +112,9 @@ class TestConfig(ConfigBase):
load_snapshot_path: str
# Test data path
test_path: str = "test.tsv"
#: Field names for the TSV. If this is not set, the first line of each file
#: will be assumed to be a header containing the field names.
field_names: Optional[List[str]] = None
use_cuda_if_available: bool = True
# Whether to use TensorBoard
use_tensorboard: bool = True
Expand Down
11 changes: 7 additions & 4 deletions pytext/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,19 +234,22 @@ def __initialize_tensorizers(self):
initializer.send(row)

@generator_iterator
def batches(self, stage: Stage, rank=0, world_size=1):
def batches(self, stage: Stage, rank=0, world_size=1, data_source=None):
"""Create batches of tensors to pass to model train_batch.
This function yields dictionaries that mirror the `tensorizers` dict passed to
`__init__`, ie. the keys will be the same, and the tensors will be the shape
expected from the respective tensorizers.
`stage` is used to determine which data source is used to create batches.
if data_source is provided, it is used instead of the configured data_sorce
this is to allow setting a different data_source for testing a model
"""
data_source = data_source or self.data_source
rows = shard(
{
Stage.TRAIN: self.data_source.train,
Stage.TEST: self.data_source.test,
Stage.EVAL: self.data_source.eval,
Stage.TRAIN: data_source.train,
Stage.TEST: data_source.test,
Stage.EVAL: data_source.eval,
}[stage],
rank,
world_size,
Expand Down
14 changes: 12 additions & 2 deletions pytext/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,14 @@ def gen_default_config(context, task_name, options):
default=True,
help="Whether to visualize test metrics using TensorBoard.",
)
@click.option(
"--field_names",
default=None,
help="""Field names for the test-path. If this is not set, the first line of
each file will be assumed to be a header containing the field names.""",
)
@click.pass_context
def test(context, model_snapshot, test_path, use_cuda, use_tensorboard):
def test(context, model_snapshot, test_path, use_cuda, use_tensorboard, field_names):
"""Test a trained model snapshot.
If model-snapshot is provided, the models and configuration will then be
Expand All @@ -277,7 +283,11 @@ def test(context, model_snapshot, test_path, use_cuda, use_tensorboard):
metric_channels.append(TensorBoardChannel())
try:
test_model_from_snapshot_path(
model_snapshot, use_cuda, test_path, metric_channels
model_snapshot,
use_cuda,
test_path,
metric_channels,
field_names=field_names,
)
finally:
for mc in metric_channels:
Expand Down
6 changes: 4 additions & 2 deletions pytext/task/new_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,11 @@ def train(
rank=rank,
)

def test(self):
def test(self, data_source):
return self.trainer.test(
self.data.batches(Stage.TEST), self.model, self.metric_reporter
self.data.batches(Stage.TEST, data_source=data_source),
self.model,
self.metric_reporter,
)

def export(self, model, export_path, metric_channels=None, export_onnx_path=None):
Expand Down
1 change: 1 addition & 0 deletions pytext/task/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def load(load_path: str):
print(f"Loading model from {load_path}...")
state = torch.load(load_path, map_location=lambda storage, loc: storage)
config = pytext_config_from_json(state[CONFIG_JSON])

task = create_task(
config.task, metadata=state[DATA_STATE], model_state=state[MODEL_STATE]
)
Expand Down
14 changes: 13 additions & 1 deletion pytext/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pytext.config import PyTextConfig, TestConfig
from pytext.config.component import create_exporter
from pytext.data.data_handler import CommonMetadata
from pytext.data.sources.data_source import SafeFileWrapper
from pytext.data.sources.tsv import TSVDataSource
from pytext.metric_reporters.channel import Channel
from pytext.task import NewTask, Task, create_task, load, save
from pytext.utils import set_random_seeds
Expand Down Expand Up @@ -162,6 +164,7 @@ def test_model(
test_config.test_path,
metric_channels,
test_out_path,
test_config.field_names,
)


Expand All @@ -171,6 +174,7 @@ def test_model_from_snapshot_path(
test_path: Optional[str] = None,
metric_channels: Optional[List[Channel]] = None,
test_out_path: str = "",
field_names: Optional[List[str]] = None,
):
_set_cuda(use_cuda_if_available)
task, train_config = load(snapshot_path)
Expand All @@ -190,7 +194,15 @@ def test_model_from_snapshot_path(
test_out_path = train_config.task.metric_reporter.output_path

if isinstance(task, NewTask):
test_results = task.test()
if test_path:
data_source = TSVDataSource(
test_file=SafeFileWrapper(test_path),
schema=task.data.data_source.schema,
field_names=field_names,
)
else:
data_source = task.data.data_source
test_results = task.test(data_source)
else:
if not test_path:
test_path = train_config.task.data_handler.test_path
Expand Down

0 comments on commit ae0321f

Please sign in to comment.