Skip to content

Commit

Permalink
fix - test flow only works with TSVDataSource (facebookresearch#448)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#448

This diff D14526577 enabled passing test file for new task, but works only with TSVDataSource.  Generalize it so it works for other data sources.

Reviewed By: hikushalhere

Differential Revision: D14781441

fbshipit-source-id: a88b5e0a8d46704450263554a100e7108ca54d31
  • Loading branch information
borguz authored and facebook-github-bot committed Apr 4, 2019
1 parent d6054c1 commit f501bf9
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions pytext/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
import torch
from pytext.common.constants import Stage
from pytext.config import PyTextConfig, TestConfig
from pytext.config.component import create_exporter
from pytext.config.component import ComponentType, create_component, create_exporter
from pytext.data.data import Batcher
from pytext.data.data_handler import CommonMetadata
from pytext.data.sources.data_source import SafeFileWrapper
from pytext.data.sources.tsv import TSVDataSource
from pytext.metric_reporters.channel import Channel
from pytext.metric_reporters.metric_reporter import MetricReporter
from pytext.task import NewTask, Task, create_task, load, save
Expand Down Expand Up @@ -205,7 +203,9 @@ def test_model_from_snapshot_path(
test_out_path = train_config.task.metric_reporter.output_path

if isinstance(task, NewTask):
data_source = _get_data_source(test_path, field_names, task)
data_source = _get_data_source(
test_path, train_config.task.data.source, field_names, task
)
test_results = task.test(data_source)
else:
if not test_path:
Expand All @@ -214,12 +214,13 @@ def test_model_from_snapshot_path(
return test_results, test_out_path, metric_channels


def _get_data_source(test_path, field_names, task):
if test_path:
data_source = TSVDataSource(
test_file=SafeFileWrapper(test_path),
schema=task.data.data_source.schema,
field_names=field_names,
def _get_data_source(test_path, source_config, field_names, task):
if test_path and hasattr(source_config, "test_filename"):
source_config.test_filename = test_path
if field_names and hasattr(source_config, "field_names"):
source_config.field_names = field_names
data_source = create_component(
ComponentType.DATA_SOURCE, source_config, task.data.data_source.schema
)
else:
data_source = task.data.data_source
Expand Down

0 comments on commit f501bf9

Please sign in to comment.