Skip to content

Commit

Permalink
add option to get logits from output layer (facebookresearch#435)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#435

Adding a workflow to print the model logits into a file
This is useful for models for which we don't want to perform evaluation, e.g. producing sentence encoding

Reviewed By: liaimi

Differential Revision: D14670793

fbshipit-source-id: f790d6c03169f94c8d5e78785c09e78f30a8d697
  • Loading branch information
rutyrinott authored and facebook-github-bot committed Apr 2, 2019
1 parent 8e9e152 commit e8cec67
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 23 deletions.
65 changes: 52 additions & 13 deletions pytext/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytext.workflow import (
export_saved_model_to_caffe2,
export_saved_model_to_torchscript,
get_logits as workflow_get_logits,
prepare_task_metadata,
test_model_from_snapshot_path,
train_model,
Expand Down Expand Up @@ -264,19 +265,9 @@ def test(context, model_snapshot, test_path, use_cuda, use_tensorboard, field_na
loaded from the snapshot rather than any passed config file.
Otherwise, a config file will be loaded.
"""
if model_snapshot:
print(f"Loading model snapshot and config from {model_snapshot}")
if use_cuda is None:
raise Exception(
"if --model-snapshot is set --use-cuda/--no-cuda must be set"
)
else:
print(f"No model snapshot provided, loading from config")
config = context.obj.load_config()
model_snapshot = config.save_snapshot_path
use_cuda = config.use_cuda_if_available
use_tensorboard = config.use_tensorboard
print(f"Configured model snapshot {model_snapshot}")
model_snapshot, use_cuda, use_tensorboard = _get_model_snapshot(
context, model_snapshot, use_cuda, use_tensorboard
)
print("\n=== Starting testing...")
metric_channels = []
if use_tensorboard:
Expand All @@ -294,6 +285,23 @@ def test(context, model_snapshot, test_path, use_cuda, use_tensorboard, field_na
mc.close()


def _get_model_snapshot(context, model_snapshot, use_cuda, use_tensorboard):
if model_snapshot:
print(f"Loading model snapshot and config from {model_snapshot}")
if use_cuda is None:
raise Exception(
"if --model-snapshot is set --use-cuda/--no-cuda must be set"
)
else:
print(f"No model snapshot provided, loading from config")
config = context.obj.load_config()
model_snapshot = config.save_snapshot_path
use_cuda = config.use_cuda_if_available
use_tensorboard = config.use_tensorboard
print(f"Configured model snapshot {model_snapshot}")
return model_snapshot, use_cuda, use_tensorboard


@main.command()
@click.pass_context
def train(context):
Expand Down Expand Up @@ -390,5 +398,36 @@ def predict_py(context, model_file):
break


@main.command()
@click.option(
"--model-snapshot",
default="",
help="load model snapshot and test configuration from this file",
)
@click.option("--test-path", default="", help="path to test data")
@click.option("--output-path", default="", help="path to save logits")
@click.option(
"--use-cuda/--no-cuda",
default=None,
help="Run supported parts of the model on GPU if available.",
)
@click.option(
"--field_names",
default=None,
help="""Field names for the test-path. If this is not set, the first line of
each file will be assumed to be a header containing the field names.""",
)
@click.pass_context
def get_logits(context, model_snapshot, test_path, use_cuda, output_path, field_names):
"""print logits from a trained model snapshot to output_path
"""

model_snapshot, use_cuda, _ = _get_model_snapshot(
context, model_snapshot, use_cuda, False
)
print("\n=== Starting get_logits...")
workflow_get_logits(model_snapshot, use_cuda, output_path, test_path, field_names)


if __name__ == "__main__":
main()
1 change: 0 additions & 1 deletion pytext/task/new_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pytext.config import ConfigBase, PyTextConfig
from pytext.config.component import ComponentType, create_component, create_trainer
from pytext.data.data import Data
from pytext.data.tensorizers import Tensorizer
from pytext.metric_reporters import (
ClassificationMetricReporter,
MetricReporter,
Expand Down
49 changes: 40 additions & 9 deletions pytext/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
from typing import Any, Dict, List, Optional, Tuple, get_type_hints

import torch
from pytext.common.constants import Stage
from pytext.config import PyTextConfig, TestConfig
from pytext.config.component import create_exporter
from pytext.data.data import Batcher
from pytext.data.data_handler import CommonMetadata
from pytext.data.sources.data_source import SafeFileWrapper
from pytext.data.sources.tsv import TSVDataSource
from pytext.metric_reporters.channel import Channel
from pytext.metric_reporters.metric_reporter import MetricReporter
from pytext.task import NewTask, Task, create_task, load, save
from pytext.utils import set_random_seeds
from pytext.utils.distributed import dist_init

from .utils import cuda, precision, timing

Expand Down Expand Up @@ -194,14 +196,7 @@ def test_model_from_snapshot_path(
test_out_path = train_config.task.metric_reporter.output_path

if isinstance(task, NewTask):
if test_path:
data_source = TSVDataSource(
test_file=SafeFileWrapper(test_path),
schema=task.data.data_source.schema,
field_names=field_names,
)
else:
data_source = task.data.data_source
data_source = _get_data_source(test_path, field_names, task)
test_results = task.test(data_source)
else:
if not test_path:
Expand All @@ -210,6 +205,42 @@ def test_model_from_snapshot_path(
return test_results, test_out_path, metric_channels


def _get_data_source(test_path, field_names, task):
if test_path:
data_source = TSVDataSource(
test_file=SafeFileWrapper(test_path),
schema=task.data.data_source.schema,
field_names=field_names,
)
else:
data_source = task.data.data_source
return data_source


def get_logits(
snapshot_path: str,
use_cuda_if_available: bool,
output_path: Optional[str] = None,
test_path: Optional[str] = None,
field_names: Optional[List[str]] = None,
):
_set_cuda(use_cuda_if_available)
task, train_config = load(snapshot_path)
if isinstance(task, NewTask):
task.model.eval()
data_source = _get_data_source(test_path, field_names, task)
task.data.batcher = Batcher()
batches = task.data.batches(Stage.TEST, data_source=data_source)
results = []
for batch in batches:
model_inputs = task.model.arrange_model_inputs(batch)
model_outputs = task.model(*model_inputs)
MetricReporter.aggregate_data(results, model_outputs)
with open(output_path, "w", encoding="utf-8") as fout:
for row in results:
fout.write(f"{row}\n")


def batch_predict(model_file: str, examples: List[Dict[str, Any]]):
task, train_config = load(model_file)
return task.predict(examples)

0 comments on commit e8cec67

Please sign in to comment.