Skip to content

Commit

Permalink
Make dense feature optional for bytes model (facebookresearch#429)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#429

^title

Reviewed By: ppuliu

Differential Revision: D14678130

fbshipit-source-id: 297d61cff25175bd51bfdf118abe0d22e7b70e25
  • Loading branch information
seayoung1112 authored and facebook-github-bot committed Mar 29, 2019
1 parent 962f01e commit aa46761
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
13 changes: 8 additions & 5 deletions pytext/models/doc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,29 @@ class ModelInput(Model.Config.ModelInput):
inputs: ModelInput = ModelInput()
embedding: WordEmbedding.Config = WordEmbedding.Config()

input_names = ["tokens", "tokens_lens"]
output_names = ["scores"]

def arrange_model_inputs(self, tensor_dict):
tokens, seq_lens = tensor_dict["tokens"]
return (tokens, seq_lens)

def arrange_targets(self, tensor_dict):
return tensor_dict["labels"]

def get_export_input_names(self, tensorizers):
return ["tokens", "tokens_lens"]

def get_export_output_names(self, tensorizers):
return ["scores"]

def vocab_to_export(self, tensorizers):
return {"tokens": list(tensorizers["tokens"].vocab)}

def caffe2_export(self, tensorizers, tensor_dict, path, export_onnx_path=None):
exporter = ModelExporter(
ModelExporter.Config(),
self.input_names,
self.get_export_input_names(tensorizers),
self.arrange_model_inputs(tensor_dict),
self.vocab_to_export(tensorizers),
self.output_names,
self.get_export_output_names(tensorizers),
)
return exporter.export_to_caffe2(self, path, export_onnx_path=export_onnx_path)

Expand Down
3 changes: 3 additions & 0 deletions pytext/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytext.config.doc_classification import ModelInput
from pytext.config.field_config import FeatureConfig
from pytext.config.pytext_config import ConfigBase, ConfigBaseMeta
from pytext.config.serialize import _is_optional
from pytext.data import CommonMetadata
from pytext.data.tensorizers import Tensorizer
from pytext.models.module import create_module
Expand All @@ -25,6 +26,8 @@ class ModelInputMeta(ConfigBaseMeta):
def __new__(metacls, typename, bases, namespace):
annotations = namespace.get("__annotations__", {})
for type in annotations.values():
if _is_optional(type):
type = type.__args__[0]
if not issubclass(type, Tensorizer.Config):
raise TypeError(
"ModelInput configuration should only include tensorizers"
Expand Down
1 change: 1 addition & 0 deletions pytext/task/new_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def from_config(cls, config: Config, unused_metadata=None, model_state=None):
tensorizers = {
name: create_component(ComponentType.TENSORIZER, tensorizer)
for name, tensorizer in config.model.inputs._asdict().items()
if tensorizer
}
schema: Dict[str, Type] = {}
for tensorizer in tensorizers.values():
Expand Down

0 comments on commit aa46761

Please sign in to comment.