Skip to content

Commit

Permalink
Add dense feature support (facebookresearch#449)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#449

Pull Request resolved: facebookresearch#444

Add dense features to decoder in Contextual Intent Slot model.

Reviewed By: gardenia22

Differential Revision: D14754861

fbshipit-source-id: de7e8c946d533bf9d96a11428772646dbac45a4e
  • Loading branch information
Derek Liu authored and facebook-github-bot committed Apr 5, 2019
1 parent 6abb019 commit 2704c1c
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 8 deletions.
3 changes: 3 additions & 0 deletions pytext/config/contextual_intent_slot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .field_config import (
CharFeatConfig,
DictFeatConfig,
FloatVectorConfig,
PretrainedModelEmbeddingConfig,
TargetConfigBase,
WordFeatConfig,
Expand All @@ -18,6 +19,7 @@ class ModelInputConfig(ModuleConfig):
char_feat: Optional[CharFeatConfig] = None
pretrained_model_embedding: Optional[PretrainedModelEmbeddingConfig] = None
seq_word_feat: Optional[WordFeatConfig] = WordFeatConfig()
dense_feat: Optional[FloatVectorConfig] = None


TargetConfig = List[TargetConfigBase]
Expand All @@ -29,6 +31,7 @@ class ModelInput:
CHAR = "char_feat"
PRETRAINED = "pretrained_model_embedding"
SEQ = "seq_word_feat"
DENSE = "dense_feat"


class ExtraField:
Expand Down
15 changes: 13 additions & 2 deletions pytext/data/contextual_intent_slot_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DocLabelField,
Field,
FloatField,
FloatVectorField,
PretrainedModelEmbeddingField,
RawField,
SeqFeatureField,
Expand All @@ -37,6 +38,7 @@ class RawData:
DICT_FEAT = "dict_feat"
DOC_WEIGHT = "doc_weight"
WORD_WEIGHT = "word_weight"
DENSE_FEAT = "dense_feat"


class ContextualIntentSlotModelDataHandler(JointModelDataHandler):
Expand Down Expand Up @@ -112,6 +114,7 @@ def from_config(
ModelInput.CHAR: CharFeatureField,
ModelInput.PRETRAINED: PretrainedModelEmbeddingField,
ModelInput.SEQ: SeqFeatureField,
ModelInput.DENSE: FloatVectorField,
},
)

Expand Down Expand Up @@ -218,6 +221,10 @@ def preprocess_row(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
ExtraField.UTTERANCE: row_data[RawData.TEXT],
ExtraField.TOKEN_RANGE: features_list[-1].token_ranges,
}

if RawData.DENSE_FEAT in row_data:
res[ModelInput.DENSE] = row_data.get(RawData.DENSE_FEAT)

if WordLabelConfig._name in self.labels:
# TODO move it into word label field
res[WordLabelConfig._name] = data.align_slot_labels(
Expand All @@ -230,16 +237,20 @@ def preprocess_row(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
def _train_input_from_batch(self, batch):
text_input = getattr(batch, ModelInput.TEXT)
seq_input = getattr(batch, ModelInput.SEQ)
return (
result = (
# text_input[0] contains the word embeddings,
# text_input[1] contains the lengths of each word
text_input[0],
*(
getattr(batch, key)
for key in self.features
if key not in [ModelInput.TEXT, ModelInput.SEQ]
if key not in [ModelInput.TEXT, ModelInput.SEQ, ModelInput.DENSE]
),
seq_input[0],
text_input[1],
seq_input[1],
)
# Append dense faeture to decoder layer at the end.
if ModelInput.DENSE in self.features:
result = result + (getattr(batch, ModelInput.DENSE),)
return result
22 changes: 21 additions & 1 deletion pytext/data/test/contextual_intent_slot_data_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest

from pytext.common.constants import DFColumn
from pytext.config.contextual_intent_slot import ModelInputConfig, TargetConfig
from pytext.config.contextual_intent_slot import ModelInput, ModelInputConfig
from pytext.config.field_config import DocLabelConfig, WordLabelConfig
from pytext.data import ContextualIntentSlotModelDataHandler
from pytext.data.featurizer import SimpleFeaturizer
Expand Down Expand Up @@ -56,3 +56,23 @@ def test_intermediate_result(self):
self.assertEqual(data.examples[0].raw_word_label, "")
self.assertListEqual(data.examples[0].token_range, [(0, 4), (5, 9), (10, 14)])
self.assertEqual(data.examples[0].utterance, '["Hey", "Youd love this"]')


class ContextualIntentSlotModelDataHandlerDenseTest(unittest.TestCase):
def test_read_file_with_dense_features(self):
data_handler_config = ContextualIntentSlotModelDataHandler.Config()
data_handler_config.columns_to_read.append(ModelInput.DENSE)
dense_file_name = tests_module.test_file(
"contextual_intent_slot_train_tiny_dense.tsv"
)
data_handler = ContextualIntentSlotModelDataHandler.from_config(
data_handler_config,
ModelInputConfig(),
[DocLabelConfig(), WordLabelConfig()],
featurizer=SimpleFeaturizer(SimpleFeaturizer.Config(), ModelInputConfig()),
)

dense_data = list(
data_handler.read_from_file(dense_file_name, data_handler.raw_columns)
)
self.assertEqual(dense_data[0][ModelInput.DENSE], "[0,1,2,3,4]")
17 changes: 14 additions & 3 deletions pytext/models/decoders/intent_slot_model_decoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List
from typing import List, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -67,15 +67,26 @@ def __init__(

self.word_decoder = nn.Linear(in_dim_word, out_dim_word)

def forward(self, x_d: torch.Tensor, x_w: torch.Tensor) -> List[torch.Tensor]:
logit_d = self.doc_decoder(x_d)
def forward(
self, x_d: torch.Tensor, x_w: torch.Tensor, dense: Optional[torch.Tensor] = None
) -> List[torch.Tensor]:
if dense is not None:
logit_d = self.doc_decoder(torch.cat((x_d, dense), 1))
else:
logit_d = self.doc_decoder(x_d)

if self.use_doc_probs_in_word:
# Get doc probability distribution
doc_prob = F.softmax(logit_d, 1)
word_input_shape = x_w.size()
doc_prob = doc_prob.unsqueeze(1).repeat(1, word_input_shape[1], 1)
x_w = torch.cat((x_w, doc_prob), 2)

if dense is not None:
word_input_shape = x_w.size()
dense = dense.unsqueeze(1).repeat(1, word_input_shape[1], 1)
x_w = torch.cat((x_w, dense), 2)

return [logit_d, self.word_decoder(x_w)]

def get_decoder(self) -> List[nn.Module]:
Expand Down
13 changes: 11 additions & 2 deletions pytext/models/joint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Union

from pytext.config import ConfigBase
from pytext.config.contextual_intent_slot import ModelInput
from pytext.data import CommonMetadata
from pytext.models.model import Model
from pytext.models.module import create_module
Expand Down Expand Up @@ -39,14 +40,22 @@ def from_config(cls, model_config, feat_config, metadata: CommonMetadata):
representation = create_module(
model_config.representation, embed_dim=embedding.embedding_dim
)
dense_feat_dim = 0
for decoder_feat in (ModelInput.DENSE,): # Only 1 right now.
if getattr(feat_config, decoder_feat, False):
dense_feat_dim = getattr(feat_config, ModelInput.DENSE).dim

doc_label_meta, word_label_meta = metadata.target
decoder = create_module(
model_config.decoder,
in_dim_doc=representation.doc_representation_dim,
in_dim_word=representation.word_representation_dim,
in_dim_doc=representation.doc_representation_dim + dense_feat_dim,
in_dim_word=representation.word_representation_dim + dense_feat_dim,
out_dim_doc=doc_label_meta.vocab_size,
out_dim_word=word_label_meta.vocab_size,
)

if dense_feat_dim > 0:
decoder.num_decoder_modules = 1
output_layer = create_module(
model_config.output_layer, doc_label_meta, word_label_meta
)
Expand Down
10 changes: 10 additions & 0 deletions tests/data/contextual_intent_slot_train_tiny_dense.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
cu:other ["Hey", "Youd love this"] {"tokenFeatList": [{"tokenIdx": 1, "features": {"b_ozlo_ner_category:content_reaction": 0.12}}]} 0.2 0.5 [0,1,2,3,4]
cu:address_Person 0:4:person ["this is crazy still to me"] {"tokenFeatList": [{"tokenIdx": 5, "features": {"b_uce_contacts": 1.0, "uce_contacts": 1.0}}]} 0.2 0.5 [0,1,2,3,4]
cu:other ["Whats up. How are you doing"] 0.2 0.5 [0,1,2,3,4]
cu:other ["Hey", "Hey"] 0.2 0.5 [0,1,2,3,4]
cu:other ["Dinner tonight?", "yup"] 0.2 0.5 [0,1,2,3,4]
cu:other ["wanna hangout?", "maybe"] 0.2 0.5 [0,1,2,3,4]
cu:other ["wya?", "home"] 0.2 0.5 [0,1,2,3,4]
cu:other ["tommi sushi again?", "why not"] {"tokenFeatList": [{"tokenIdx": 0, "features": {"b_ozlo_ner_category:poi": 0.9}}, {"tokenIdx": 1, "features": {"b_ozlo_ner_category:poi": 1.0}}]} 0.2 0.5 [0,1,2,3,4]
cu:other ["going out?"] 0.2 0.5 [0,1,2,3,4]
cu:other ["I like this!!"] 0.2 0.5 [0,1,2,3,4]

0 comments on commit 2704c1c

Please sign in to comment.