Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
RNNG: add tok/beam size to forward params (#434)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #434

1 add tok/beam size to forward params
2 change the way of getting input from using enum to incrementing index, because the enum way is wrong when dict feat is optional and contextual feat is given (input position changed)
3 this change is backward compatible, but not forward compatible (old model will work with this code but new model will not work with existing predictor)

Differential Revision: D14718588

fbshipit-source-id: 2289c97840a3e77d0a8f807f2d558be9b98da710
  • Loading branch information
seayoung1112 authored and facebook-github-bot committed Apr 4, 2019
1 parent 944cc11 commit 178d9a4
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
4 changes: 4 additions & 0 deletions pytext/data/compositional_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def _train_input_from_batch(self, batch):
if name == ACTION_FEATURE_FIELD:
input = input.tolist() # Action needn't be passed as Tensor obj.
m_inputs.append(input)
# beam size and topk
m_inputs.extend([1, 1])
return m_inputs

def _test_input_from_batch(self, batch):
Expand All @@ -156,6 +158,8 @@ def _test_input_from_batch(self, batch):
getattr(batch, DatasetFieldName.DICT_FIELD, None),
None,
getattr(batch, DatasetFieldName.PRETRAINED_MODEL_EMBEDDING, None),
1,
1,
]

def preprocess_row(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
12 changes: 3 additions & 9 deletions pytext/models/semantic_parsers/rnng/rnng_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class RNNGConstraints(ConfigBase):
# version 0 - initial implementation
# version 1 - beam search
# version 2 - use zero init state rather than random
# version 3 - add beam search input params
version: int = 2
lstm: BiLSTM.Config = BiLSTM.Config()
ablation: AblationParams = AblationParams()
Expand Down Expand Up @@ -125,8 +126,6 @@ def from_config(cls, model_config, feature_config, metadata: CommonMetadata):
lstm_dim=model_config.lstm.lstm_dim,
max_open_NT=model_config.max_open_NT,
dropout=model_config.dropout,
beam_size=model_config.beam_size,
top_k=model_config.top_k,
actions_vocab=metadata.actions_vocab,
shift_idx=metadata.shift_idx,
reduce_idx=metadata.reduce_idx,
Expand All @@ -146,8 +145,6 @@ def __init__(
lstm_dim: int,
max_open_NT: int,
dropout: float,
beam_size: int,
top_k: int,
actions_vocab,
shift_idx: int,
reduce_idx: int,
Expand Down Expand Up @@ -238,9 +235,6 @@ def __init__(
self.valid_IN_idxs = valid_IN_idxs
self.valid_SL_idxs = valid_SL_idxs

self.beam_size = beam_size
self.top_k = top_k

num_actions = len(actions_vocab)
lstm_count = ablation.use_buffer + ablation.use_stack + ablation.use_action
if lstm_count == 0:
Expand Down Expand Up @@ -278,6 +272,8 @@ def forward(
dict_feat: Optional[Tuple[torch.Tensor, ...]] = None,
actions: Optional[List[List[int]]] = None,
contextual_token_embeddings: Optional[torch.Tensor] = None,
beam_size=1,
top_k=1,
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""RNNG forward function.
Expand All @@ -295,8 +291,6 @@ def forward(
(batch_size, action_length)
(batch_size, action_length, number_of_actions)
"""
beam_size = self.beam_size
top_k = self.top_k

if self.stage != Stage.TEST:
beam_size = 1
Expand Down
6 changes: 3 additions & 3 deletions pytext/models/test/rnng_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def setUp(self):
lstm_dim=20,
max_open_NT=10,
dropout=0.2,
beam_size=3,
top_k=3,
actions_vocab=actions_vocab,
shift_idx=4,
reduce_idx=3,
Expand Down Expand Up @@ -218,7 +216,9 @@ def test_forward_shapes(self):

# Beam Search Test
self.parser.eval(Stage.TEST)
results = self.parser(tokens=tokens, seq_lens=seq_lens, dict_feat=dict_feat)
results = self.parser(
tokens=tokens, seq_lens=seq_lens, dict_feat=dict_feat, beam_size=3, top_k=3
)
self.assertEqual(len(results), 3)
for actions, scores in results:
self.assertGreater(actions.shape[1], tokens.shape[1])
Expand Down

0 comments on commit 178d9a4

Please sign in to comment.