diff --git a/pytext/exporters/__init__.py b/pytext/exporters/__init__.py index 6845f680f..d6554ad31 100644 --- a/pytext/exporters/__init__.py +++ b/pytext/exporters/__init__.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from pytext.exporters.custom_exporters import DenseFeatureExporter from pytext.exporters.exporter import ModelExporter -__all__ = ["ModelExporter"] +__all__ = ["ModelExporter", "DenseFeatureExporter"] diff --git a/pytext/exporters/custom_exporters.py b/pytext/exporters/custom_exporters.py new file mode 100644 index 000000000..b7895c2b1 --- /dev/null +++ b/pytext/exporters/custom_exporters.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from typing import Dict + +from pytext.config import ConfigBase +from pytext.config.field_config import FeatureConfig, FloatVectorConfig +from pytext.exporters.exporter import ModelExporter +from pytext.fields import FieldMeta + + +class DenseFeatureExporter(ModelExporter): + """ + Exporter for models that have DenseFeatures as input to the decoder + """ + + @classmethod + def get_feature_metadata( + cls, feature_config: FeatureConfig, feature_meta: Dict[str, FieldMeta] + ): + # add all features EXCEPT dense features. The features exported here + # go through the representation layer + ( + input_names_rep, + dummy_model_input_rep, + feature_itos_map_rep, + ) = cls._get_exportable_metadata( + lambda x: isinstance(x, ConfigBase) + and not isinstance(x, FloatVectorConfig), + feature_config, + feature_meta, + ) + + # need feature lengths only for non-dense features + cls._add_feature_lengths(input_names_rep, dummy_model_input_rep) + + # add dense features. These features don't go through the representation + # layer, instead they go directly to the decoder + ( + input_names_dense, + dummy_model_input_dense, + feature_itos_map_dense, + ) = cls._get_exportable_metadata( + lambda x: isinstance(x, FloatVectorConfig), feature_config, feature_meta + ) + + feature_itos_map_rep.update(feature_itos_map_dense) + return ( + input_names_rep + input_names_dense, + tuple(dummy_model_input_rep + dummy_model_input_dense), + feature_itos_map_rep, + ) diff --git a/pytext/exporters/exporter.py b/pytext/exporters/exporter.py index 4a6e4d94d..ad56c69e5 100644 --- a/pytext/exporters/exporter.py +++ b/pytext/exporters/exporter.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -from typing import Dict, List, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union import torch from caffe2.python import core @@ -73,29 +73,14 @@ def get_feature_metadata( ): # The number of names in input_names *must* be equal to the number of # tensors passed in dummy_input - input_names: List[str] = [] - dummy_model_input: List = [] - feature_itos_map = {} - - for name, feat_config in feature_config._asdict().items(): - if isinstance(feat_config, ConfigBase): - input_names.extend(feat_config.export_input_names) - if getattr(feature_meta[name], "vocab", None): - feature_itos_map[feat_config.export_input_names[0]] = feature_meta[ - name - ].vocab.itos - dummy_model_input.append(feature_meta[name].dummy_model_input) - - if "tokens_vals" in input_names: - dummy_model_input.append( - torch.tensor([1, 1], dtype=torch.long) - ) # token lengths - input_names.append("tokens_lens") - if "seq_tokens_vals" in input_names: - dummy_model_input.append( - torch.tensor([1, 1], dtype=torch.long) - ) # seq lengths - input_names.append("seq_tokens_lens") + ( + input_names, + dummy_model_input, + feature_itos_map, + ) = cls._get_exportable_metadata( + lambda x: isinstance(x, ConfigBase), feature_config, feature_meta + ) + cls._add_feature_lengths(input_names, dummy_model_input) return input_names, tuple(dummy_model_input), feature_itos_map def __init__(self, config, input_names, dummy_model_input, vocab_map, output_names): @@ -233,3 +218,43 @@ def export_to_metrics(self, model, metric_channels): for mc in metric_channels or []: mc.export(model, self.dummy_model_input) + + @classmethod + def _get_exportable_metadata( + cls, + exportable_filter: Callable, + feature_config: FeatureConfig, + feature_meta: Dict[str, FieldMeta], + ) -> Tuple[List[str], List, Dict]: + # The number of names in input_names *must* be equal to the number of + # tensors passed in dummy_input + input_names: List[str] = [] + dummy_model_input: List = [] + feature_itos_map = {} + + for name, feat_config in feature_config._asdict().items(): + if exportable_filter(feat_config): + input_names.extend(feat_config.export_input_names) + if getattr(feature_meta[name], "vocab", None): + feature_itos_map[feat_config.export_input_names[0]] = feature_meta[ + name + ].vocab.itos + dummy_model_input.append(feature_meta[name].dummy_model_input) + return input_names, dummy_model_input, feature_itos_map + + @classmethod + def _add_feature_lengths(cls, input_names: List[str], dummy_model_input: List): + """If any of the input_names have tokens or seq_tokens, add the length + of those tokens to dummy_input + """ + + if "tokens_vals" in input_names: + dummy_model_input.append( + torch.tensor([1, 1], dtype=torch.long) + ) # token lengths + input_names.append("tokens_lens") + if "seq_tokens_vals" in input_names: + dummy_model_input.append( + torch.tensor([1, 1], dtype=torch.long) + ) # seq lengths + input_names.append("seq_tokens_lens") diff --git a/pytext/exporters/test/text_model_exporter_test.py b/pytext/exporters/test/text_model_exporter_test.py index cfbeccb5a..5f492f85f 100644 --- a/pytext/exporters/test/text_model_exporter_test.py +++ b/pytext/exporters/test/text_model_exporter_test.py @@ -28,6 +28,7 @@ CharFeatureField, DictFeatureField, FieldMeta, + FloatVectorField, SeqFeatureField, TextFeatureField, ) @@ -96,6 +97,41 @@ DOC_CONFIGS = [ """ +{ + "model": { + "representation": { + "DocNNRepresentation": {} + }, + "output_layer": { + "loss": { + "CrossEntropyLoss": {} + } + } + }, + "features": { + "word_feat": {}, + "dict_feat": {}, + "char_feat": { + "embed_dim": 5, + "cnn": { + "kernel_num": 2, + "kernel_sizes": [2, 3] + } + }, + "dense_feat": { + "dim":10 + } + }, + "featurizer": { + "SimpleFeaturizer": {} + }, + "trainer": { + "epochs": 1 + }, + "exporter": {} +} +""", + """ { "model": { "representation": { @@ -292,6 +328,9 @@ # Handle different batch_sizes BATCH_SIZE = 1 +# Fixed dimension of dense_features since it needs to be specified in config +DENSE_FEATURE_DIM = 10 + class ModelExporterTest(hu.HypothesisTestCase): @given( @@ -685,9 +724,11 @@ def _get_metadata(self, num_doc_classes, num_word_classes): w_vocab = Vocab(Counter()) dict_vocab = Vocab(Counter()) c_vocab = Vocab(Counter()) + d_vocab = Vocab(Counter()) w_vocab.itos = W_VOCAB dict_vocab.itos = DICT_VOCAB c_vocab.itos = CHAR_VOCAB + d_vocab.itos = [] text_feat_meta = FieldMeta() text_feat_meta.unk_token_idx = UNK_IDX @@ -712,11 +753,24 @@ def _get_metadata(self, num_doc_classes, num_word_classes): char_feat_meta.pretrained_embeds_weight = None char_feat_meta.dummy_model_input = CharFeatureField.dummy_model_input + dense_feat_meta = FieldMeta() + dense_feat_meta.vocab_size = 0 + dense_feat_meta.vocab = d_vocab + dense_feat_meta.vocab_export_name = "dense_vals" + dense_feat_meta.pretrained_embeds_weight = None + # ugh, dims are fixed + dense_feat_meta.dummy_model_input = torch.tensor( + [[1.0] * DENSE_FEATURE_DIM, [1.0] * DENSE_FEATURE_DIM], + dtype=torch.float, + device="cpu", + ) + meta = CommonMetadata() meta.features = { DatasetFieldName.TEXT_FIELD: text_feat_meta, DatasetFieldName.DICT_FIELD: dict_feat_meta, DatasetFieldName.CHAR_FIELD: char_feat_meta, + DatasetFieldName.DENSE_FIELD: dense_feat_meta, } meta.target = labels if len(labels) == 1: @@ -788,6 +842,9 @@ def _get_rand_input( c_vocab_size, size=(batch_size, num_words, num_chars) ).astype(np.int64) ) + dense_features = torch.from_numpy( + np.random.rand(batch_size, DENSE_FEATURE_DIM).astype(np.float32) + ) inputs = [] if features.word_feat: inputs.append(text) @@ -796,6 +853,8 @@ def _get_rand_input( if features.char_feat: inputs.append(chars) inputs.append(lengths) + if features.dense_feat: + inputs.append(dense_features) return tuple(inputs) def _get_config(self, cls, config_str): diff --git a/pytext/fields/field.py b/pytext/fields/field.py index 586eb9718..dc9611dd4 100644 --- a/pytext/fields/field.py +++ b/pytext/fields/field.py @@ -312,7 +312,7 @@ def __init__(self, dim=0, dim_error_check=False, **kwargs): ) self.dim_error_check = dim_error_check # dims in data should match config self.dummy_model_input = torch.tensor( - [[1.0] * dim], dtype=torch.float, device="cpu" + [[1.0] * dim, [1.0] * dim], dtype=torch.float, device="cpu" ) def _parse_vector(self, s): diff --git a/pytext/task/tasks.py b/pytext/task/tasks.py index 2b8a76d21..15f0cd35f 100644 --- a/pytext/task/tasks.py +++ b/pytext/task/tasks.py @@ -21,6 +21,7 @@ QueryDocumentPairwiseRankingDataHandler, SeqModelDataHandler, ) +from pytext.exporters import DenseFeatureExporter from pytext.metric_reporters import ( ClassificationMetricReporter, CompositionalMetricReporter, @@ -115,6 +116,7 @@ class Config(Task.Config): metric_reporter: ClassificationMetricReporter.Config = ( ClassificationMetricReporter.Config() ) + exporter: Optional[DenseFeatureExporter.Config] = None @classmethod def format_prediction(cls, predictions, scores, context, target_meta):