From 178d9a44b3c2265961f99572c9ef7df967a7d669 Mon Sep 17 00:00:00 2001 From: Shicong Zhao Date: Wed, 3 Apr 2019 17:24:33 -0700 Subject: [PATCH] RNNG: add tok/beam size to forward params (#434) Summary: Pull Request resolved: https://github.com/facebookresearch/pytext/pull/434 1 add tok/beam size to forward params 2 change the way of getting input from using enum to incrementing index, because the enum way is wrong when dict feat is optional and contextual feat is given (input position changed) 3 this change is backward compatible, but not forward compatible (old model will work with this code but new model will not work with existing predictor) Differential Revision: D14718588 fbshipit-source-id: 2289c97840a3e77d0a8f807f2d558be9b98da710 --- pytext/data/compositional_data_handler.py | 4 ++++ pytext/models/semantic_parsers/rnng/rnng_parser.py | 12 +++--------- pytext/models/test/rnng_test.py | 6 +++--- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/pytext/data/compositional_data_handler.py b/pytext/data/compositional_data_handler.py index b14b18b1c..dfedeef25 100644 --- a/pytext/data/compositional_data_handler.py +++ b/pytext/data/compositional_data_handler.py @@ -146,6 +146,8 @@ def _train_input_from_batch(self, batch): if name == ACTION_FEATURE_FIELD: input = input.tolist() # Action needn't be passed as Tensor obj. m_inputs.append(input) + # beam size and topk + m_inputs.extend([1, 1]) return m_inputs def _test_input_from_batch(self, batch): @@ -156,6 +158,8 @@ def _test_input_from_batch(self, batch): getattr(batch, DatasetFieldName.DICT_FIELD, None), None, getattr(batch, DatasetFieldName.PRETRAINED_MODEL_EMBEDDING, None), + 1, + 1, ] def preprocess_row(self, row_data: Dict[str, Any]) -> Dict[str, Any]: diff --git a/pytext/models/semantic_parsers/rnng/rnng_parser.py b/pytext/models/semantic_parsers/rnng/rnng_parser.py index d137025d1..d8d0cbb8b 100644 --- a/pytext/models/semantic_parsers/rnng/rnng_parser.py +++ b/pytext/models/semantic_parsers/rnng/rnng_parser.py @@ -91,6 +91,7 @@ class RNNGConstraints(ConfigBase): # version 0 - initial implementation # version 1 - beam search # version 2 - use zero init state rather than random + # version 3 - add beam search input params version: int = 2 lstm: BiLSTM.Config = BiLSTM.Config() ablation: AblationParams = AblationParams() @@ -125,8 +126,6 @@ def from_config(cls, model_config, feature_config, metadata: CommonMetadata): lstm_dim=model_config.lstm.lstm_dim, max_open_NT=model_config.max_open_NT, dropout=model_config.dropout, - beam_size=model_config.beam_size, - top_k=model_config.top_k, actions_vocab=metadata.actions_vocab, shift_idx=metadata.shift_idx, reduce_idx=metadata.reduce_idx, @@ -146,8 +145,6 @@ def __init__( lstm_dim: int, max_open_NT: int, dropout: float, - beam_size: int, - top_k: int, actions_vocab, shift_idx: int, reduce_idx: int, @@ -238,9 +235,6 @@ def __init__( self.valid_IN_idxs = valid_IN_idxs self.valid_SL_idxs = valid_SL_idxs - self.beam_size = beam_size - self.top_k = top_k - num_actions = len(actions_vocab) lstm_count = ablation.use_buffer + ablation.use_stack + ablation.use_action if lstm_count == 0: @@ -278,6 +272,8 @@ def forward( dict_feat: Optional[Tuple[torch.Tensor, ...]] = None, actions: Optional[List[List[int]]] = None, contextual_token_embeddings: Optional[torch.Tensor] = None, + beam_size=1, + top_k=1, ) -> List[Tuple[torch.Tensor, torch.Tensor]]: """RNNG forward function. @@ -295,8 +291,6 @@ def forward( (batch_size, action_length) (batch_size, action_length, number_of_actions) """ - beam_size = self.beam_size - top_k = self.top_k if self.stage != Stage.TEST: beam_size = 1 diff --git a/pytext/models/test/rnng_test.py b/pytext/models/test/rnng_test.py index 2c9d52f8b..96929d1ca 100644 --- a/pytext/models/test/rnng_test.py +++ b/pytext/models/test/rnng_test.py @@ -79,8 +79,6 @@ def setUp(self): lstm_dim=20, max_open_NT=10, dropout=0.2, - beam_size=3, - top_k=3, actions_vocab=actions_vocab, shift_idx=4, reduce_idx=3, @@ -218,7 +216,9 @@ def test_forward_shapes(self): # Beam Search Test self.parser.eval(Stage.TEST) - results = self.parser(tokens=tokens, seq_lens=seq_lens, dict_feat=dict_feat) + results = self.parser( + tokens=tokens, seq_lens=seq_lens, dict_feat=dict_feat, beam_size=3, top_k=3 + ) self.assertEqual(len(results), 3) for actions, scores in results: self.assertGreater(actions.shape[1], tokens.shape[1])